release/v1.4 -> master (#548)

This commit is contained in:
cybermaggedon 2025-10-06 17:54:26 +01:00 committed by GitHub
parent 3ec2cd54f9
commit 2bd68ed7f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
94 changed files with 8571 additions and 1740 deletions

View file

@ -70,6 +70,8 @@ some-containers:
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
${DOCKER} build -f containers/Containerfile.flow \
-t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
${DOCKER} build -f containers/Containerfile.vertexai \
-t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
# ${DOCKER} build -f containers/Containerfile.mcp \
# -t ${CONTAINER_BASE}/trustgraph-mcp:${VERSION} .
# ${DOCKER} build -f containers/Containerfile.vertexai \

View file

@ -158,17 +158,17 @@ The current primary key `PRIMARY KEY (collection, s, p, o)` provides minimal clu
- Uneven load distribution across cluster nodes
- Scalability bottlenecks as collections grow
## Proposed Solution: Multi-Table Denormalization Strategy
## Proposed Solution: 4-Table Denormalization Strategy
### Overview
Replace the single `triples` table with three purpose-built tables, each optimized for specific query patterns. This eliminates the need for secondary indexes and ALLOW FILTERING while providing optimal performance for all query types.
Replace the single `triples` table with four purpose-built tables, each optimized for specific query patterns. This eliminates the need for secondary indexes and ALLOW FILTERING while providing optimal performance for all query types. The fourth table enables efficient collection deletion despite compound partition keys.
### New Schema Design
**Table 1: Subject-Centric Queries**
**Table 1: Subject-Centric Queries (triples_s)**
```sql
CREATE TABLE triples_by_subject (
CREATE TABLE triples_s (
collection text,
s text,
p text,
@ -176,13 +176,13 @@ CREATE TABLE triples_by_subject (
PRIMARY KEY ((collection, s), p, o)
);
```
- **Optimizes:** get_s, get_sp, get_spo, get_os
- **Optimizes:** get_s, get_sp, get_os
- **Partition Key:** (collection, s) - Better distribution than collection alone
- **Clustering:** (p, o) - Enables efficient predicate/object lookups for a subject
**Table 2: Predicate-Object Queries**
**Table 2: Predicate-Object Queries (triples_p)**
```sql
CREATE TABLE triples_by_po (
CREATE TABLE triples_p (
collection text,
p text,
o text,
@ -194,9 +194,9 @@ CREATE TABLE triples_by_po (
- **Partition Key:** (collection, p) - Direct access by predicate
- **Clustering:** (o, s) - Efficient object-subject traversal
**Table 3: Object-Centric Queries**
**Table 3: Object-Centric Queries (triples_o)**
```sql
CREATE TABLE triples_by_object (
CREATE TABLE triples_o (
collection text,
o text,
s text,
@ -204,30 +204,72 @@ CREATE TABLE triples_by_object (
PRIMARY KEY ((collection, o), s, p)
);
```
- **Optimizes:** get_o, get_os
- **Optimizes:** get_o
- **Partition Key:** (collection, o) - Direct access by object
- **Clustering:** (s, p) - Efficient subject-predicate traversal
**Table 4: Collection Management & SPO Queries (triples_collection)**
```sql
CREATE TABLE triples_collection (
collection text,
s text,
p text,
o text,
PRIMARY KEY (collection, s, p, o)
);
```
- **Optimizes:** get_spo, delete_collection
- **Partition Key:** collection only - Enables efficient collection-level operations
- **Clustering:** (s, p, o) - Standard triple ordering
- **Purpose:** Dual use for exact SPO lookups and as deletion index
### Query Mapping
| Original Query | Target Table | Performance Improvement |
|----------------|-------------|------------------------|
| get_all(collection) | triples_by_subject | Token-based pagination |
| get_s(collection, s) | triples_by_subject | Direct partition access |
| get_p(collection, p) | triples_by_po | Direct partition access |
| get_o(collection, o) | triples_by_object | Direct partition access |
| get_sp(collection, s, p) | triples_by_subject | Partition + clustering |
| get_po(collection, p, o) | triples_by_po | **No more ALLOW FILTERING!** |
| get_os(collection, o, s) | triples_by_subject | Partition + clustering |
| get_spo(collection, s, p, o) | triples_by_subject | Exact key lookup |
| get_all(collection) | triples_s | ALLOW FILTERING (acceptable for scan) |
| get_s(collection, s) | triples_s | Direct partition access |
| get_p(collection, p) | triples_p | Direct partition access |
| get_o(collection, o) | triples_o | Direct partition access |
| get_sp(collection, s, p) | triples_s | Partition + clustering |
| get_po(collection, p, o) | triples_p | **No more ALLOW FILTERING!** |
| get_os(collection, o, s) | triples_o | Partition + clustering |
| get_spo(collection, s, p, o) | triples_collection | Exact key lookup |
| delete_collection(collection) | triples_collection | Read index, batch delete all |
### Collection Deletion Strategy
With compound partition keys, we cannot simply execute `DELETE FROM table WHERE collection = ?`. Instead:
1. **Read Phase:** Query `triples_collection` to enumerate all triples:
```sql
SELECT s, p, o FROM triples_collection WHERE collection = ?
```
This is efficient since `collection` is the partition key for this table.
2. **Delete Phase:** For each triple (s, p, o), delete from all 4 tables using full partition keys:
```sql
DELETE FROM triples_s WHERE collection = ? AND s = ? AND p = ? AND o = ?
DELETE FROM triples_p WHERE collection = ? AND p = ? AND o = ? AND s = ?
DELETE FROM triples_o WHERE collection = ? AND o = ? AND s = ? AND p = ?
DELETE FROM triples_collection WHERE collection = ? AND s = ? AND p = ? AND o = ?
```
Batched in groups of 100 for efficiency.
**Trade-off Analysis:**
- ✅ Maintains optimal query performance with distributed partitions
- ✅ No hot partitions for large collections
- ❌ More complex deletion logic (read-then-delete)
- ❌ Deletion time proportional to collection size
### Benefits
1. **Eliminates ALLOW FILTERING** - Every query has an optimal access path
1. **Eliminates ALLOW FILTERING** - Every query has an optimal access path (except get_all scan)
2. **No Secondary Indexes** - Each table IS the index for its query pattern
3. **Better Data Distribution** - Composite partition keys spread load effectively
4. **Predictable Performance** - Query time proportional to result size, not total data
5. **Leverages Cassandra Strengths** - Designed for Cassandra's architecture
6. **Enables Collection Deletion** - triples_collection serves as deletion index
## Implementation Plan
@ -295,10 +337,11 @@ def delete_collection(self, collection) -> None # Delete from all three tables
### Implementation Strategy
#### Phase 1: Schema and Core Methods
1. **Rewrite `init()` method** - Create three tables instead of one
2. **Rewrite `insert()` method** - Batch writes to all three tables
1. **Rewrite `init()` method** - Create four tables instead of one
2. **Rewrite `insert()` method** - Batch writes to all four tables
3. **Implement prepared statements** - For optimal performance
4. **Add table routing logic** - Direct queries to optimal tables
5. **Implement collection deletion** - Read from triples_collection, batch delete from all tables
#### Phase 2: Query Method Optimization
1. **Rewrite each get_* method** to use optimal table
@ -318,18 +361,11 @@ def delete_collection(self, collection) -> None # Delete from all three tables
def insert(self, collection, s, p, o):
batch = BatchStatement()
# Insert into all three tables
batch.add(SimpleStatement(
"INSERT INTO triples_by_subject (collection, s, p, o) VALUES (?, ?, ?, ?)"
), (collection, s, p, o))
batch.add(SimpleStatement(
"INSERT INTO triples_by_po (collection, p, o, s) VALUES (?, ?, ?, ?)"
), (collection, p, o, s))
batch.add(SimpleStatement(
"INSERT INTO triples_by_object (collection, o, s, p) VALUES (?, ?, ?, ?)"
), (collection, o, s, p))
# Insert into all four tables
batch.add(self.insert_subject_stmt, (collection, s, p, o))
batch.add(self.insert_po_stmt, (collection, p, o, s))
batch.add(self.insert_object_stmt, (collection, o, s, p))
batch.add(self.insert_collection_stmt, (collection, s, p, o))
self.session.execute(batch)
```
@ -337,11 +373,65 @@ def insert(self, collection, s, p, o):
#### Query Routing Logic
```python
def get_po(self, collection, p, o, limit=10):
# Route to triples_by_po table - NO ALLOW FILTERING!
# Route to triples_p table - NO ALLOW FILTERING!
return self.session.execute(
"SELECT s FROM triples_by_po WHERE collection = ? AND p = ? AND o = ? LIMIT ?",
self.get_po_stmt,
(collection, p, o, limit)
)
def get_spo(self, collection, s, p, o, limit=10):
# Route to triples_collection table for exact SPO lookup
return self.session.execute(
self.get_spo_stmt,
(collection, s, p, o, limit)
)
```
#### Collection Deletion Logic
```python
def delete_collection(self, collection):
# Step 1: Read all triples from collection table
rows = self.session.execute(
f"SELECT s, p, o FROM {self.collection_table} WHERE collection = %s",
(collection,)
)
# Step 2: Batch delete from all 4 tables
batch = BatchStatement()
count = 0
for row in rows:
s, p, o = row.s, row.p, row.o
# Delete using full partition keys for each table
batch.add(SimpleStatement(
f"DELETE FROM {self.subject_table} WHERE collection = ? AND s = ? AND p = ? AND o = ?"
), (collection, s, p, o))
batch.add(SimpleStatement(
f"DELETE FROM {self.po_table} WHERE collection = ? AND p = ? AND o = ? AND s = ?"
), (collection, p, o, s))
batch.add(SimpleStatement(
f"DELETE FROM {self.object_table} WHERE collection = ? AND o = ? AND s = ? AND p = ?"
), (collection, o, s, p))
batch.add(SimpleStatement(
f"DELETE FROM {self.collection_table} WHERE collection = ? AND s = ? AND p = ? AND o = ?"
), (collection, s, p, o))
count += 1
# Execute every 100 triples to avoid oversized batches
if count % 100 == 0:
self.session.execute(batch)
batch = BatchStatement()
# Execute remaining deletions
if count % 100 != 0:
self.session.execute(batch)
logger.info(f"Deleted {count} triples from collection {collection}")
```
#### Prepared Statement Optimization
@ -349,12 +439,18 @@ def get_po(self, collection, p, o, limit=10):
def prepare_statements(self):
# Cache prepared statements for better performance
self.insert_subject_stmt = self.session.prepare(
"INSERT INTO triples_by_subject (collection, s, p, o) VALUES (?, ?, ?, ?)"
f"INSERT INTO {self.subject_table} (collection, s, p, o) VALUES (?, ?, ?, ?)"
)
self.insert_po_stmt = self.session.prepare(
"INSERT INTO triples_by_po (collection, p, o, s) VALUES (?, ?, ?, ?)"
f"INSERT INTO {self.po_table} (collection, p, o, s) VALUES (?, ?, ?, ?)"
)
# ... etc for all tables and queries
self.insert_object_stmt = self.session.prepare(
f"INSERT INTO {self.object_table} (collection, o, s, p) VALUES (?, ?, ?, ?)"
)
self.insert_collection_stmt = self.session.prepare(
f"INSERT INTO {self.collection_table} (collection, s, p, o) VALUES (?, ?, ?, ?)"
)
# ... query statements
```
## Migration Strategy
@ -511,9 +607,10 @@ def rollback_to_legacy():
## Risks and Considerations
### Performance Risks
- **Write latency increase** - 3x write operations per insert
- **Storage overhead** - 3x storage requirement
- **Write latency increase** - 4x write operations per insert (33% more than 3-table approach)
- **Storage overhead** - 4x storage requirement (33% more than 3-table approach)
- **Batch write failures** - Need proper error handling
- **Deletion complexity** - Collection deletion requires read-then-delete loop
### Operational Risks
- **Migration complexity** - Data migration for large datasets

View file

@ -2,16 +2,17 @@
## Overview
This specification describes the collection management capabilities for TrustGraph, enabling users to have explicit control over collections that are currently implicitly created during data loading and querying operations. The feature supports four primary use cases:
This specification describes the collection management capabilities for TrustGraph, requiring explicit collection creation and providing direct control over the collection lifecycle. Collections must be explicitly created before use, ensuring proper synchronization between the librarian metadata and all storage backends. The feature supports four primary use cases:
1. **Collection Listing**: View all existing collections in the system
2. **Collection Deletion**: Remove unwanted collections and their associated data
3. **Collection Labeling**: Associate descriptive labels with collections for better organization
4. **Collection Tagging**: Apply tags to collections for categorization and easier discovery
1. **Collection Creation**: Explicitly create collections before storing data
2. **Collection Listing**: View all existing collections in the system
3. **Collection Metadata Management**: Update collection names, descriptions, and tags
4. **Collection Deletion**: Remove collections and their associated data across all storage types
## Goals
- **Explicit Collection Control**: Provide users with direct management capabilities over collections beyond implicit creation
- **Explicit Collection Creation**: Require collections to be created before data can be stored
- **Storage Synchronization**: Ensure collections exist in all storage backends (vectors, objects, triples)
- **Collection Visibility**: Enable users to list and inspect all collections in their environment
- **Collection Cleanup**: Allow deletion of collections that are no longer needed
- **Collection Organization**: Support labels and tags for better collection tracking and discovery
@ -19,22 +20,25 @@ This specification describes the collection management capabilities for TrustGra
- **Collection Discovery**: Make it easier to find specific collections through filtering and search
- **Operational Transparency**: Provide clear visibility into collection lifecycle and usage
- **Resource Management**: Enable cleanup of unused collections to optimize resource utilization
- **Data Integrity**: Prevent orphaned collections in storage without metadata tracking
## Background
Currently, collections in TrustGraph are implicitly created during data loading operations and query execution. While this provides convenience for users, it lacks the explicit control needed for production environments and long-term data management.
Previously, collections in TrustGraph were implicitly created during data loading operations, leading to synchronization issues where collections could exist in storage backends without corresponding metadata in the librarian. This created management challenges and potential orphaned data.
Current limitations include:
- No way to list existing collections
- No mechanism to delete unwanted collections
- No ability to associate metadata with collections for tracking purposes
- Difficulty in organizing and discovering collections over time
The explicit collection creation model addresses these issues by:
- Requiring collections to be created before use via `tg-set-collection`
- Broadcasting collection creation to all storage backends
- Maintaining synchronized state between librarian metadata and storage
- Preventing writes to non-existent collections
- Providing clear collection lifecycle management
This specification addresses these gaps by introducing explicit collection management operations. By providing collection management APIs and commands, TrustGraph can:
- Give users full control over their collection lifecycle
- Enable better organization through labels and tags
- Support collection cleanup for resource optimization
- Improve operational visibility and management
This specification defines the explicit collection management model. By requiring explicit collection creation, TrustGraph ensures:
- Collections are tracked in librarian metadata from creation
- All storage backends are aware of collections before receiving data
- No orphaned collections exist in storage
- Clear operational visibility and control over collection lifecycle
- Consistent error handling when operations reference non-existent collections
## Technical Design
@ -98,24 +102,52 @@ This approach allows:
#### Collection Lifecycle
Collections follow a lazy-creation pattern that aligns with existing TrustGraph behavior:
Collections are explicitly created in the librarian before data operations can proceed:
1. **Lazy Creation**: Collections are automatically created when first referenced during data loading or query operations. No explicit create operation is needed.
1. **Collection Creation** (Two Paths):
2. **Implicit Registration**: When a collection is used (data loading, querying), the system checks if a metadata record exists. If not, a new record is created with default values:
- `name`: defaults to collection_id
- `description`: empty
- `tags`: empty set
- `created_at`: current timestamp
**Path A: User-Initiated Creation** via `tg-set-collection`:
- User provides collection ID, name, description, and tags
- Librarian creates metadata record in `collections` table
- Librarian broadcasts "create-collection" to all storage backends
- All storage processors create collection and confirm success
- Collection is now ready for data operations
3. **Explicit Updates**: Users can update collection metadata (name, description, tags) through management operations after lazy creation.
**Path B: Automatic Creation on Document Submission**:
- User submits document specifying a collection ID
- Librarian checks if collection exists in metadata table
- If not exists: Librarian creates metadata with defaults (name=collection_id, empty description/tags)
- Librarian broadcasts "create-collection" to all storage backends
- All storage processors create collection and confirm success
- Document processing proceeds with collection now established
4. **Explicit Deletion**: Users can delete collections, which removes both the metadata record and the underlying collection data across all store types.
Both paths ensure collection exists in librarian metadata AND all storage backends before data operations.
5. **Multi-Store Deletion**: Collection deletion cascades across all storage backends (vector stores, object stores, triple stores) as each implements lazy creation and must support collection deletion.
2. **Storage Validation**: Write operations validate collection exists:
- Storage processors check collection state before accepting writes
- Writes to non-existent collections return error
- This prevents direct writes bypassing the librarian's collection creation logic
3. **Query Behavior**: Query operations handle non-existent collections gracefully:
- Queries to non-existent collections return empty results
- No error thrown for query operations
- Allows exploration without requiring collection to exist
4. **Metadata Updates**: Users can update collection metadata after creation:
- Update name, description, and tags via `tg-set-collection`
- Updates apply to librarian metadata only
- Storage backends maintain collection but metadata updates don't propagate
5. **Explicit Deletion**: Users delete collections via `tg-delete-collection`:
- Librarian broadcasts "delete-collection" to all storage backends
- Waits for confirmation from all storage processors
- Deletes librarian metadata record only after storage cleanup complete
- Ensures no orphaned data remains in storage
**Key Principle**: The librarian is the single point of control for collection creation. Whether initiated by user command or document submission, the librarian ensures proper metadata tracking and storage backend synchronization before allowing data operations.
Operations required:
- **Collection Use Notification**: Internal operation triggered during data loading/querying to ensure metadata record exists
- **Create Collection**: User operation via `tg-set-collection` OR automatic on document submission
- **Update Collection Metadata**: User operation to modify name, description, and tags
- **Delete Collection**: User operation to remove collection and its data across all stores
- **List Collections**: User operation to view collections with filtering by tags
@ -123,32 +155,65 @@ Operations required:
#### Multi-Store Collection Management
Collections exist across multiple storage backends in TrustGraph:
- **Vector Stores**: Store embeddings and vector data for collections
- **Object Stores**: Store documents and file data for collections
- **Triple Stores**: Store graph/RDF data for collections
- **Vector Stores** (Qdrant, Milvus, Pinecone): Store embeddings and vector data
- **Object Stores** (Cassandra): Store documents and file data
- **Triple Stores** (Cassandra, Neo4j, Memgraph, FalkorDB): Store graph/RDF data
Each store type implements:
- **Lazy Creation**: Collections are created implicitly when data is first stored
- **Collection Deletion**: Store-specific deletion operations to remove collection data
- **Collection State Tracking**: Maintain knowledge of which collections exist
- **Collection Creation**: Accept and process "create-collection" operations
- **Collection Validation**: Check collection exists before accepting writes
- **Collection Deletion**: Remove all data for specified collection
The librarian service coordinates collection operations across all store types, ensuring consistent collection lifecycle management.
The librarian service coordinates collection operations across all store types, ensuring:
- Collections created in all backends before use
- All backends confirm creation before returning success
- Synchronized collection lifecycle across storage types
- Consistent error handling when collections don't exist
#### Collection State Tracking by Storage Type
Each storage backend tracks collection state differently based on its capabilities:
**Cassandra Triple Store:**
- Uses existing `triples_collection` table
- Creates system marker triple when collection created
- Query: `SELECT collection FROM triples_collection WHERE collection = ? LIMIT 1`
- Efficient single-partition check for collection existence
**Qdrant/Milvus/Pinecone Vector Stores:**
- Native collection APIs provide existence checking
- Collections created with proper vector configuration
- `collection_exists()` method uses storage API
- Collection creation validates dimension requirements
**Neo4j/Memgraph/FalkorDB Graph Stores:**
- Use `:CollectionMetadata` nodes to track collections
- Node properties: `{user, collection, created_at}`
- Query: `MATCH (c:CollectionMetadata {user: $user, collection: $collection})`
- Separate from data nodes for clean separation
- Enables efficient collection listing and validation
**Cassandra Object Store:**
- Uses collection metadata table or marker rows
- Similar pattern to triple store
- Validates collection before document writes
### APIs
New APIs:
Collection Management APIs (Librarian):
- **Create/Update Collection**: Create new collection or update existing metadata via `tg-set-collection`
- **List Collections**: Retrieve collections for a user with optional tag filtering
- **Update Collection Metadata**: Modify collection name, description, and tags
- **Delete Collection**: Remove collection and associated data with confirmation, cascading to all store types
- **Collection Use Notification** (Internal): Ensure metadata record exists when collection is referenced
- **Delete Collection**: Remove collection and associated data, cascading to all store types
Store Writer APIs (Enhanced):
- **Vector Store Collection Deletion**: Remove vector data for specified user and collection
- **Object Store Collection Deletion**: Remove object/document data for specified user and collection
- **Triple Store Collection Deletion**: Remove graph/RDF data for specified user and collection
Storage Management APIs (All Storage Processors):
- **Create Collection**: Handle "create-collection" operation, establish collection in storage
- **Delete Collection**: Handle "delete-collection" operation, remove all collection data
- **Collection Exists Check**: Internal validation before accepting write operations
Modified APIs:
- **Data Loading APIs**: Enhanced to trigger collection use notification for lazy metadata creation
- **Query APIs**: Enhanced to trigger collection use notification and optionally include metadata in responses
Data Operation APIs (Modified Behavior):
- **Write APIs**: Validate collection exists before accepting data, return error if not
- **Query APIs**: Return empty results for non-existent collections without error
### Implementation Details
@ -168,32 +233,35 @@ When a user initiates collection deletion through the librarian service:
#### Collection Management Interface
All store writers will implement a standardized collection management interface with a common schema across store types:
All store writers implement a standardized collection management interface with a common schema:
**Message Schema:**
**Message Schema (`StorageManagementRequest`):**
```json
{
"operation": "delete-collection",
"operation": "create-collection" | "delete-collection",
"user": "user123",
"collection": "documents-2024",
"timestamp": "2024-01-15T10:30:00Z"
"collection": "documents-2024"
}
```
**Queue Architecture:**
- **Object Store Collection Management Queue**: Handles collection operations for object/document stores
- **Vector Store Collection Management Queue**: Handles collection operations for vector/embedding stores
- **Triple Store Collection Management Queue**: Handles collection operations for graph/RDF stores
- **Vector Store Management Queue** (`vector-storage-management`): Vector/embedding stores
- **Object Store Management Queue** (`object-storage-management`): Object/document stores
- **Triple Store Management Queue** (`triples-storage-management`): Graph/RDF stores
- **Storage Response Queue** (`storage-management-response`): All responses sent here
Each store writer implements:
- **Collection Management Handler**: Separate from standard data storage handlers
- **Delete Collection Operation**: Removes all data associated with the specified collection
- **Message Processing**: Consumes from dedicated collection management queue
- **Status Reporting**: Returns success/failure status for coordination
- **Idempotent Operations**: Handles cases where collection doesn't exist (no-op)
- **Collection Management Handler**: Processes `StorageManagementRequest` messages
- **Create Collection Operation**: Establishes collection in storage backend
- **Delete Collection Operation**: Removes all data associated with collection
- **Collection State Tracking**: Maintains knowledge of which collections exist
- **Message Processing**: Consumes from dedicated management queue
- **Status Reporting**: Returns success/failure via `StorageManagementResponse`
- **Idempotent Operations**: Safe to call create/delete multiple times
**Initial Implementation:**
Only `delete-collection` operation will be implemented initially. The interface supports future operations like `archive-collection`, `migrate-collection`, etc.
**Supported Operations:**
- `create-collection`: Create collection in storage backend
- `delete-collection`: Remove all collection data from storage backend
#### Cassandra Triple Store Refactor
@ -244,13 +312,11 @@ As part of this implementation, the Cassandra triple store will be refactored fr
- Maintain same query logic with collection parameter
**Benefits:**
- **Simplified Collection Deletion**: Simple `DELETE FROM triples WHERE collection = ?` instead of dropping tables
- **Simplified Collection Deletion**: Delete using `collection` partition key across all 4 tables
- **Resource Efficiency**: Fewer database connections and table objects
- **Cross-Collection Operations**: Easier to implement operations spanning multiple collections
- **Consistent Architecture**: Aligns with unified collection metadata approach
**Migration Strategy:**
Existing table-per-collection data will need migration to the new unified schema during the upgrade process.
- **Collection Validation**: Easy to check collection existence via `triples_collection` table
Collection operations will be atomic where possible and provide appropriate error handling and validation.
@ -264,37 +330,25 @@ Collection listing operations may need pagination for environments with large nu
## Testing Strategy
Comprehensive testing will cover collection lifecycle operations, metadata management, and CLI command functionality with both unit and integration tests.
## Migration Plan
This implementation requires both metadata and storage migrations:
### Collection Metadata Migration
Existing collections will need to be registered in the new Cassandra collections metadata table. A migration process will:
- Scan existing keyspaces and tables to identify collections
- Create metadata records with default values (name=collection_id, empty description/tags)
- Preserve creation timestamps where possible
### Cassandra Triple Store Migration
The Cassandra storage refactor requires data migration from table-per-collection to unified table:
- **Pre-migration**: Identify all user keyspaces and collection tables
- **Data Transfer**: Copy triples from individual collection tables to unified "triples" table with collection
- **Schema Validation**: Ensure new primary key structure maintains query performance
- **Cleanup**: Remove old collection tables after successful migration
- **Rollback Plan**: Maintain ability to restore table-per-collection structure if needed
Migration will be performed during a maintenance window to ensure data consistency.
Comprehensive testing will cover:
- Collection creation workflow end-to-end
- Storage backend synchronization
- Write validation for non-existent collections
- Query handling of non-existent collections
- Collection deletion cascade across all stores
- Error handling and recovery scenarios
- Unit tests for each storage backend
- Integration tests for cross-store operations
## Implementation Status
### ✅ Completed Components
1. **Librarian Collection Management Service** (`trustgraph-flow/trustgraph/librarian/collection_service.py`)
- Complete collection CRUD operations (list, update, delete)
1. **Librarian Collection Management Service** (`trustgraph-flow/trustgraph/librarian/collection_manager.py`)
- Collection metadata CRUD operations (list, update, delete)
- Cassandra collection metadata table integration via `LibraryTableStore`
- Async request/response handling with proper error management
- Collection deletion cascade coordination across all storage types
- Async request/response handling with proper error management
2. **Collection Metadata Schema** (`trustgraph-base/trustgraph/schema/services/collection.py`)
- `CollectionManagementRequest` and `CollectionManagementResponse` schemas
@ -303,47 +357,70 @@ Migration will be performed during a maintenance window to ensure data consisten
3. **Storage Management Schema** (`trustgraph-base/trustgraph/schema/services/storage.py`)
- `StorageManagementRequest` and `StorageManagementResponse` schemas
- Storage management queue topics defined
- Message format for storage-level collection operations
### ❌ Missing Components
4. **Cassandra 4-Table Schema** (`trustgraph-flow/trustgraph/direct/cassandra_kg.py`)
- Compound partition keys for query performance
- `triples_collection` table for SPO queries and deletion tracking
- Collection deletion implemented with read-then-delete pattern
1. **Storage Management Queue Topics**
- Missing topic definitions in schema for:
- `vector_storage_management_topic`
- `object_storage_management_topic`
- `triples_storage_management_topic`
- `storage_management_response_topic`
- These are referenced by the librarian service but not yet defined
### 🔄 In Progress Components
2. **Store Collection Management Handlers**
- **Vector Store Writers** (Qdrant, Milvus, Pinecone): No collection deletion handlers
- **Object Store Writers** (Cassandra): No collection deletion handlers
- **Triple Store Writers** (Cassandra, Neo4j, Memgraph, FalkorDB): No collection deletion handlers
- Need to implement `StorageManagementRequest` processing in each store writer
1. **Collection Creation Broadcast** (`trustgraph-flow/trustgraph/librarian/collection_manager.py`)
- Update `update_collection()` to send "create-collection" to storage backends
- Wait for confirmations from all storage processors
- Handle creation failures appropriately
3. **Collection Management Interface Implementation**
- Store writers need collection management message consumers
- Collection deletion operations need to be implemented per store type
- Response handling back to librarian service
2. **Document Submission Handler** (`trustgraph-flow/trustgraph/librarian/service.py` or similar)
- Check if collection exists when document submitted
- If not exists: Create collection with defaults before processing document
- Trigger same "create-collection" broadcast as `tg-set-collection`
- Ensure collection established before document flows to storage processors
### ❌ Pending Components
1. **Collection State Tracking** - Need to implement in each storage backend:
- **Cassandra Triples**: Use `triples_collection` table with marker triples
- **Neo4j/Memgraph/FalkorDB**: Create `:CollectionMetadata` nodes
- **Qdrant/Milvus/Pinecone**: Use native collection APIs
- **Cassandra Objects**: Add collection metadata tracking
2. **Storage Management Handlers** - Need "create-collection" support in 12 files:
- `trustgraph-flow/trustgraph/storage/triples/cassandra/write.py`
- `trustgraph-flow/trustgraph/storage/triples/neo4j/write.py`
- `trustgraph-flow/trustgraph/storage/triples/memgraph/write.py`
- `trustgraph-flow/trustgraph/storage/triples/falkordb/write.py`
- `trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py`
- `trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py`
- `trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py`
- `trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py`
- `trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py`
- `trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py`
- `trustgraph-flow/trustgraph/storage/objects/cassandra/write.py`
- Plus any other storage implementations
3. **Write Operation Validation** - Add collection existence checks to all `store_*` methods
4. **Query Operation Handling** - Update queries to return empty for non-existent collections
### Next Implementation Steps
1. **Define Storage Management Topics** in `trustgraph-base/trustgraph/schema/services/storage.py`
2. **Implement Collection Management Handlers** in each storage writer:
- Add `StorageManagementRequest` consumers
- Implement collection deletion operations
- Add response producers for status reporting
3. **Test End-to-End Collection Deletion** across all storage types
**Phase 1: Core Infrastructure (2-3 days)**
1. Add collection state tracking methods to all storage backends
2. Implement `collection_exists()` and `create_collection()` methods
## Timeline
**Phase 2: Storage Handlers (1 week)**
3. Add "create-collection" handlers to all storage processors
4. Add write validation to reject non-existent collections
5. Update query handling for non-existent collections
Phase 1 (Storage Topics): 1-2 days
Phase 2 (Store Handlers): 1-2 weeks depending on number of storage backends
Phase 3 (Testing & Integration): 3-5 days
**Phase 3: Collection Manager (2-3 days)**
6. Update collection_manager to broadcast creates
7. Implement response tracking and error handling
## Open Questions
- Should collection deletion be soft or hard delete by default?
- What metadata fields should be required vs optional?
- Should we implement storage management handlers incrementally by store type?
**Phase 4: Testing (3-5 days)**
8. End-to-end testing of explicit creation workflow
9. Test all storage backends
10. Validate error handling and edge cases

View file

@ -6,7 +6,7 @@ A flow class defines a complete dataflow pattern template in the TrustGraph syst
## Structure
A flow class definition consists of four main sections:
A flow class definition consists of five main sections:
### 1. Class Section
Defines shared service processors that are instantiated once per flow class. These processors handle requests from all flow instances of this class.
@ -15,7 +15,11 @@ Defines shared service processors that are instantiated once per flow class. The
"class": {
"service-name:{class}": {
"request": "queue-pattern:{class}",
"response": "queue-pattern:{class}"
"response": "queue-pattern:{class}",
"settings": {
"setting-name": "fixed-value",
"parameterized-setting": "{parameter-name}"
}
}
}
```
@ -24,6 +28,7 @@ Defines shared service processors that are instantiated once per flow class. The
- Shared across all flow instances of the same class
- Typically expensive or stateless services (LLMs, embedding models)
- Use `{class}` template variable for queue naming
- Settings can be fixed values or parameterized with `{parameter-name}` syntax
- Examples: `embeddings:{class}`, `text-completion:{class}`, `graph-rag:{class}`
### 2. Flow Section
@ -33,7 +38,11 @@ Defines flow-specific processors that are instantiated for each individual flow
"flow": {
"processor-name:{id}": {
"input": "queue-pattern:{id}",
"output": "queue-pattern:{id}"
"output": "queue-pattern:{id}",
"settings": {
"setting-name": "fixed-value",
"parameterized-setting": "{parameter-name}"
}
}
}
```
@ -42,6 +51,7 @@ Defines flow-specific processors that are instantiated for each individual flow
- Unique instance per flow
- Handle flow-specific data and state
- Use `{id}` template variable for queue naming
- Settings can be fixed values or parameterized with `{parameter-name}` syntax
- Examples: `chunker:{id}`, `pdf-decoder:{id}`, `kg-extract-relationships:{id}`
### 3. Interfaces Section
@ -72,7 +82,24 @@ Interfaces can take two forms:
- **Service Interfaces**: Request/response patterns for services (`embeddings`, `text-completion`)
- **Data Interfaces**: Fire-and-forget data flow connection points (`triples-store`, `entity-contexts-load`)
### 4. Metadata
### 4. Parameters Section
Maps flow-specific parameter names to centrally-stored parameter definitions:
```json
"parameters": {
"model": "llm-model",
"temp": "temperature",
"chunk": "chunk-size"
}
```
**Characteristics:**
- Keys are parameter names used in processor settings (e.g., `{model}`)
- Values reference parameter definitions stored in schema/config
- Enables reuse of common parameter definitions across flows
- Reduces duplication of parameter schemas
### 5. Metadata
Additional information about the flow class:
```json
@ -82,16 +109,98 @@ Additional information about the flow class:
## Template Variables
### {id}
### System Variables
#### {id}
- Replaced with the unique flow instance identifier
- Creates isolated resources for each flow
- Example: `flow-123`, `customer-A-flow`
### {class}
#### {class}
- Replaced with the flow class name
- Creates shared resources across flows of the same class
- Example: `standard-rag`, `enterprise-rag`
### Parameter Variables
#### {parameter-name}
- Custom parameters defined at flow launch time
- Parameter names match keys in the flow's `parameters` section
- Used in processor settings to customize behavior
- Examples: `{model}`, `{temp}`, `{chunk}`
- Replaced with values provided when launching the flow
- Validated against centrally-stored parameter definitions
## Processor Settings
Settings provide configuration values to processors at instantiation time. They can be:
### Fixed Settings
Direct values that don't change:
```json
"settings": {
"model": "gemma3:12b",
"temperature": 0.7,
"max_retries": 3
}
```
### Parameterized Settings
Values that use parameters provided at flow launch:
```json
"settings": {
"model": "{model}",
"temperature": "{temp}",
"endpoint": "https://{region}.api.example.com"
}
```
Parameter names in settings correspond to keys in the flow's `parameters` section.
### Settings Examples
**LLM Processor with Parameters:**
```json
// In parameters section:
"parameters": {
"model": "llm-model",
"temp": "temperature",
"tokens": "max-tokens",
"key": "openai-api-key"
}
// In processor definition:
"text-completion:{class}": {
"request": "non-persistent://tg/request/text-completion:{class}",
"response": "non-persistent://tg/response/text-completion:{class}",
"settings": {
"model": "{model}",
"temperature": "{temp}",
"max_tokens": "{tokens}",
"api_key": "{key}"
}
}
```
**Chunker with Fixed and Parameterized Settings:**
```json
// In parameters section:
"parameters": {
"chunk": "chunk-size"
}
// In processor definition:
"chunker:{id}": {
"input": "persistent://tg/flow/chunk:{id}",
"output": "persistent://tg/flow/chunk-load:{id}",
"settings": {
"chunk_size": "{chunk}",
"chunk_overlap": 100,
"encoding": "utf-8"
}
}
```
## Queue Patterns (Pulsar)
Flow classes use Apache Pulsar for messaging. Queue names follow the Pulsar format:
@ -137,15 +246,27 @@ All processors (both `{id}` and `{class}`) work together as a cohesive dataflow
Given:
- Flow Instance ID: `customer-A-flow`
- Flow Class: `standard-rag`
- Flow parameter mappings:
- `"model": "llm-model"`
- `"temp": "temperature"`
- `"chunk": "chunk-size"`
- User-provided parameters:
- `model`: `gpt-4`
- `temp`: `0.5`
- `chunk`: `512`
Template expansions:
- `persistent://tg/flow/chunk-load:{id}``persistent://tg/flow/chunk-load:customer-A-flow`
- `non-persistent://tg/request/embeddings:{class}``non-persistent://tg/request/embeddings:standard-rag`
- `"model": "{model}"``"model": "gpt-4"`
- `"temperature": "{temp}"``"temperature": "0.5"`
- `"chunk_size": "{chunk}"``"chunk_size": "512"`
This creates:
- Isolated document processing pipeline for `customer-A-flow`
- Shared embedding service for all `standard-rag` flows
- Complete dataflow from document ingestion through querying
- Processors configured with the provided parameter values
## Benefits

View file

@ -0,0 +1,485 @@
# Flow Class Configurable Parameters Technical Specification
## Overview
This specification describes the implementation of configurable parameters for flow classes in TrustGraph. Parameters enable users to customize processor parameters at flow launch time by providing values that replace parameter placeholders in the flow class definition.
Parameters work through template variable substitution in processor parameters, similar to how `{id}` and `{class}` variables work, but with user-provided values.
The integration supports four primary use cases:
1. **Model Selection**: Allowing users to choose different LLM models (e.g., `gemma3:8b`, `gpt-4`, `claude-3`) for processors
2. **Resource Configuration**: Adjusting processor parameters like chunk sizes, batch sizes, and concurrency limits
3. **Behavioral Tuning**: Modifying processor behavior through parameters like temperature, max-tokens, or retrieval thresholds
4. **Environment-Specific Parameters**: Configuring endpoints, API keys, or region-specific URLs per deployment
## Goals
- **Dynamic Processor Configuration**: Enable runtime configuration of processor parameters through parameter substitution
- **Parameter Validation**: Provide type checking and validation for parameters at flow launch time
- **Default Values**: Support sensible defaults while allowing overrides for advanced users
- **Template Substitution**: Seamlessly replace parameter placeholders in processor parameters
- **UI Integration**: Enable parameter input through both API and UI interfaces
- **Type Safety**: Ensure parameter types match expected processor parameter types
- **Documentation**: Self-documenting parameter schemas within flow class definitions
- **Backward Compatibility**: Maintain compatibility with existing flow classes that don't use parameters
## Background
Flow classes in TrustGraph now support processor parameters that can contain either fixed values or parameter placeholders. This creates an opportunity for runtime customization.
Current processor parameters support:
- Fixed values: `"model": "gemma3:12b"`
- Parameter placeholders: `"model": "gemma3:{model-size}"`
This specification defines how parameters are:
- Declared in flow class definitions
- Validated when flows are launched
- Substituted in processor parameters
- Exposed through APIs and UI
By leveraging parameterized processor parameters, TrustGraph can:
- Reduce flow class duplication by using parameters for variations
- Enable users to tune processor behavior without modifying definitions
- Support environment-specific configurations through parameter values
- Maintain type safety through parameter schema validation
## Technical Design
### Architecture
The configurable parameters system requires the following technical components:
1. **Parameter Schema Definition**
- JSON Schema-based parameter definitions within flow class metadata
- Type definitions including string, number, boolean, enum, and object types
- Validation rules including min/max values, patterns, and required fields
Module: trustgraph-flow/trustgraph/flow/definition.py
2. **Parameter Resolution Engine**
- Runtime parameter validation against schema
- Default value application for unspecified parameters
- Parameter injection into flow execution context
- Type coercion and conversion as needed
Module: trustgraph-flow/trustgraph/flow/parameter_resolver.py
3. **Parameter Store Integration**
- Retrieval of parameter definitions from schema/config store
- Caching of frequently-used parameter definitions
- Validation against centrally-stored schemas
Module: trustgraph-flow/trustgraph/flow/parameter_store.py
4. **Flow Launcher Extensions**
- API extensions to accept parameter values during flow launch
- Parameter mapping resolution (flow names to definition names)
- Error handling for invalid parameter combinations
Module: trustgraph-flow/trustgraph/flow/launcher.py
5. **UI Parameter Forms**
- Dynamic form generation from flow parameter metadata
- Ordered parameter display using `order` field
- Descriptive parameter labels using `description` field
- Input validation against parameter type definitions
- Parameter presets and templates
Module: trustgraph-ui/components/flow-parameters/
### Data Models
#### Parameter Definitions (Stored in Schema/Config)
Parameter definitions are stored centrally in the schema and config system with type "parameter-types":
```json
{
"llm-model": {
"type": "string",
"description": "LLM model to use",
"default": "gpt-4",
"enum": [
{
"id": "gpt-4",
"description": "OpenAI GPT-4 (Most Capable)"
},
{
"id": "gpt-3.5-turbo",
"description": "OpenAI GPT-3.5 Turbo (Fast & Efficient)"
},
{
"id": "claude-3",
"description": "Anthropic Claude 3 (Thoughtful & Safe)"
},
{
"id": "gemma3:8b",
"description": "Google Gemma 3 8B (Open Source)"
}
],
"required": false
},
"model-size": {
"type": "string",
"description": "Model size variant",
"default": "8b",
"enum": ["2b", "8b", "12b", "70b"],
"required": false
},
"temperature": {
"type": "number",
"description": "Model temperature for generation",
"default": 0.7,
"minimum": 0.0,
"maximum": 2.0,
"required": false
},
"chunk-size": {
"type": "integer",
"description": "Document chunk size",
"default": 512,
"minimum": 128,
"maximum": 2048,
"required": false
}
}
```
#### Flow Class with Parameter References
Flow classes define parameter metadata with type references, descriptions, and ordering:
```json
{
"flow_class": "document-analysis",
"parameters": {
"llm-model": {
"type": "llm-model",
"description": "Primary LLM model for text completion",
"order": 1
},
"llm-rag-model": {
"type": "llm-model",
"description": "LLM model for RAG operations",
"order": 2,
"advanced": true,
"controlled-by": "llm-model"
},
"llm-temperature": {
"type": "temperature",
"description": "Generation temperature for creativity control",
"order": 3,
"advanced": true
},
"chunk-size": {
"type": "chunk-size",
"description": "Document chunk size for processing",
"order": 4,
"advanced": true
},
"chunk-overlap": {
"type": "integer",
"description": "Overlap between document chunks",
"order": 5,
"advanced": true,
"controlled-by": "chunk-size"
}
},
"class": {
"text-completion:{class}": {
"request": "non-persistent://tg/request/text-completion:{class}",
"response": "non-persistent://tg/response/text-completion:{class}",
"parameters": {
"model": "{llm-model}",
"temperature": "{llm-temperature}"
}
},
"rag-completion:{class}": {
"request": "non-persistent://tg/request/rag-completion:{class}",
"response": "non-persistent://tg/response/rag-completion:{class}",
"parameters": {
"model": "{llm-rag-model}",
"temperature": "{llm-temperature}"
}
}
},
"flow": {
"chunker:{id}": {
"input": "persistent://tg/flow/chunk:{id}",
"output": "persistent://tg/flow/chunk-load:{id}",
"parameters": {
"chunk_size": "{chunk-size}",
"chunk_overlap": "{chunk-overlap}"
}
}
}
}
```
The `parameters` section maps flow-specific parameter names (keys) to parameter metadata objects containing:
- `type`: Reference to centrally-defined parameter definition (e.g., "llm-model")
- `description`: Human-readable description for UI display
- `order`: Display order for parameter forms (lower numbers appear first)
- `advanced` (optional): Boolean flag indicating if this is an advanced parameter (default: false). When set to true, the UI may hide this parameter by default or place it in an "Advanced" section
- `controlled-by` (optional): Name of another parameter that controls this parameter's value when in simple mode. When specified, this parameter inherits its value from the controlling parameter unless explicitly overridden
This approach allows:
- Reusable parameter type definitions across multiple flow classes
- Centralized parameter type management and validation
- Flow-specific parameter descriptions and ordering
- Enhanced UI experience with descriptive parameter forms
- Consistent parameter validation across flows
- Easy addition of new standard parameter types
- Simplified UI with basic/advanced mode separation
- Parameter value inheritance for related settings
#### Flow Launch Request
The flow launch API accepts parameters using the flow's parameter names:
```json
{
"flow_class": "document-analysis",
"flow_id": "customer-A-flow",
"parameters": {
"llm-model": "claude-3",
"llm-temperature": 0.5,
"chunk-size": 1024
}
}
```
Note: In this example, `llm-rag-model` is not explicitly provided but will inherit the value "claude-3" from `llm-model` due to its `controlled-by` relationship. Similarly, `chunk-overlap` could inherit a calculated value based on `chunk-size`.
The system will:
1. Extract parameter metadata from flow class definition
2. Map flow parameter names to their type definitions (e.g., `llm-model``llm-model` type)
3. Resolve controlled-by relationships (e.g., `llm-rag-model` inherits from `llm-model`)
4. Validate user-provided and inherited values against the parameter type definitions
5. Substitute resolved values into processor parameters during flow instantiation
### Implementation Details
#### Parameter Resolution Process
When a flow is started, the system performs the following parameter resolution steps:
1. **Flow Class Loading**: Load flow class definition and extract parameter metadata
2. **Metadata Extraction**: Extract `type`, `description`, `order`, `advanced`, and `controlled-by` for each parameter defined in the flow class's `parameters` section
3. **Type Definition Lookup**: For each parameter in the flow class:
- Retrieve the parameter type definition from schema/config store using the `type` field
- The type definitions are stored with type "parameter-types" in the config system
- Each type definition contains the parameter's schema, default value, and validation rules
4. **Default Value Resolution**:
- For each parameter defined in the flow class:
- Check if the user provided a value for this parameter
- If no user value provided, use the `default` value from the parameter type definition
- Build a complete parameter map containing both user-provided and default values
5. **Parameter Inheritance Resolution** (controlled-by relationships):
- For parameters with `controlled-by` field, check if a value was explicitly provided
- If no explicit value provided, inherit the value from the controlling parameter
- If the controlling parameter also has no value, use the default from the type definition
- Validate that no circular dependencies exist in `controlled-by` relationships
6. **Validation**: Validate the complete parameter set (user-provided, defaults, and inherited) against type definitions
7. **Storage**: Store the complete resolved parameter set with the flow instance for auditability
8. **Template Substitution**: Replace parameter placeholders in processor parameters with resolved values
9. **Processor Instantiation**: Create processors with substituted parameters
**Important Implementation Notes:**
- The flow service MUST merge user-provided parameters with defaults from parameter type definitions
- The complete parameter set (including applied defaults) MUST be stored with the flow for traceability
- Parameter resolution happens at flow start time, not at processor instantiation time
- Missing required parameters without defaults MUST cause flow start to fail with a clear error message
#### Parameter Inheritance with controlled-by
The `controlled-by` field enables parameter value inheritance, particularly useful for simplifying user interfaces while maintaining flexibility:
**Example Scenario**:
- `llm-model` parameter controls the primary LLM model
- `llm-rag-model` parameter has `"controlled-by": "llm-model"`
- In simple mode, setting `llm-model` to "gpt-4" automatically sets `llm-rag-model` to "gpt-4" as well
- In advanced mode, users can override `llm-rag-model` with a different value
**Resolution Rules**:
1. If a parameter has an explicitly provided value, use that value
2. If no explicit value and `controlled-by` is set, use the controlling parameter's value
3. If the controlling parameter has no value, fall back to the default from the type definition
4. Circular dependencies in `controlled-by` relationships result in a validation error
**UI Behavior**:
- In basic/simple mode: Parameters with `controlled-by` may be hidden or shown as read-only with inherited value
- In advanced mode: All parameters are shown and can be individually configured
- When a controlling parameter changes, dependent parameters update automatically unless explicitly overridden
#### Pulsar Integration
1. **Start-Flow Operation**
- The Pulsar start-flow operation needs to accept a `parameters` field containing a map of parameter values
- The Pulsar schema for the start-flow request must be updated to include the optional `parameters` field
- Example request:
```json
{
"flow_class": "document-analysis",
"flow_id": "customer-A-flow",
"parameters": {
"model": "claude-3",
"size": "12b",
"temp": 0.5,
"chunk": 1024
}
}
```
2. **Get-Flow Operation**
- The Pulsar schema for the get-flow response must be updated to include the `parameters` field
- This allows clients to retrieve the parameter values that were used when the flow was started
- Example response:
```json
{
"flow_id": "customer-A-flow",
"flow_class": "document-analysis",
"status": "running",
"parameters": {
"model": "claude-3",
"size": "12b",
"temp": 0.5,
"chunk": 1024
}
}
```
#### Flow Service Implementation
The flow configuration service (`trustgraph-flow/trustgraph/config/service/flow.py`) requires the following enhancements:
1. **Parameter Resolution Function**
```python
async def resolve_parameters(self, flow_class, user_params):
"""
Resolve parameters by merging user-provided values with defaults.
Args:
flow_class: The flow class definition dict
user_params: User-provided parameters dict
Returns:
Complete parameter dict with user values and defaults merged
"""
```
This function should:
- Extract parameter metadata from the flow class's `parameters` section
- For each parameter, fetch its type definition from config store
- Apply defaults for any parameters not provided by the user
- Handle `controlled-by` inheritance relationships
- Return the complete parameter set
2. **Modified `handle_start_flow` Method**
- Call `resolve_parameters` after loading the flow class
- Use the complete resolved parameter set for template substitution
- Store the complete parameter set (not just user-provided) with the flow
- Validate that all required parameters have values
3. **Parameter Type Fetching**
- Parameter type definitions are stored in config with type "parameter-types"
- Each type definition contains schema, default value, and validation rules
- Cache frequently-used parameter types to reduce config lookups
#### Config System Integration
3. **Flow Object Storage**
- When a flow is added to the config system by the flow component in the config manager, the flow object must include the resolved parameter values
- The config manager needs to store both the original user-provided parameters and the resolved values (with defaults applied)
- Flow objects in the config system should include:
- `parameters`: The final resolved parameter values used for the flow
#### CLI Integration
4. **Library CLI Commands**
- CLI commands that start flows need parameter support:
- Accept parameter values via command-line flags or configuration files
- Validate parameters against flow class definitions before submission
- Support parameter file input (JSON/YAML) for complex parameter sets
- CLI commands that show flows need to display parameter information:
- Show parameter values used when the flow was started
- Display available parameters for a flow class
- Show parameter validation schemas and defaults
#### Processor Base Class Integration
5. **ParameterSpec Support**
- Processor base classes need to support parameter substitution through the existing ParametersSpec mechanism
- The ParametersSpec class (located in the same module as ConsumerSpec and ProducerSpec) should be enhanced if necessary to support parameter template substitution
- Processors should be able to invoke ParametersSpec to configure their parameters with parameter values resolved at flow launch time
- The ParametersSpec implementation needs to:
- Accept parameters configurations that contain parameter placeholders (e.g., `{model}`, `{temperature}`)
- Support runtime parameter substitution when the processor is instantiated
- Validate that substituted values match expected types and constraints
- Provide error handling for missing or invalid parameter references
#### Substitution Rules
- Parameters use the format `{parameter-name}` in processor parameters
- Parameter names in parameters match the keys in the flow's `parameters` section
- Substitution occurs alongside `{id}` and `{class}` replacement
- Invalid parameter references result in launch-time errors
- Type validation happens based on the centrally-stored parameter definition
- **IMPORTANT**: All parameter values are stored and transmitted as strings
- Numbers are converted to strings (e.g., `0.7` becomes `"0.7"`)
- Booleans are converted to lowercase strings (e.g., `true` becomes `"true"`)
- This is required by the Pulsar schema which defines `parameters = Map(String())`
Example resolution:
```
Flow parameter mapping: "model": "llm-model"
Processor parameter: "model": "{model}"
User provides: "model": "gemma3:8b"
Final parameter: "model": "gemma3:8b"
Example with type conversion:
Parameter type default: 0.7 (number)
Stored in flow: "0.7" (string)
Substituted in processor: "0.7" (string)
```
## Testing Strategy
- Unit tests for parameter schema validation
- Integration tests for parameter substitution in processor parameters
- End-to-end tests for launching flows with different parameter values
- UI tests for parameter form generation and validation
- Performance tests for flows with many parameters
- Edge cases: missing parameters, invalid types, undefined parameter references
## Migration Plan
1. The system should continue to support flow classes with no parameters
declared.
2. The system should continue to support flows no parameters specified:
This works for flows with no parameters, and flows with parameters
(they have defaults).
## Open Questions
Q: Should parameters support complex nested objects or keep to simple types?
A: The parameter values will be string encoded, we're probably going to want
to stick to strings.
Q: Should parameter placeholders be allowed in queue names or only in
parameters?
A: Only in parameters to remove strange injections and edge-cases.
Q: How to handle conflicts between parameter names and system variables like
`id` and `class`?
A: It is not valid to specify id and class when launching a flow
Q: Should we support computed parameters (derived from other parameters)?
A: Just string substitution to remove strange injections and edge-cases.
## References
- JSON Schema Specification: https://json-schema.org/
- Flow Class Definition Spec: docs/tech-specs/flow-class-definition.md

View file

@ -0,0 +1,629 @@
# GraphRAG Performance Optimisation Technical Specification
## Overview
This specification describes comprehensive performance optimisations for the GraphRAG (Graph Retrieval-Augmented Generation) algorithm in TrustGraph. The current implementation suffers from significant performance bottlenecks that limit scalability and response times. This specification addresses four primary optimisation areas:
1. **Graph Traversal Optimisation**: Eliminate inefficient recursive database queries and implement batched graph exploration
2. **Label Resolution Optimisation**: Replace sequential label fetching with parallel/batched operations
3. **Caching Strategy Enhancement**: Implement intelligent caching with LRU eviction and prefetching
4. **Query Optimisation**: Add result memoisation and embedding caching for improved response times
## Goals
- **Reduce Database Query Volume**: Achieve 50-80% reduction in total database queries through batching and caching
- **Improve Response Times**: Target 3-5x faster subgraph construction and 2-3x faster label resolution
- **Enhance Scalability**: Support larger knowledge graphs with better memory management
- **Maintain Accuracy**: Preserve existing GraphRAG functionality and result quality
- **Enable Concurrency**: Improve parallel processing capabilities for multiple concurrent requests
- **Reduce Memory Footprint**: Implement efficient data structures and memory management
- **Add Observability**: Include performance metrics and monitoring capabilities
- **Ensure Reliability**: Add proper error handling and timeout mechanisms
## Background
The current GraphRAG implementation in `trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py` exhibits several critical performance issues that severely impact system scalability:
### Current Performance Problems
**1. Inefficient Graph Traversal (`follow_edges` function, lines 79-127)**
- Makes 3 separate database queries per entity per depth level
- Query pattern: subject-based, predicate-based, and object-based queries for each entity
- No batching: Each query processes only one entity at a time
- No cycle detection: Can revisit the same nodes multiple times
- Recursive implementation without memoisation leads to exponential complexity
- Time complexity: O(entities × max_path_length × triple_limit³)
**2. Sequential Label Resolution (`get_labelgraph` function, lines 144-171)**
- Processes each triple component (subject, predicate, object) sequentially
- Each `maybe_label` call potentially triggers a database query
- No parallel execution or batching of label queries
- Results in up to 3 × subgraph_size individual database calls
**3. Primitive Caching Strategy (`maybe_label` function, lines 62-77)**
- Simple dictionary cache without size limits or TTL
- No cache eviction policy leads to unbounded memory growth
- Cache misses trigger individual database queries
- No prefetching or intelligent cache warming
**4. Suboptimal Query Patterns**
- Entity vector similarity queries not cached between similar requests
- No result memoisation for repeated query patterns
- Missing query optimisation for common access patterns
**5. Critical Object Lifetime Issues (`rag.py:96-102`)**
- **GraphRag object recreated per request**: Fresh instance created for every query, losing all cache benefits
- **Query object extremely short-lived**: Created and destroyed within single query execution (lines 201-207)
- **Label cache reset per request**: Cache warming and accumulated knowledge lost between requests
- **Client recreation overhead**: Database clients potentially re-established for each request
- **No cross-request optimisation**: Cannot benefit from query patterns or result sharing
### Performance Impact Analysis
Current worst-case scenario for a typical query:
- **Entity Retrieval**: 1 vector similarity query
- **Graph Traversal**: entities × max_path_length × 3 × triple_limit queries
- **Label Resolution**: subgraph_size × 3 individual label queries
For default parameters (50 entities, path length 2, 30 triple limit, 150 subgraph size):
- **Minimum queries**: 1 + (50 × 2 × 3 × 30) + (150 × 3) = **9,451 database queries**
- **Response time**: 15-30 seconds for moderate-sized graphs
- **Memory usage**: Unbounded cache growth over time
- **Cache effectiveness**: 0% - caches reset on every request
- **Object creation overhead**: GraphRag + Query objects created/destroyed per request
This specification addresses these gaps by implementing batched queries, intelligent caching, and parallel processing. By optimizing query patterns and data access, TrustGraph can:
- Support enterprise-scale knowledge graphs with millions of entities
- Provide sub-second response times for typical queries
- Handle hundreds of concurrent GraphRAG requests
- Scale efficiently with graph size and complexity
## Technical Design
### Architecture
The GraphRAG performance optimisation requires the following technical components:
#### 1. **Object Lifetime Architectural Refactor**
- **Make GraphRag long-lived**: Move GraphRag instance to Processor level for persistence across requests
- **Preserve caches**: Maintain label cache, embedding cache, and query result cache between requests
- **Optimize Query object**: Refactor Query as lightweight execution context, not data container
- **Connection persistence**: Maintain database client connections across requests
Module: `trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py` (modified)
#### 2. **Optimized Graph Traversal Engine**
- Replace recursive `follow_edges` with iterative breadth-first search
- Implement batched entity processing at each traversal level
- Add cycle detection using visited node tracking
- Include early termination when limits are reached
Module: `trustgraph-flow/trustgraph/retrieval/graph_rag/optimized_traversal.py`
#### 3. **Parallel Label Resolution System**
- Batch label queries for multiple entities simultaneously
- Implement async/await patterns for concurrent database access
- Add intelligent prefetching for common label patterns
- Include label cache warming strategies
Module: `trustgraph-flow/trustgraph/retrieval/graph_rag/label_resolver.py`
#### 4. **Conservative Label Caching Layer**
- LRU cache with short TTL for labels only (5min) to balance performance vs consistency
- Cache metrics and hit ratio monitoring
- **No embedding caching**: Already cached per-query, no cross-query benefit
- **No query result caching**: Due to graph mutation consistency concerns
Module: `trustgraph-flow/trustgraph/retrieval/graph_rag/cache_manager.py`
#### 5. **Query Optimisation Framework**
- Query pattern analysis and optimisation suggestions
- Batch query coordinator for database access
- Connection pooling and query timeout management
- Performance monitoring and metrics collection
Module: `trustgraph-flow/trustgraph/retrieval/graph_rag/query_optimizer.py`
### Data Models
#### Optimized Graph Traversal State
The traversal engine maintains state to avoid redundant operations:
```python
@dataclass
class TraversalState:
visited_entities: Set[str]
current_level_entities: Set[str]
next_level_entities: Set[str]
subgraph: Set[Tuple[str, str, str]]
depth: int
query_batch: List[TripleQuery]
```
This approach allows:
- Efficient cycle detection through visited entity tracking
- Batched query preparation at each traversal level
- Memory-efficient state management
- Early termination when size limits are reached
#### Enhanced Cache Structure
```python
@dataclass
class CacheEntry:
value: Any
timestamp: float
access_count: int
ttl: Optional[float]
class CacheManager:
label_cache: LRUCache[str, CacheEntry]
embedding_cache: LRUCache[str, CacheEntry]
query_result_cache: LRUCache[str, CacheEntry]
cache_stats: CacheStatistics
```
#### Batch Query Structures
```python
@dataclass
class BatchTripleQuery:
entities: List[str]
query_type: QueryType # SUBJECT, PREDICATE, OBJECT
limit_per_entity: int
@dataclass
class BatchLabelQuery:
entities: List[str]
predicate: str = LABEL
```
### APIs
#### New APIs:
**GraphTraversal API**
```python
async def optimized_follow_edges_batch(
entities: List[str],
max_depth: int,
triple_limit: int,
max_subgraph_size: int
) -> Set[Tuple[str, str, str]]
```
**Batch Label Resolution API**
```python
async def resolve_labels_batch(
entities: List[str],
cache_manager: CacheManager
) -> Dict[str, str]
```
**Cache Management API**
```python
class CacheManager:
async def get_or_fetch_label(self, entity: str) -> str
async def get_or_fetch_embeddings(self, query: str) -> List[float]
async def cache_query_result(self, query_hash: str, result: Any, ttl: int)
def get_cache_statistics(self) -> CacheStatistics
```
#### Modified APIs:
**GraphRag.query()** - Enhanced with performance optimisations:
- Add cache_manager parameter for cache control
- Include performance_metrics return value
- Add query_timeout parameter for reliability
**Query class** - Refactored for batch processing:
- Replace individual entity processing with batch operations
- Add async context managers for resource cleanup
- Include progress callbacks for long-running operations
### Implementation Details
#### Phase 0: Critical Architectural Lifetime Refactor
**Current Problematic Implementation:**
```python
# INEFFICIENT: GraphRag recreated every request
class Processor(FlowProcessor):
async def on_request(self, msg, consumer, flow):
# PROBLEM: New GraphRag instance per request!
self.rag = GraphRag(
embeddings_client = flow("embeddings-request"),
graph_embeddings_client = flow("graph-embeddings-request"),
triples_client = flow("triples-request"),
prompt_client = flow("prompt-request"),
verbose=True,
)
# Cache starts empty every time - no benefit from previous requests
response = await self.rag.query(...)
# VERY SHORT-LIVED: Query object created/destroyed per request
class GraphRag:
async def query(self, query, user="trustgraph", collection="default", ...):
q = Query(rag=self, user=user, collection=collection, ...) # Created
kg = await q.get_labelgraph(query) # Used briefly
# q automatically destroyed when function exits
```
**Optimized Long-Lived Architecture:**
```python
class Processor(FlowProcessor):
def __init__(self, **params):
super().__init__(**params)
self.rag_instance = None # Will be initialized once
self.client_connections = {}
async def initialize_rag(self, flow):
"""Initialize GraphRag once, reuse for all requests"""
if self.rag_instance is None:
self.rag_instance = LongLivedGraphRag(
embeddings_client=flow("embeddings-request"),
graph_embeddings_client=flow("graph-embeddings-request"),
triples_client=flow("triples-request"),
prompt_client=flow("prompt-request"),
verbose=True,
)
return self.rag_instance
async def on_request(self, msg, consumer, flow):
# REUSE the same GraphRag instance - caches persist!
rag = await self.initialize_rag(flow)
# Query object becomes lightweight execution context
response = await rag.query_with_context(
query=v.query,
execution_context=QueryContext(
user=v.user,
collection=v.collection,
entity_limit=entity_limit,
# ... other params
)
)
class LongLivedGraphRag:
def __init__(self, ...):
# CONSERVATIVE caches - balance performance vs consistency
self.label_cache = LRUCacheWithTTL(max_size=5000, ttl=300) # 5min TTL for freshness
# Note: No embedding cache - already cached per-query, no cross-query benefit
# Note: No query result cache due to consistency concerns
self.performance_metrics = PerformanceTracker()
async def query_with_context(self, query: str, context: QueryContext):
# Use lightweight QueryExecutor instead of heavyweight Query object
executor = QueryExecutor(self, context) # Minimal object
return await executor.execute(query)
@dataclass
class QueryContext:
"""Lightweight execution context - no heavy operations"""
user: str
collection: str
entity_limit: int
triple_limit: int
max_subgraph_size: int
max_path_length: int
class QueryExecutor:
"""Lightweight execution context - replaces old Query class"""
def __init__(self, rag: LongLivedGraphRag, context: QueryContext):
self.rag = rag
self.context = context
# No heavy initialization - just references
async def execute(self, query: str):
# All heavy lifting uses persistent rag caches
return await self.rag.execute_optimized_query(query, self.context)
```
This architectural change provides:
- **10-20% database query reduction** for graphs with common relationships (vs 0% currently)
- **Eliminated object creation overhead** for every request
- **Persistent connection pooling** and client reuse
- **Cross-request optimization** within cache TTL windows
**Important Cache Consistency Limitation:**
Long-term caching introduces staleness risk when entities/labels are deleted or modified in the underlying graph. The LRU cache with TTL provides a balance between performance gains and data freshness, but cannot detect real-time graph changes.
#### Phase 1: Graph Traversal Optimisation
**Current Implementation Problems:**
```python
# INEFFICIENT: 3 queries per entity per level
async def follow_edges(self, ent, subgraph, path_length):
# Query 1: s=ent, p=None, o=None
res = await self.rag.triples_client.query(s=ent, p=None, o=None, limit=self.triple_limit)
# Query 2: s=None, p=ent, o=None
res = await self.rag.triples_client.query(s=None, p=ent, o=None, limit=self.triple_limit)
# Query 3: s=None, p=None, o=ent
res = await self.rag.triples_client.query(s=None, p=None, o=ent, limit=self.triple_limit)
```
**Optimized Implementation:**
```python
async def optimized_traversal(self, entities: List[str], max_depth: int) -> Set[Triple]:
visited = set()
current_level = set(entities)
subgraph = set()
for depth in range(max_depth):
if not current_level or len(subgraph) >= self.max_subgraph_size:
break
# Batch all queries for current level
batch_queries = []
for entity in current_level:
if entity not in visited:
batch_queries.extend([
TripleQuery(s=entity, p=None, o=None),
TripleQuery(s=None, p=entity, o=None),
TripleQuery(s=None, p=None, o=entity)
])
# Execute all queries concurrently
results = await self.execute_batch_queries(batch_queries)
# Process results and prepare next level
next_level = set()
for result in results:
subgraph.update(result.triples)
next_level.update(result.new_entities)
visited.update(current_level)
current_level = next_level - visited
return subgraph
```
#### Phase 2: Parallel Label Resolution
**Current Sequential Implementation:**
```python
# INEFFICIENT: Sequential processing
for edge in subgraph:
s = await self.maybe_label(edge[0]) # Individual query
p = await self.maybe_label(edge[1]) # Individual query
o = await self.maybe_label(edge[2]) # Individual query
```
**Optimized Parallel Implementation:**
```python
async def resolve_labels_parallel(self, subgraph: List[Triple]) -> List[Triple]:
# Collect all unique entities needing labels
entities_to_resolve = set()
for s, p, o in subgraph:
entities_to_resolve.update([s, p, o])
# Remove already cached entities
uncached_entities = [e for e in entities_to_resolve if e not in self.label_cache]
# Batch query for all uncached labels
if uncached_entities:
label_results = await self.batch_label_query(uncached_entities)
self.label_cache.update(label_results)
# Apply labels to subgraph
return [
(self.label_cache.get(s, s), self.label_cache.get(p, p), self.label_cache.get(o, o))
for s, p, o in subgraph
]
```
#### Phase 3: Advanced Caching Strategy
**LRU Cache with TTL:**
```python
class LRUCacheWithTTL:
def __init__(self, max_size: int, default_ttl: int = 3600):
self.cache = OrderedDict()
self.max_size = max_size
self.default_ttl = default_ttl
self.access_times = {}
async def get(self, key: str) -> Optional[Any]:
if key in self.cache:
# Check TTL expiration
if time.time() - self.access_times[key] > self.default_ttl:
del self.cache[key]
del self.access_times[key]
return None
# Move to end (most recently used)
self.cache.move_to_end(key)
return self.cache[key]
return None
async def put(self, key: str, value: Any):
if key in self.cache:
self.cache.move_to_end(key)
else:
if len(self.cache) >= self.max_size:
# Remove least recently used
oldest_key = next(iter(self.cache))
del self.cache[oldest_key]
del self.access_times[oldest_key]
self.cache[key] = value
self.access_times[key] = time.time()
```
#### Phase 4: Query Optimisation and Monitoring
**Performance Metrics Collection:**
```python
@dataclass
class PerformanceMetrics:
total_queries: int
cache_hits: int
cache_misses: int
avg_response_time: float
subgraph_construction_time: float
label_resolution_time: float
total_entities_processed: int
memory_usage_mb: float
```
**Query Timeout and Circuit Breaker:**
```python
async def execute_with_timeout(self, query_func, timeout: int = 30):
try:
return await asyncio.wait_for(query_func(), timeout=timeout)
except asyncio.TimeoutError:
logger.error(f"Query timeout after {timeout}s")
raise GraphRagTimeoutError(f"Query exceeded timeout of {timeout}s")
```
## Cache Consistency Considerations
**Data Staleness Trade-offs:**
- **Label cache (5min TTL)**: Risk of serving deleted/renamed entity labels
- **No embedding caching**: Not needed - embeddings already cached per-query
- **No result caching**: Prevents stale subgraph results from deleted entities/relationships
**Mitigation Strategies:**
- **Conservative TTL values**: Balance performance gains (10-20%) with data freshness
- **Cache invalidation hooks**: Optional integration with graph mutation events
- **Monitoring dashboards**: Track cache hit rates vs staleness incidents
- **Configurable cache policies**: Allow per-deployment tuning based on mutation frequency
**Recommended Cache Configuration by Graph Mutation Rate:**
- **High mutation (>100 changes/hour)**: TTL=60s, smaller cache sizes
- **Medium mutation (10-100 changes/hour)**: TTL=300s (default)
- **Low mutation (<10 changes/hour)**: TTL=600s, larger cache sizes
## Security Considerations
**Query Injection Prevention:**
- Validate all entity identifiers and query parameters
- Use parameterized queries for all database interactions
- Implement query complexity limits to prevent DoS attacks
**Resource Protection:**
- Enforce maximum subgraph size limits
- Implement query timeouts to prevent resource exhaustion
- Add memory usage monitoring and limits
**Access Control:**
- Maintain existing user and collection isolation
- Add audit logging for performance-impacting operations
- Implement rate limiting for expensive operations
## Performance Considerations
### Expected Performance Improvements
**Query Reduction:**
- Current: ~9,000+ queries for typical request
- Optimized: ~50-100 batched queries (98% reduction)
**Response Time Improvements:**
- Graph traversal: 15-20s → 3-5s (4-5x faster)
- Label resolution: 8-12s → 2-4s (3x faster)
- Overall query: 25-35s → 6-10s (3-4x improvement)
**Memory Efficiency:**
- Bounded cache sizes prevent memory leaks
- Efficient data structures reduce memory footprint by ~40%
- Better garbage collection through proper resource cleanup
**Realistic Performance Expectations:**
- **Label cache**: 10-20% query reduction for graphs with common relationships
- **Batching optimization**: 50-80% query reduction (primary optimization)
- **Object lifetime optimization**: Eliminate per-request creation overhead
- **Overall improvement**: 3-4x response time improvement primarily from batching
**Scalability Improvements:**
- Support for 3-5x larger knowledge graphs (limited by cache consistency needs)
- 3-5x higher concurrent request capacity
- Better resource utilization through connection reuse
### Performance Monitoring
**Real-time Metrics:**
- Query execution times by operation type
- Cache hit ratios and effectiveness
- Database connection pool utilisation
- Memory usage and garbage collection impact
**Performance Benchmarking:**
- Automated performance regression testing
- Load testing with realistic data volumes
- Comparison benchmarks against current implementation
## Testing Strategy
### Unit Testing
- Individual component testing for traversal, caching, and label resolution
- Mock database interactions for performance testing
- Cache eviction and TTL expiration testing
- Error handling and timeout scenarios
### Integration Testing
- End-to-end GraphRAG query testing with optimisations
- Database interaction testing with real data
- Concurrent request handling and resource management
- Memory leak detection and resource cleanup verification
### Performance Testing
- Benchmark testing against current implementation
- Load testing with varying graph sizes and complexities
- Stress testing for memory and connection limits
- Regression testing for performance improvements
### Compatibility Testing
- Verify existing GraphRAG API compatibility
- Test with various graph database backends
- Validate result accuracy compared to current implementation
## Implementation Plan
### Direct Implementation Approach
Since APIs are allowed to change, implement optimizations directly without migration complexity:
1. **Replace `follow_edges` method**: Rewrite with iterative batched traversal
2. **Optimize `get_labelgraph`**: Implement parallel label resolution
3. **Add long-lived GraphRag**: Modify Processor to maintain persistent instance
4. **Implement label caching**: Add LRU cache with TTL to GraphRag class
### Scope of Changes
- **Query class**: Replace ~50 lines in `follow_edges`, add ~30 lines batch handling
- **GraphRag class**: Add caching layer (~40 lines)
- **Processor class**: Modify to use persistent GraphRag instance (~20 lines)
- **Total**: ~140 lines of focused changes, mostly within existing classes
## Timeline
**Week 1: Core Implementation**
- Replace `follow_edges` with batched iterative traversal
- Implement parallel label resolution in `get_labelgraph`
- Add long-lived GraphRag instance to Processor
- Implement label caching layer
**Week 2: Testing and Integration**
- Unit tests for new traversal and caching logic
- Performance benchmarking against current implementation
- Integration testing with real graph data
- Code review and optimization
**Week 3: Deployment**
- Deploy optimized implementation
- Monitor performance improvements
- Fine-tune cache TTL and batch sizes based on real usage
## Open Questions
- **Database Connection Pooling**: Should we implement custom connection pooling or rely on existing database client pooling?
- **Cache Persistence**: Should label and embedding caches persist across service restarts?
- **Distributed Caching**: For multi-instance deployments, should we implement distributed caching with Redis/Memcached?
- **Query Result Format**: Should we optimize the internal triple representation for better memory efficiency?
- **Monitoring Integration**: Which metrics should be exposed to existing monitoring systems (Prometheus, etc.)?
## References
- [GraphRAG Original Implementation](trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py)
- [TrustGraph Architecture Principles](architecture-principles.md)
- [Collection Management Specification](collection-management.md)

View file

@ -29,23 +29,25 @@ class TestEndToEndConfigurationFlow:
'CASSANDRA_USERNAME': 'integration-user',
'CASSANDRA_PASSWORD': 'integration-pass'
}
mock_cluster_instance = MagicMock()
mock_session = MagicMock()
mock_cluster_instance.connect.return_value = mock_session
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = TriplesWriter(taskgroup=MagicMock())
# Create a mock message to trigger TrustGraph creation
mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection'
mock_message.triples = []
# This should create TrustGraph with environment config
await processor.store_triples(mock_message)
# Mock collection_exists to return True
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
# This should create TrustGraph with environment config
await processor.store_triples(mock_message)
# Verify Cluster was created with correct hosts
mock_cluster.assert_called_once()
@ -145,8 +147,10 @@ class TestConfigurationPriorityEndToEnd:
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection'
mock_message.triples = []
await processor.store_triples(mock_message)
# Mock collection_exists to return True
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
await processor.store_triples(mock_message)
# Should use CLI parameters, not environment
mock_cluster.assert_called_once()
@ -243,8 +247,10 @@ class TestNoBackwardCompatibilityEndToEnd:
mock_message.metadata.user = 'legacy_user'
mock_message.metadata.collection = 'legacy_collection'
mock_message.triples = []
await processor.store_triples(mock_message)
# Mock collection_exists to return True
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
await processor.store_triples(mock_message)
# Should use defaults since old parameters are not recognized
mock_cluster.assert_called_once()
@ -299,8 +305,10 @@ class TestNoBackwardCompatibilityEndToEnd:
mock_message.metadata.user = 'precedence_user'
mock_message.metadata.collection = 'precedence_collection'
mock_message.triples = []
await processor.store_triples(mock_message)
# Mock collection_exists to return True
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
await processor.store_triples(mock_message)
# Should use new parameters, not old ones
mock_cluster.assert_called_once()
@ -349,8 +357,10 @@ class TestMultipleHostsHandling:
mock_message.metadata.user = 'single_user'
mock_message.metadata.collection = 'single_collection'
mock_message.triples = []
await processor.store_triples(mock_message)
# Mock collection_exists to return True
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
await processor.store_triples(mock_message)
# Single host should be converted to list
mock_cluster.assert_called_once()

View file

@ -0,0 +1,276 @@
"""
Integration tests for Dynamic LLM Parameters
Testing end-to-end flow of runtime parameter changes in LLM processors
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice
from openai.types.completion_usage import CompletionUsage
from trustgraph.model.text_completion.openai.llm import Processor as OpenAIProcessor
from trustgraph.base import LlmResult
@pytest.mark.integration
class TestDynamicLlmParameters:
"""Integration tests for dynamic parameter configuration"""
@pytest.fixture
def mock_openai_client(self):
"""Mock OpenAI client that returns realistic responses"""
client = MagicMock()
# Default mock response
usage = CompletionUsage(prompt_tokens=25, completion_tokens=15, total_tokens=40)
message = ChatCompletionMessage(role="assistant", content="Dynamic parameter test response")
choice = Choice(index=0, message=message, finish_reason="stop")
completion = ChatCompletion(
id="chatcmpl-test-dynamic",
choices=[choice],
created=1234567890,
model="gpt-4", # Will be overridden based on test
object="chat.completion",
usage=usage
)
client.chat.completions.create.return_value = completion
return client
@pytest.fixture
def base_processor_config(self):
"""Base configuration for test processors"""
return {
"api_key": "test-api-key",
"url": "https://api.openai.com/v1",
"temperature": 0.0, # Default temperature
"max_output": 1024,
}
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_runtime_temperature_override(self, mock_llm_init, mock_async_init,
mock_openai_class, mock_openai_client, base_processor_config):
"""Test that temperature can be overridden at runtime"""
# Arrange
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = base_processor_config | {
"model": "gpt-3.5-turbo",
"concurrency": 1,
"taskgroup": AsyncMock(),
"id": "test-processor"
}
processor = OpenAIProcessor(**config)
# Act - Call with different temperature than configured default (0.0)
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Dynamic parameter test response"
# Verify the OpenAI API was called with the overridden temperature
mock_openai_client.chat.completions.create.assert_called_once()
call_args = mock_openai_client.chat.completions.create.call_args
assert call_args.kwargs['temperature'] == 0.9 # Should use runtime parameter
assert call_args.kwargs['model'] == "gpt-3.5-turbo" # Should use processor default
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_runtime_model_override(self, mock_llm_init, mock_async_init,
mock_openai_class, mock_openai_client, base_processor_config):
"""Test that model can be overridden at runtime"""
# Arrange
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = base_processor_config | {
"model": "gpt-3.5-turbo", # Default model
"concurrency": 1,
"taskgroup": AsyncMock(),
"id": "test-processor"
}
processor = OpenAIProcessor(**config)
# Act - Call with different model than configured default
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gpt-4", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
# Verify the OpenAI API was called with the overridden model
mock_openai_client.chat.completions.create.assert_called_once()
call_args = mock_openai_client.chat.completions.create.call_args
assert call_args.kwargs['model'] == "gpt-4" # Should use runtime parameter
assert call_args.kwargs['temperature'] == 0.0 # Should use processor default
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_both_parameters_override(self, mock_llm_init, mock_async_init,
mock_openai_class, mock_openai_client, base_processor_config):
"""Test that both model and temperature can be overridden simultaneously"""
# Arrange
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = base_processor_config | {
"model": "gpt-3.5-turbo", # Default model
"concurrency": 1,
"taskgroup": AsyncMock(),
"id": "test-processor"
}
processor = OpenAIProcessor(**config)
# Act - Override both parameters
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gpt-4", # Override model
temperature=0.5 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
# Verify both parameters were overridden
mock_openai_client.chat.completions.create.assert_called_once()
call_args = mock_openai_client.chat.completions.create.call_args
assert call_args.kwargs['model'] == "gpt-4" # Should use runtime parameter
assert call_args.kwargs['temperature'] == 0.5 # Should use runtime parameter
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_fallback_to_defaults_when_no_override(self, mock_llm_init, mock_async_init,
mock_openai_class, mock_openai_client, base_processor_config):
"""Test that processor falls back to configured defaults when no parameters are provided"""
# Arrange
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = base_processor_config | {
"model": "gpt-3.5-turbo", # Default model
"temperature": 0.2, # Default temperature
"concurrency": 1,
"taskgroup": AsyncMock(),
"id": "test-processor"
}
processor = OpenAIProcessor(**config)
# Act - Call with no parameter overrides
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default
temperature=None # Use default
)
# Assert
assert isinstance(result, LlmResult)
# Verify defaults were used
mock_openai_client.chat.completions.create.assert_called_once()
call_args = mock_openai_client.chat.completions.create.call_args
assert call_args.kwargs['model'] == "gpt-3.5-turbo" # Should use processor default
assert call_args.kwargs['temperature'] == 0.2 # Should use processor default
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_multiple_concurrent_calls_different_parameters(self, mock_llm_init, mock_async_init,
mock_openai_class, mock_openai_client, base_processor_config):
"""Test multiple concurrent calls with different parameters don't interfere"""
# Arrange
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = base_processor_config | {
"model": "gpt-3.5-turbo",
"concurrency": 1,
"taskgroup": AsyncMock(),
"id": "test-processor"
}
processor = OpenAIProcessor(**config)
# Reset the mock to track multiple calls
mock_openai_client.reset_mock()
# Act - Make multiple calls with different parameters concurrently
import asyncio
tasks = [
processor.generate_content("System 1", "Prompt 1", model="gpt-3.5-turbo", temperature=0.1),
processor.generate_content("System 2", "Prompt 2", model="gpt-4", temperature=0.8),
processor.generate_content("System 3", "Prompt 3", model="gpt-3.5-turbo", temperature=0.5)
]
results = await asyncio.gather(*tasks)
# Assert
assert len(results) == 3
for result in results:
assert isinstance(result, LlmResult)
# Verify all calls were made with correct parameters
assert mock_openai_client.chat.completions.create.call_count == 3
# Get all call arguments
call_args_list = mock_openai_client.chat.completions.create.call_args_list
# Verify each call had the expected parameters
expected_params = [
("gpt-3.5-turbo", 0.1),
("gpt-4", 0.8),
("gpt-3.5-turbo", 0.5)
]
for i, (expected_model, expected_temp) in enumerate(expected_params):
call_kwargs = call_args_list[i].kwargs
assert call_kwargs['model'] == expected_model
assert call_kwargs['temperature'] == expected_temp
async def test_parameter_boundary_values(self, mock_openai_client, base_processor_config):
"""Test parameter boundary values (edge cases)"""
# This would test extreme values like temperature=0.0, temperature=2.0, etc.
# Implementation depends on specific validation requirements
pass
async def test_invalid_parameter_types_handling(self, mock_openai_client, base_processor_config):
"""Test handling of invalid parameter types"""
# This would test what happens with invalid temperature values, non-existent models, etc.
# Implementation depends on error handling requirements
pass
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -22,7 +22,36 @@ class TestObjectsCassandraIntegration:
def mock_cassandra_session(self):
"""Mock Cassandra session for integration tests"""
session = MagicMock()
session.execute = MagicMock()
# Track if keyspaces have been created
created_keyspaces = set()
# Mock the execute method to return a valid result for keyspace checks
def execute_mock(query, *args, **kwargs):
result = MagicMock()
query_str = str(query)
# Track keyspace creation
if "CREATE KEYSPACE" in query_str:
# Extract keyspace name from query
import re
match = re.search(r'CREATE KEYSPACE IF NOT EXISTS (\w+)', query_str)
if match:
created_keyspaces.add(match.group(1))
# For keyspace existence checks
if "system_schema.keyspaces" in query_str:
# Check if this keyspace was created
if args and args[0] in created_keyspaces:
result.one.return_value = MagicMock() # Exists
else:
result.one.return_value = None # Doesn't exist
else:
result.one.return_value = None
return result
session.execute = MagicMock(side_effect=execute_mock)
return session
@pytest.fixture
@ -57,7 +86,8 @@ class TestObjectsCassandraIntegration:
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.create_collection = Processor.create_collection.__get__(processor, Processor)
return processor, mock_cassandra_cluster, mock_cassandra_session
@pytest.mark.asyncio
@ -85,7 +115,10 @@ class TestObjectsCassandraIntegration:
await processor.on_schema_config(config, version=1)
assert "customer_records" in processor.schemas
# Step 1.5: Create the collection first (simulate tg-set-collection)
await processor.create_collection("test_user", "import_2024")
# Step 2: Process an ExtractedObject
test_obj = ExtractedObject(
metadata=Metadata(
@ -104,10 +137,10 @@ class TestObjectsCassandraIntegration:
confidence=0.95,
source_span="Customer: John Doe..."
)
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
# Verify Cassandra interactions
@ -178,7 +211,11 @@ class TestObjectsCassandraIntegration:
await processor.on_schema_config(config, version=1)
assert len(processor.schemas) == 2
# Create collections first
await processor.create_collection("shop", "catalog")
await processor.create_collection("shop", "sales")
# Process objects for different schemas
product_obj = ExtractedObject(
metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]),
@ -187,7 +224,7 @@ class TestObjectsCassandraIntegration:
confidence=0.9,
source_span="Product..."
)
order_obj = ExtractedObject(
metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]),
schema_name="orders",
@ -195,7 +232,7 @@ class TestObjectsCassandraIntegration:
confidence=0.85,
source_span="Order..."
)
# Process both objects
for obj in [product_obj, order_obj]:
msg = MagicMock()
@ -225,6 +262,9 @@ class TestObjectsCassandraIntegration:
]
)
# Create collection first
await processor.create_collection("test", "test")
# Create object missing required field
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
@ -233,10 +273,10 @@ class TestObjectsCassandraIntegration:
confidence=0.8,
source_span="Test"
)
msg = MagicMock()
msg.value.return_value = test_obj
# Should still process (Cassandra doesn't enforce NOT NULL)
await processor.on_object(msg, None, None)
@ -261,6 +301,9 @@ class TestObjectsCassandraIntegration:
]
)
# Create collection first
await processor.create_collection("logger", "app_events")
# Process object
test_obj = ExtractedObject(
metadata=Metadata(id="e1", user="logger", collection="app_events", metadata=[]),
@ -269,10 +312,10 @@ class TestObjectsCassandraIntegration:
confidence=1.0,
source_span="Event"
)
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
# Verify synthetic_id was added
@ -325,8 +368,10 @@ class TestObjectsCassandraIntegration:
)
# Make insert fail
mock_result = MagicMock()
mock_result.one.return_value = MagicMock() # Keyspace exists
mock_session.execute.side_effect = [
None, # keyspace creation succeeds
mock_result, # keyspace existence check succeeds
None, # table creation succeeds
Exception("Connection timeout") # insert fails
]
@ -359,7 +404,11 @@ class TestObjectsCassandraIntegration:
# Process objects from different collections
collections = ["import_jan", "import_feb", "import_mar"]
# Create all collections first
for coll in collections:
await processor.create_collection("analytics", coll)
for coll in collections:
obj = ExtractedObject(
metadata=Metadata(id=f"{coll}-1", user="analytics", collection=coll, metadata=[]),
@ -368,7 +417,7 @@ class TestObjectsCassandraIntegration:
confidence=0.9,
source_span="Data"
)
msg = MagicMock()
msg.value.return_value = obj
await processor.on_object(msg, None, None)
@ -436,9 +485,12 @@ class TestObjectsCassandraIntegration:
source_span="Multiple customers extracted from document"
)
# Create collection first
await processor.create_collection("test_user", "batch_import")
msg = MagicMock()
msg.value.return_value = batch_obj
await processor.on_object(msg, None, None)
# Verify table creation
@ -479,6 +531,9 @@ class TestObjectsCassandraIntegration:
fields=[Field(name="id", type="string", size=50, primary=True)]
)
# Create collection first
await processor.create_collection("test", "empty")
# Process empty batch object
empty_obj = ExtractedObject(
metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]),
@ -487,10 +542,10 @@ class TestObjectsCassandraIntegration:
confidence=1.0,
source_span="No objects found"
)
msg = MagicMock()
msg.value.return_value = empty_obj
await processor.on_object(msg, None, None)
# Should still create table
@ -517,6 +572,9 @@ class TestObjectsCassandraIntegration:
]
)
# Create collection first
await processor.create_collection("test", "mixed")
# Single object (backward compatibility)
single_obj = ExtractedObject(
metadata=Metadata(id="single", user="test", collection="mixed", metadata=[]),
@ -525,7 +583,7 @@ class TestObjectsCassandraIntegration:
confidence=0.9,
source_span="Single object"
)
# Batch object
batch_obj = ExtractedObject(
metadata=Metadata(id="batch", user="test", collection="mixed", metadata=[]),
@ -537,7 +595,7 @@ class TestObjectsCassandraIntegration:
confidence=0.85,
source_span="Batch objects"
)
# Process both
for obj in [single_obj, batch_obj]:
msg = MagicMock()

View file

@ -60,13 +60,13 @@ class TestTextCompletionIntegration:
"""Create text completion processor with test configuration"""
# Create a minimal processor instance for testing generate_content
processor = MagicMock()
processor.model = processor_config["model"]
processor.default_model = processor_config["model"]
processor.temperature = processor_config["temperature"]
processor.max_output = processor_config["max_output"]
# Add the actual generate_content method from Processor class
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
return processor
@pytest.mark.asyncio
@ -112,11 +112,11 @@ class TestTextCompletionIntegration:
for config in test_configs:
# Arrange - Create minimal processor mock
processor = MagicMock()
processor.model = config['model']
processor.default_model = config['model']
processor.temperature = config['temperature']
processor.max_output = config['max_output']
processor.openai = mock_openai_client
# Add the actual generate_content method
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
@ -242,7 +242,7 @@ class TestTextCompletionIntegration:
processors = []
for i in range(5):
processor = MagicMock()
processor.model = processor_config["model"]
processor.default_model = processor_config["model"]
processor.temperature = processor_config["temperature"]
processor.max_output = processor_config["max_output"]
processor.openai = mock_openai_client
@ -348,7 +348,7 @@ class TestTextCompletionIntegration:
"""Test that model parameters are correctly passed to OpenAI API"""
# Arrange
processor = MagicMock()
processor.model = "gpt-4"
processor.default_model = "gpt-4"
processor.temperature = 0.8
processor.max_output = 2048
processor.openai = mock_openai_client

View file

@ -0,0 +1,238 @@
"""
Unit tests for Flow Parameter Specification functionality
Testing parameter specification registration and handling in flow processors
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
from trustgraph.base.flow_processor import FlowProcessor
from trustgraph.base import ParameterSpec, ConsumerSpec, ProducerSpec
class MockAsyncProcessor:
def __init__(self, **params):
self.config_handlers = []
self.id = params.get('id', 'test-service')
self.specifications = []
class TestFlowParameterSpecs(IsolatedAsyncioTestCase):
"""Test flow processor parameter specification functionality"""
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_parameter_spec_registration(self):
"""Test that parameter specs can be registered with flow processors"""
# Arrange
config = {
'id': 'test-flow-processor',
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
# Create test parameter specs
model_spec = ParameterSpec(name="model")
temperature_spec = ParameterSpec(name="temperature")
# Act
processor.register_specification(model_spec)
processor.register_specification(temperature_spec)
# Assert
assert len(processor.specifications) >= 2
param_specs = [spec for spec in processor.specifications
if isinstance(spec, ParameterSpec)]
assert len(param_specs) >= 2
param_names = [spec.name for spec in param_specs]
assert "model" in param_names
assert "temperature" in param_names
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_mixed_specification_types(self):
"""Test registration of mixed specification types (parameters, consumers, producers)"""
# Arrange
config = {
'id': 'test-flow-processor',
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
# Create different spec types
param_spec = ParameterSpec(name="model")
consumer_spec = ConsumerSpec(name="input", schema=MagicMock(), handler=MagicMock())
producer_spec = ProducerSpec(name="output", schema=MagicMock())
# Act
processor.register_specification(param_spec)
processor.register_specification(consumer_spec)
processor.register_specification(producer_spec)
# Assert
assert len(processor.specifications) == 3
# Count each type
param_specs = [s for s in processor.specifications if isinstance(s, ParameterSpec)]
consumer_specs = [s for s in processor.specifications if isinstance(s, ConsumerSpec)]
producer_specs = [s for s in processor.specifications if isinstance(s, ProducerSpec)]
assert len(param_specs) == 1
assert len(consumer_specs) == 1
assert len(producer_specs) == 1
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_parameter_spec_metadata(self):
"""Test parameter specification metadata handling"""
# Arrange
config = {
'id': 'test-flow-processor',
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
# Create parameter specs with metadata (if supported)
model_spec = ParameterSpec(name="model")
temperature_spec = ParameterSpec(name="temperature")
# Act
processor.register_specification(model_spec)
processor.register_specification(temperature_spec)
# Assert
param_specs = [spec for spec in processor.specifications
if isinstance(spec, ParameterSpec)]
model_spec_registered = next((s for s in param_specs if s.name == "model"), None)
temperature_spec_registered = next((s for s in param_specs if s.name == "temperature"), None)
assert model_spec_registered is not None
assert temperature_spec_registered is not None
assert model_spec_registered.name == "model"
assert temperature_spec_registered.name == "temperature"
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_duplicate_parameter_spec_handling(self):
"""Test handling of duplicate parameter spec registration"""
# Arrange
config = {
'id': 'test-flow-processor',
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
# Create duplicate parameter specs
model_spec1 = ParameterSpec(name="model")
model_spec2 = ParameterSpec(name="model")
# Act
processor.register_specification(model_spec1)
processor.register_specification(model_spec2)
# Assert - Should allow duplicates (or handle appropriately)
param_specs = [spec for spec in processor.specifications
if isinstance(spec, ParameterSpec) and spec.name == "model"]
# Either should have 2 duplicates or the system should handle deduplication
assert len(param_specs) >= 1 # At least one should be registered
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
@patch('trustgraph.base.flow_processor.Flow')
async def test_parameter_specs_available_to_flows(self, mock_flow_class):
"""Test that parameter specs are available when flows are created"""
# Arrange
config = {
'id': 'test-flow-processor',
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
processor.id = 'test-processor'
# Register parameter specs
model_spec = ParameterSpec(name="model")
temperature_spec = ParameterSpec(name="temperature")
processor.register_specification(model_spec)
processor.register_specification(temperature_spec)
mock_flow = AsyncMock()
mock_flow_class.return_value = mock_flow
flow_name = 'test-flow'
flow_defn = {'config': 'test-config'}
# Act
await processor.start_flow(flow_name, flow_defn)
# Assert - Flow should be created with access to processor specifications
mock_flow_class.assert_called_once_with('test-processor', flow_name, processor, flow_defn)
# The flow should have access to the processor's specifications
# (The exact mechanism depends on Flow implementation)
assert len(processor.specifications) >= 2
class TestParameterSpecValidation(IsolatedAsyncioTestCase):
"""Test parameter specification validation functionality"""
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_parameter_spec_name_validation(self):
"""Test parameter spec name validation"""
# Arrange
config = {
'id': 'test-flow-processor',
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
# Act & Assert - Valid parameter names
valid_specs = [
ParameterSpec(name="model"),
ParameterSpec(name="temperature"),
ParameterSpec(name="max_tokens"),
ParameterSpec(name="api_key")
]
for spec in valid_specs:
# Should not raise any exceptions
processor.register_specification(spec)
assert len([s for s in processor.specifications if isinstance(s, ParameterSpec)]) >= 4
def test_parameter_spec_creation_validation(self):
"""Test parameter spec creation with various inputs"""
# Test valid parameter spec creation
valid_specs = [
ParameterSpec(name="model"),
ParameterSpec(name="temperature"),
ParameterSpec(name="max_output"),
]
for spec in valid_specs:
assert spec.name is not None
assert isinstance(spec.name, str)
# Test edge cases (if parameter specs have validation)
# This depends on the actual ParameterSpec implementation
try:
empty_name_spec = ParameterSpec(name="")
# May or may not be valid depending on implementation
except Exception:
# If validation exists, it should catch invalid names
pass
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,264 @@
"""
Unit tests for LLM Service Parameter Specifications
Testing the new parameter-aware functionality added to the LLM base service
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
from trustgraph.base.llm_service import LlmService, LlmResult
from trustgraph.base import ParameterSpec, ConsumerSpec, ProducerSpec
from trustgraph.schema import TextCompletionRequest, TextCompletionResponse
class MockAsyncProcessor:
def __init__(self, **params):
self.config_handlers = []
self.id = params.get('id', 'test-service')
self.specifications = []
class TestLlmServiceParameters(IsolatedAsyncioTestCase):
"""Test LLM service parameter specification functionality"""
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_parameter_specs_registration(self):
"""Test that LLM service registers model and temperature parameter specs"""
# Arrange
config = {
'id': 'test-llm-service',
'concurrency': 1,
'taskgroup': AsyncMock() # Add required taskgroup
}
# Act
service = LlmService(**config)
# Assert
param_specs = {spec.name: spec for spec in service.specifications
if isinstance(spec, ParameterSpec)}
assert "model" in param_specs
assert "temperature" in param_specs
assert len(param_specs) >= 2 # May have other parameter specs
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_model_parameter_spec_properties(self):
"""Test that model parameter spec has correct properties"""
# Arrange
config = {
'id': 'test-llm-service',
'concurrency': 1,
'taskgroup': AsyncMock()
}
# Act
service = LlmService(**config)
# Assert
model_spec = None
for spec in service.specifications:
if isinstance(spec, ParameterSpec) and spec.name == "model":
model_spec = spec
break
assert model_spec is not None
assert model_spec.name == "model"
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_temperature_parameter_spec_properties(self):
"""Test that temperature parameter spec has correct properties"""
# Arrange
config = {
'id': 'test-llm-service',
'concurrency': 1,
'taskgroup': AsyncMock()
}
# Act
service = LlmService(**config)
# Assert
temperature_spec = None
for spec in service.specifications:
if isinstance(spec, ParameterSpec) and spec.name == "temperature":
temperature_spec = spec
break
assert temperature_spec is not None
assert temperature_spec.name == "temperature"
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_request_extracts_parameters_from_flow(self):
"""Test that on_request method extracts model and temperature from flow"""
# Arrange
config = {
'id': 'test-llm-service',
'concurrency': 1,
'taskgroup': AsyncMock()
}
service = LlmService(**config)
# Mock the metrics
service.text_completion_model_metric = MagicMock()
service.text_completion_model_metric.labels.return_value.info = AsyncMock()
# Mock the generate_content method to capture parameters
service.generate_content = AsyncMock(return_value=LlmResult(
text="test response",
in_token=10,
out_token=5,
model="gpt-4"
))
# Mock message and flow
mock_message = MagicMock()
mock_message.value.return_value = MagicMock()
mock_message.value.return_value.system = "system prompt"
mock_message.value.return_value.prompt = "user prompt"
mock_message.properties.return_value = {"id": "test-id"}
mock_consumer = MagicMock()
mock_consumer.name = "request"
mock_flow = MagicMock()
mock_flow.name = "test-flow"
mock_flow.return_value = "test-model" # flow("model") returns this
mock_flow.side_effect = lambda param: {
"model": "gpt-4",
"temperature": 0.7
}.get(param, f"mock-{param}")
mock_producer = AsyncMock()
mock_flow.producer = {"response": mock_producer}
# Act
await service.on_request(mock_message, mock_consumer, mock_flow)
# Assert
# Verify that generate_content was called with parameters from flow
service.generate_content.assert_called_once()
call_args = service.generate_content.call_args
assert call_args[0][0] == "system prompt" # system
assert call_args[0][1] == "user prompt" # prompt
assert call_args[0][2] == "gpt-4" # model
assert call_args[0][3] == 0.7 # temperature
# Verify flow was queried for both parameters
mock_flow.assert_any_call("model")
mock_flow.assert_any_call("temperature")
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_request_handles_missing_parameters_gracefully(self):
"""Test that on_request handles missing parameters gracefully"""
# Arrange
config = {
'id': 'test-llm-service',
'concurrency': 1,
'taskgroup': AsyncMock()
}
service = LlmService(**config)
# Mock the metrics
service.text_completion_model_metric = MagicMock()
service.text_completion_model_metric.labels.return_value.info = AsyncMock()
# Mock the generate_content method
service.generate_content = AsyncMock(return_value=LlmResult(
text="test response",
in_token=10,
out_token=5,
model="default-model"
))
# Mock message and flow where flow returns None for parameters
mock_message = MagicMock()
mock_message.value.return_value = MagicMock()
mock_message.value.return_value.system = "system prompt"
mock_message.value.return_value.prompt = "user prompt"
mock_message.properties.return_value = {"id": "test-id"}
mock_consumer = MagicMock()
mock_consumer.name = "request"
mock_flow = MagicMock()
mock_flow.name = "test-flow"
mock_flow.return_value = None # Both parameters return None
mock_producer = AsyncMock()
mock_flow.producer = {"response": mock_producer}
# Act
await service.on_request(mock_message, mock_consumer, mock_flow)
# Assert
# Should still call generate_content, with None values that will use processor defaults
service.generate_content.assert_called_once()
call_args = service.generate_content.call_args
assert call_args[0][0] == "system prompt" # system
assert call_args[0][1] == "user prompt" # prompt
assert call_args[0][2] is None # model (will use processor default)
assert call_args[0][3] is None # temperature (will use processor default)
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_request_error_handling_preserves_behavior(self):
"""Test that parameter extraction doesn't break existing error handling"""
# Arrange
config = {
'id': 'test-llm-service',
'concurrency': 1,
'taskgroup': AsyncMock()
}
service = LlmService(**config)
# Mock the metrics
service.text_completion_model_metric = MagicMock()
service.text_completion_model_metric.labels.return_value.info = AsyncMock()
# Mock generate_content to raise an exception
service.generate_content = AsyncMock(side_effect=Exception("Test error"))
# Mock message and flow
mock_message = MagicMock()
mock_message.value.return_value = MagicMock()
mock_message.value.return_value.system = "system prompt"
mock_message.value.return_value.prompt = "user prompt"
mock_message.properties.return_value = {"id": "test-id"}
mock_consumer = MagicMock()
mock_consumer.name = "request"
mock_flow = MagicMock()
mock_flow.name = "test-flow"
mock_flow.side_effect = lambda param: {
"model": "gpt-4",
"temperature": 0.7
}.get(param, f"mock-{param}")
mock_producer = AsyncMock()
mock_flow.producer = {"response": mock_producer}
# Act
await service.on_request(mock_message, mock_consumer, mock_flow)
# Assert
# Should have sent error response
mock_producer.send.assert_called_once()
error_response = mock_producer.send.call_args[0][0]
assert error_response.error is not None
assert error_response.error.type == "llm-error"
assert "Test error" in error_response.error.message
assert error_response.response is None
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -1,211 +1,236 @@
"""
Unit tests for trustgraph.chunking.recursive
Testing parameter override functionality for chunk-size and chunk-overlap
"""
import pytest
import asyncio
from unittest.mock import AsyncMock, Mock, patch, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.chunking.recursive.chunker import Processor
from trustgraph.schema import TextDocument, Chunk, Metadata
from trustgraph.chunking.recursive.chunker import Processor as RecursiveChunker
@pytest.fixture
def mock_flow():
output_mock = AsyncMock()
flow_mock = Mock(return_value=output_mock)
return flow_mock, output_mock
class MockAsyncProcessor:
def __init__(self, **params):
self.config_handlers = []
self.id = params.get('id', 'test-service')
self.specifications = []
@pytest.fixture
def mock_consumer():
consumer = Mock()
consumer.id = "test-consumer"
consumer.flow = "test-flow"
return consumer
class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
"""Test Recursive chunker functionality"""
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_processor_initialization_basic(self):
"""Test basic processor initialization"""
# Arrange
config = {
'id': 'test-chunker',
'chunk_size': 1500,
'chunk_overlap': 150,
'concurrency': 1,
'taskgroup': AsyncMock()
}
# Act
processor = Processor(**config)
# Assert
assert processor.default_chunk_size == 1500
assert processor.default_chunk_overlap == 150
assert hasattr(processor, 'text_splitter')
# Verify parameter specs are registered
param_specs = [spec for spec in processor.specifications
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
assert len(param_specs) == 2
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_chunk_size_override(self):
"""Test chunk_document with chunk-size parameter override"""
# Arrange
config = {
'id': 'test-chunker',
'chunk_size': 1000, # Default chunk size
'chunk_overlap': 100,
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = Processor(**config)
# Mock message and flow
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
"chunk-size": 2000, # Override chunk size
"chunk-overlap": None # Use default chunk overlap
}.get(param)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
mock_message, mock_consumer, mock_flow, 1000, 100
)
# Assert
assert chunk_size == 2000 # Should use overridden value
assert chunk_overlap == 100 # Should use default value
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_chunk_overlap_override(self):
"""Test chunk_document with chunk-overlap parameter override"""
# Arrange
config = {
'id': 'test-chunker',
'chunk_size': 1000,
'chunk_overlap': 100, # Default chunk overlap
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = Processor(**config)
# Mock message and flow
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
"chunk-size": None, # Use default chunk size
"chunk-overlap": 200 # Override chunk overlap
}.get(param)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
mock_message, mock_consumer, mock_flow, 1000, 100
)
# Assert
assert chunk_size == 1000 # Should use default value
assert chunk_overlap == 200 # Should use overridden value
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_both_parameters_override(self):
"""Test chunk_document with both chunk-size and chunk-overlap overrides"""
# Arrange
config = {
'id': 'test-chunker',
'chunk_size': 1000,
'chunk_overlap': 100,
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = Processor(**config)
# Mock message and flow
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
"chunk-size": 1500, # Override chunk size
"chunk-overlap": 150 # Override chunk overlap
}.get(param)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
mock_message, mock_consumer, mock_flow, 1000, 100
)
# Assert
assert chunk_size == 1500 # Should use overridden value
assert chunk_overlap == 150 # Should use overridden value
@patch('trustgraph.chunking.recursive.chunker.RecursiveCharacterTextSplitter')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_uses_flow_parameters(self, mock_splitter_class):
"""Test that on_message method uses parameters from flow"""
# Arrange
mock_splitter = MagicMock()
mock_document = MagicMock()
mock_document.page_content = "Test chunk content"
mock_splitter.create_documents.return_value = [mock_document]
mock_splitter_class.return_value = mock_splitter
config = {
'id': 'test-chunker',
'chunk_size': 1000,
'chunk_overlap': 100,
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = Processor(**config)
# Mock message with TextDocument
mock_message = MagicMock()
mock_text_doc = MagicMock()
mock_text_doc.metadata = Metadata(
id="test-doc-123",
metadata=[],
user="test-user",
collection="test-collection"
)
mock_text_doc.text = b"This is test document content"
mock_message.value.return_value = mock_text_doc
# Mock consumer and flow with parameter overrides
mock_consumer = MagicMock()
mock_producer = AsyncMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
"chunk-size": 1500,
"chunk-overlap": 150,
"output": mock_producer
}.get(param)
# Act
await processor.on_message(mock_message, mock_consumer, mock_flow)
# Assert
# Verify RecursiveCharacterTextSplitter was called with overridden parameters (last call)
actual_last_call = mock_splitter_class.call_args_list[-1]
assert actual_last_call.kwargs['chunk_size'] == 1500
assert actual_last_call.kwargs['chunk_overlap'] == 150
assert actual_last_call.kwargs['length_function'] == len
assert actual_last_call.kwargs['is_separator_regex'] == False
# Verify chunk was sent to output
mock_producer.send.assert_called_once()
sent_chunk = mock_producer.send.call_args[0][0]
assert isinstance(sent_chunk, Chunk)
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_no_overrides(self):
"""Test chunk_document when no parameters are overridden (flow returns None)"""
# Arrange
config = {
'id': 'test-chunker',
'chunk_size': 1000,
'chunk_overlap': 100,
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = Processor(**config)
# Mock message and flow that returns None for all parameters
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.return_value = None # No overrides
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
mock_message, mock_consumer, mock_flow, 1000, 100
)
# Assert
assert chunk_size == 1000 # Should use default value
assert chunk_overlap == 100 # Should use default value
@pytest.fixture
def sample_document():
metadata = Metadata(
id="test-doc-1",
metadata=[],
user="test-user",
collection="test-collection"
)
text = "This is a test document. " * 100 # Create text long enough to be chunked
return TextDocument(
metadata=metadata,
text=text.encode("utf-8")
)
@pytest.fixture
def short_document():
metadata = Metadata(
id="test-doc-2",
metadata=[],
user="test-user",
collection="test-collection"
)
text = "This is a very short document."
return TextDocument(
metadata=metadata,
text=text.encode("utf-8")
)
class TestRecursiveChunker:
def test_init_default_params(self, mock_async_processor_init):
processor = RecursiveChunker()
assert processor.text_splitter._chunk_size == 2000
assert processor.text_splitter._chunk_overlap == 100
def test_init_custom_params(self, mock_async_processor_init):
processor = RecursiveChunker(chunk_size=500, chunk_overlap=50)
assert processor.text_splitter._chunk_size == 500
assert processor.text_splitter._chunk_overlap == 50
def test_init_with_id(self, mock_async_processor_init):
processor = RecursiveChunker(id="custom-chunker")
assert processor.id == "custom-chunker"
@pytest.mark.asyncio
async def test_on_message_single_chunk(self, mock_async_processor_init, mock_flow, mock_consumer, short_document):
flow_mock, output_mock = mock_flow
processor = RecursiveChunker(chunk_size=2000, chunk_overlap=100)
msg = Mock()
msg.value.return_value = short_document
await processor.on_message(msg, mock_consumer, flow_mock)
# Should produce exactly one chunk for short text
assert output_mock.send.call_count == 1
# Verify the chunk was created correctly
chunk_call = output_mock.send.call_args[0][0]
assert isinstance(chunk_call, Chunk)
assert chunk_call.metadata == short_document.metadata
assert chunk_call.chunk.decode("utf-8") == short_document.text.decode("utf-8")
@pytest.mark.asyncio
async def test_on_message_multiple_chunks(self, mock_async_processor_init, mock_flow, mock_consumer, sample_document):
flow_mock, output_mock = mock_flow
processor = RecursiveChunker(chunk_size=100, chunk_overlap=20)
msg = Mock()
msg.value.return_value = sample_document
await processor.on_message(msg, mock_consumer, flow_mock)
# Should produce multiple chunks
assert output_mock.send.call_count > 1
# Verify all chunks have correct metadata
for call in output_mock.send.call_args_list:
chunk = call[0][0]
assert isinstance(chunk, Chunk)
assert chunk.metadata == sample_document.metadata
assert len(chunk.chunk) > 0
@pytest.mark.asyncio
async def test_on_message_chunk_overlap(self, mock_async_processor_init, mock_flow, mock_consumer):
flow_mock, output_mock = mock_flow
processor = RecursiveChunker(chunk_size=50, chunk_overlap=10)
# Create a document with predictable content
metadata = Metadata(id="test", metadata=[], user="test-user", collection="test-collection")
text = "ABCDEFGHIJ" * 10 # 100 characters
document = TextDocument(metadata=metadata, text=text.encode("utf-8"))
msg = Mock()
msg.value.return_value = document
await processor.on_message(msg, mock_consumer, flow_mock)
# Collect all chunks
chunks = []
for call in output_mock.send.call_args_list:
chunk_text = call[0][0].chunk.decode("utf-8")
chunks.append(chunk_text)
# Verify chunks have expected overlap
for i in range(len(chunks) - 1):
# The end of chunk i should overlap with the beginning of chunk i+1
# Check if there's some overlap (exact overlap depends on text splitter logic)
assert len(chunks[i]) <= 50 + 10 # chunk_size + some tolerance
@pytest.mark.asyncio
async def test_on_message_empty_document(self, mock_async_processor_init, mock_flow, mock_consumer):
flow_mock, output_mock = mock_flow
processor = RecursiveChunker()
metadata = Metadata(id="empty", metadata=[], user="test-user", collection="test-collection")
document = TextDocument(metadata=metadata, text=b"")
msg = Mock()
msg.value.return_value = document
await processor.on_message(msg, mock_consumer, flow_mock)
# Empty documents typically don't produce chunks with langchain splitters
# This behavior is expected - no chunks should be produced
assert output_mock.send.call_count == 0
@pytest.mark.asyncio
async def test_on_message_unicode_handling(self, mock_async_processor_init, mock_flow, mock_consumer):
flow_mock, output_mock = mock_flow
processor = RecursiveChunker(chunk_size=500, chunk_overlap=20) # Fixed overlap < chunk_size
metadata = Metadata(id="unicode", metadata=[], user="test-user", collection="test-collection")
text = "Hello 世界! 🌍 This is a test with émojis and spëcial characters."
document = TextDocument(metadata=metadata, text=text.encode("utf-8"))
msg = Mock()
msg.value.return_value = document
await processor.on_message(msg, mock_consumer, flow_mock)
# Verify unicode is preserved correctly
all_chunks = []
for call in output_mock.send.call_args_list:
chunk_text = call[0][0].chunk.decode("utf-8")
all_chunks.append(chunk_text)
# Reconstruct text (approximately, due to overlap)
reconstructed = "".join(all_chunks)
assert "世界" in reconstructed
assert "🌍" in reconstructed
assert "émojis" in reconstructed
@pytest.mark.asyncio
async def test_metrics_recorded(self, mock_async_processor_init, mock_flow, mock_consumer, sample_document):
flow_mock, output_mock = mock_flow
processor = RecursiveChunker(chunk_size=100)
msg = Mock()
msg.value.return_value = sample_document
# Mock the metric
with patch.object(RecursiveChunker.chunk_metric, 'labels') as mock_labels:
mock_observe = Mock()
mock_labels.return_value.observe = mock_observe
await processor.on_message(msg, mock_consumer, flow_mock)
# Verify metrics were recorded
mock_labels.assert_called_with(id="test-consumer", flow="test-flow")
assert mock_observe.call_count > 0
# Verify chunk sizes were observed
for call in mock_observe.call_args_list:
chunk_size = call[0][0]
assert chunk_size > 0
def test_add_args(self):
parser = Mock()
RecursiveChunker.add_args(parser)
# Verify arguments were added
calls = parser.add_argument.call_args_list
arg_names = [call[0][0] for call in calls]
assert '-z' in arg_names or '--chunk-size' in arg_names
assert '-v' in arg_names or '--chunk-overlap' in arg_names
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -1,275 +1,256 @@
"""
Unit tests for trustgraph.chunking.token
Testing parameter override functionality for chunk-size and chunk-overlap
"""
import pytest
import asyncio
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.chunking.token.chunker import Processor
from trustgraph.schema import TextDocument, Chunk, Metadata
from trustgraph.chunking.token.chunker import Processor as TokenChunker
@pytest.fixture
def mock_flow():
output_mock = AsyncMock()
flow_mock = Mock(return_value=output_mock)
return flow_mock, output_mock
class MockAsyncProcessor:
def __init__(self, **params):
self.config_handlers = []
self.id = params.get('id', 'test-service')
self.specifications = []
@pytest.fixture
def mock_consumer():
consumer = Mock()
consumer.id = "test-consumer"
consumer.flow = "test-flow"
return consumer
class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
"""Test Token chunker functionality"""
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_processor_initialization_basic(self):
"""Test basic processor initialization"""
# Arrange
config = {
'id': 'test-chunker',
'chunk_size': 300,
'chunk_overlap': 20,
'concurrency': 1,
'taskgroup': AsyncMock()
}
# Act
processor = Processor(**config)
# Assert
assert processor.default_chunk_size == 300
assert processor.default_chunk_overlap == 20
assert hasattr(processor, 'text_splitter')
# Verify parameter specs are registered
param_specs = [spec for spec in processor.specifications
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
assert len(param_specs) == 2
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_chunk_size_override(self):
"""Test chunk_document with chunk-size parameter override"""
# Arrange
config = {
'id': 'test-chunker',
'chunk_size': 250, # Default chunk size
'chunk_overlap': 15,
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = Processor(**config)
# Mock message and flow
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
"chunk-size": 400, # Override chunk size
"chunk-overlap": None # Use default chunk overlap
}.get(param)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
mock_message, mock_consumer, mock_flow, 250, 15
)
# Assert
assert chunk_size == 400 # Should use overridden value
assert chunk_overlap == 15 # Should use default value
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_chunk_overlap_override(self):
"""Test chunk_document with chunk-overlap parameter override"""
# Arrange
config = {
'id': 'test-chunker',
'chunk_size': 250,
'chunk_overlap': 15, # Default chunk overlap
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = Processor(**config)
# Mock message and flow
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
"chunk-size": None, # Use default chunk size
"chunk-overlap": 25 # Override chunk overlap
}.get(param)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
mock_message, mock_consumer, mock_flow, 250, 15
)
# Assert
assert chunk_size == 250 # Should use default value
assert chunk_overlap == 25 # Should use overridden value
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_both_parameters_override(self):
"""Test chunk_document with both chunk-size and chunk-overlap overrides"""
# Arrange
config = {
'id': 'test-chunker',
'chunk_size': 250,
'chunk_overlap': 15,
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = Processor(**config)
# Mock message and flow
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
"chunk-size": 350, # Override chunk size
"chunk-overlap": 30 # Override chunk overlap
}.get(param)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
mock_message, mock_consumer, mock_flow, 250, 15
)
# Assert
assert chunk_size == 350 # Should use overridden value
assert chunk_overlap == 30 # Should use overridden value
@patch('trustgraph.chunking.token.chunker.TokenTextSplitter')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_uses_flow_parameters(self, mock_splitter_class):
"""Test that on_message method uses parameters from flow"""
# Arrange
mock_splitter = MagicMock()
mock_document = MagicMock()
mock_document.page_content = "Test token chunk content"
mock_splitter.create_documents.return_value = [mock_document]
mock_splitter_class.return_value = mock_splitter
config = {
'id': 'test-chunker',
'chunk_size': 250,
'chunk_overlap': 15,
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = Processor(**config)
# Mock message with TextDocument
mock_message = MagicMock()
mock_text_doc = MagicMock()
mock_text_doc.metadata = Metadata(
id="test-doc-456",
metadata=[],
user="test-user",
collection="test-collection"
)
mock_text_doc.text = b"This is test document content for token chunking"
mock_message.value.return_value = mock_text_doc
# Mock consumer and flow with parameter overrides
mock_consumer = MagicMock()
mock_producer = AsyncMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
"chunk-size": 400,
"chunk-overlap": 40,
"output": mock_producer
}.get(param)
# Act
await processor.on_message(mock_message, mock_consumer, mock_flow)
# Assert
# Verify TokenTextSplitter was called with overridden parameters (last call)
expected_call = [
('encoding_name', 'cl100k_base'),
('chunk_size', 400),
('chunk_overlap', 40)
]
actual_last_call = mock_splitter_class.call_args_list[-1]
assert actual_last_call.kwargs['encoding_name'] == "cl100k_base"
assert actual_last_call.kwargs['chunk_size'] == 400
assert actual_last_call.kwargs['chunk_overlap'] == 40
# Verify chunk was sent to output
mock_producer.send.assert_called_once()
sent_chunk = mock_producer.send.call_args[0][0]
assert isinstance(sent_chunk, Chunk)
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_no_overrides(self):
"""Test chunk_document when no parameters are overridden (flow returns None)"""
# Arrange
config = {
'id': 'test-chunker',
'chunk_size': 250,
'chunk_overlap': 15,
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = Processor(**config)
# Mock message and flow that returns None for all parameters
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.return_value = None # No overrides
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
mock_message, mock_consumer, mock_flow, 250, 15
)
# Assert
assert chunk_size == 250 # Should use default value
assert chunk_overlap == 15 # Should use default value
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_token_chunker_uses_different_defaults(self):
"""Test that token chunker has different defaults than recursive chunker"""
# Arrange & Act
config = {
'id': 'test-chunker',
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = Processor(**config)
# Assert - Token chunker should have different defaults
assert processor.default_chunk_size == 250 # Token chunker default
assert processor.default_chunk_overlap == 15 # Token chunker default
@pytest.fixture
def sample_document():
metadata = Metadata(
id="test-doc-1",
metadata=[],
user="test-user",
collection="test-collection"
)
# Create text that will result in multiple token chunks
text = "The quick brown fox jumps over the lazy dog. " * 50
return TextDocument(
metadata=metadata,
text=text.encode("utf-8")
)
@pytest.fixture
def short_document():
metadata = Metadata(
id="test-doc-2",
metadata=[],
user="test-user",
collection="test-collection"
)
text = "Short text."
return TextDocument(
metadata=metadata,
text=text.encode("utf-8")
)
class TestTokenChunker:
def test_init_default_params(self, mock_async_processor_init):
processor = TokenChunker()
assert processor.text_splitter._chunk_size == 250
assert processor.text_splitter._chunk_overlap == 15
# Just verify the text splitter was created (encoding verification is complex)
assert processor.text_splitter is not None
assert hasattr(processor.text_splitter, 'split_text')
def test_init_custom_params(self, mock_async_processor_init):
processor = TokenChunker(chunk_size=100, chunk_overlap=10)
assert processor.text_splitter._chunk_size == 100
assert processor.text_splitter._chunk_overlap == 10
def test_init_with_id(self, mock_async_processor_init):
processor = TokenChunker(id="custom-token-chunker")
assert processor.id == "custom-token-chunker"
@pytest.mark.asyncio
async def test_on_message_single_chunk(self, mock_async_processor_init, mock_flow, mock_consumer, short_document):
flow_mock, output_mock = mock_flow
processor = TokenChunker(chunk_size=250, chunk_overlap=15)
msg = Mock()
msg.value.return_value = short_document
await processor.on_message(msg, mock_consumer, flow_mock)
# Short text should produce exactly one chunk
assert output_mock.send.call_count == 1
# Verify the chunk was created correctly
chunk_call = output_mock.send.call_args[0][0]
assert isinstance(chunk_call, Chunk)
assert chunk_call.metadata == short_document.metadata
assert chunk_call.chunk.decode("utf-8") == short_document.text.decode("utf-8")
@pytest.mark.asyncio
async def test_on_message_multiple_chunks(self, mock_async_processor_init, mock_flow, mock_consumer, sample_document):
flow_mock, output_mock = mock_flow
processor = TokenChunker(chunk_size=50, chunk_overlap=5)
msg = Mock()
msg.value.return_value = sample_document
await processor.on_message(msg, mock_consumer, flow_mock)
# Should produce multiple chunks
assert output_mock.send.call_count > 1
# Verify all chunks have correct metadata
for call in output_mock.send.call_args_list:
chunk = call[0][0]
assert isinstance(chunk, Chunk)
assert chunk.metadata == sample_document.metadata
assert len(chunk.chunk) > 0
@pytest.mark.asyncio
async def test_on_message_token_overlap(self, mock_async_processor_init, mock_flow, mock_consumer):
flow_mock, output_mock = mock_flow
processor = TokenChunker(chunk_size=20, chunk_overlap=5)
# Create a document with repeated pattern
metadata = Metadata(id="test", metadata=[], user="test-user", collection="test-collection")
text = "one two three four five six seven eight nine ten " * 5
document = TextDocument(metadata=metadata, text=text.encode("utf-8"))
msg = Mock()
msg.value.return_value = document
await processor.on_message(msg, mock_consumer, flow_mock)
# Collect all chunks
chunks = []
for call in output_mock.send.call_args_list:
chunk_text = call[0][0].chunk.decode("utf-8")
chunks.append(chunk_text)
# Should have multiple chunks
assert len(chunks) > 1
# Verify chunks are not empty
for chunk in chunks:
assert len(chunk) > 0
@pytest.mark.asyncio
async def test_on_message_empty_document(self, mock_async_processor_init, mock_flow, mock_consumer):
flow_mock, output_mock = mock_flow
processor = TokenChunker()
metadata = Metadata(id="empty", metadata=[], user="test-user", collection="test-collection")
document = TextDocument(metadata=metadata, text=b"")
msg = Mock()
msg.value.return_value = document
await processor.on_message(msg, mock_consumer, flow_mock)
# Empty documents typically don't produce chunks with langchain splitters
# This behavior is expected - no chunks should be produced
assert output_mock.send.call_count == 0
@pytest.mark.asyncio
async def test_on_message_unicode_handling(self, mock_async_processor_init, mock_flow, mock_consumer):
flow_mock, output_mock = mock_flow
processor = TokenChunker(chunk_size=50)
metadata = Metadata(id="unicode", metadata=[], user="test-user", collection="test-collection")
# Test with various unicode characters
text = "Hello 世界! 🌍 Test émojis café naïve résumé. Greek: αβγδε Hebrew: אבגדה"
document = TextDocument(metadata=metadata, text=text.encode("utf-8"))
msg = Mock()
msg.value.return_value = document
await processor.on_message(msg, mock_consumer, flow_mock)
# Verify unicode is preserved correctly
all_chunks = []
for call in output_mock.send.call_args_list:
chunk_text = call[0][0].chunk.decode("utf-8")
all_chunks.append(chunk_text)
# Reconstruct text
reconstructed = "".join(all_chunks)
assert "世界" in reconstructed
assert "🌍" in reconstructed
assert "émojis" in reconstructed
assert "αβγδε" in reconstructed
assert "אבגדה" in reconstructed
@pytest.mark.asyncio
async def test_on_message_token_boundary_preservation(self, mock_async_processor_init, mock_flow, mock_consumer):
flow_mock, output_mock = mock_flow
processor = TokenChunker(chunk_size=10, chunk_overlap=2)
metadata = Metadata(id="boundary", metadata=[], user="test-user", collection="test-collection")
# Text with clear word boundaries
text = "This is a test of token boundaries and proper splitting."
document = TextDocument(metadata=metadata, text=text.encode("utf-8"))
msg = Mock()
msg.value.return_value = document
await processor.on_message(msg, mock_consumer, flow_mock)
# Collect all chunks
chunks = []
for call in output_mock.send.call_args_list:
chunk_text = call[0][0].chunk.decode("utf-8")
chunks.append(chunk_text)
# Token chunker should respect token boundaries
for chunk in chunks:
# Chunks should not start or end with partial words (in most cases)
assert len(chunk.strip()) > 0
@pytest.mark.asyncio
async def test_metrics_recorded(self, mock_async_processor_init, mock_flow, mock_consumer, sample_document):
flow_mock, output_mock = mock_flow
processor = TokenChunker(chunk_size=50)
msg = Mock()
msg.value.return_value = sample_document
# Mock the metric
with patch.object(TokenChunker.chunk_metric, 'labels') as mock_labels:
mock_observe = Mock()
mock_labels.return_value.observe = mock_observe
await processor.on_message(msg, mock_consumer, flow_mock)
# Verify metrics were recorded
mock_labels.assert_called_with(id="test-consumer", flow="test-flow")
assert mock_observe.call_count > 0
# Verify chunk sizes were observed
for call in mock_observe.call_args_list:
chunk_size = call[0][0]
assert chunk_size > 0
def test_add_args(self):
parser = Mock()
TokenChunker.add_args(parser)
# Verify arguments were added
calls = parser.add_argument.call_args_list
arg_names = [call[0][0] for call in calls]
assert '-z' in arg_names or '--chunk-size' in arg_names
assert '-v' in arg_names or '--chunk-overlap' in arg_names
@pytest.mark.asyncio
async def test_encoding_specific_behavior(self, mock_async_processor_init, mock_flow, mock_consumer):
flow_mock, output_mock = mock_flow
processor = TokenChunker(chunk_size=10, chunk_overlap=0)
metadata = Metadata(id="encoding", metadata=[], user="test-user", collection="test-collection")
# Test text that might tokenize differently with cl100k_base encoding
text = "GPT-4 is an AI model. It uses tokens."
document = TextDocument(metadata=metadata, text=text.encode("utf-8"))
msg = Mock()
msg.value.return_value = document
await processor.on_message(msg, mock_consumer, flow_mock)
# Verify chunking happened
assert output_mock.send.call_count >= 1
# Collect all chunks
chunks = []
for call in output_mock.send.call_args_list:
chunk_text = call[0][0].chunk.decode("utf-8")
chunks.append(chunk_text)
# Verify all text is preserved (allowing for overlap)
all_text = " ".join(chunks)
assert "GPT-4" in all_text
assert "AI model" in all_text
assert "tokens" in all_text
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -34,7 +34,9 @@ class TestGraphRag:
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
assert graph_rag.triples_client == mock_triples_client
assert graph_rag.verbose is False # Default value
assert graph_rag.label_cache == {} # Empty cache initially
# Verify label_cache is an LRUCacheWithTTL instance
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
def test_graph_rag_initialization_with_verbose(self):
"""Test GraphRag initialization with verbose enabled"""
@ -59,7 +61,9 @@ class TestGraphRag:
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
assert graph_rag.triples_client == mock_triples_client
assert graph_rag.verbose is True
assert graph_rag.label_cache == {} # Empty cache initially
# Verify label_cache is an LRUCacheWithTTL instance
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
class TestQuery:
@ -228,8 +232,11 @@ class TestQuery:
"""Test Query.maybe_label method with cached label"""
# Create mock GraphRag with label cache
mock_rag = MagicMock()
mock_rag.label_cache = {"entity1": "Entity One Label"}
# Create mock LRUCacheWithTTL
mock_cache = MagicMock()
mock_cache.get.return_value = "Entity One Label"
mock_rag.label_cache = mock_cache
# Initialize Query
query = Query(
rag=mock_rag,
@ -237,27 +244,32 @@ class TestQuery:
collection="test_collection",
verbose=False
)
# Call maybe_label with cached entity
result = await query.maybe_label("entity1")
# Verify cached label is returned
assert result == "Entity One Label"
# Verify cache was checked with proper key format (user:collection:entity)
mock_cache.get.assert_called_once_with("test_user:test_collection:entity1")
@pytest.mark.asyncio
async def test_maybe_label_with_label_lookup(self):
"""Test Query.maybe_label method with database label lookup"""
# Create mock GraphRag with triples client
mock_rag = MagicMock()
mock_rag.label_cache = {} # Empty cache
# Create mock LRUCacheWithTTL that returns None (cache miss)
mock_cache = MagicMock()
mock_cache.get.return_value = None
mock_rag.label_cache = mock_cache
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
# Mock triple result with label
mock_triple = MagicMock()
mock_triple.o = "Human Readable Label"
mock_triples_client.query.return_value = [mock_triple]
# Initialize Query
query = Query(
rag=mock_rag,
@ -265,10 +277,10 @@ class TestQuery:
collection="test_collection",
verbose=False
)
# Call maybe_label
result = await query.maybe_label("http://example.com/entity")
# Verify triples client was called correctly
mock_triples_client.query.assert_called_once_with(
s="http://example.com/entity",
@ -278,17 +290,21 @@ class TestQuery:
user="test_user",
collection="test_collection"
)
# Verify result and cache update
# Verify result and cache update with proper key
assert result == "Human Readable Label"
assert mock_rag.label_cache["http://example.com/entity"] == "Human Readable Label"
cache_key = "test_user:test_collection:http://example.com/entity"
mock_cache.put.assert_called_once_with(cache_key, "Human Readable Label")
@pytest.mark.asyncio
async def test_maybe_label_with_no_label_found(self):
"""Test Query.maybe_label method when no label is found"""
# Create mock GraphRag with triples client
mock_rag = MagicMock()
mock_rag.label_cache = {} # Empty cache
# Create mock LRUCacheWithTTL that returns None (cache miss)
mock_cache = MagicMock()
mock_cache.get.return_value = None
mock_rag.label_cache = mock_cache
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
@ -318,7 +334,8 @@ class TestQuery:
# Verify result is entity itself and cache is updated
assert result == "unlabeled_entity"
assert mock_rag.label_cache["unlabeled_entity"] == "unlabeled_entity"
cache_key = "test_user:test_collection:unlabeled_entity"
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
@pytest.mark.asyncio
async def test_follow_edges_basic_functionality(self):
@ -441,40 +458,40 @@ class TestQuery:
@pytest.mark.asyncio
async def test_get_subgraph_method(self):
"""Test Query.get_subgraph method orchestrates entity and edge discovery"""
# Create mock Query that patches get_entities and follow_edges
# Create mock Query that patches get_entities and follow_edges_batch
mock_rag = MagicMock()
query = Query(
rag=mock_rag,
user="test_user",
user="test_user",
collection="test_collection",
verbose=False,
max_path_length=1
)
# Mock get_entities to return test entities
query.get_entities = AsyncMock(return_value=["entity1", "entity2"])
# Mock follow_edges to add triples to subgraph
async def mock_follow_edges(ent, subgraph, path_length):
subgraph.add((ent, "predicate", "object"))
query.follow_edges = AsyncMock(side_effect=mock_follow_edges)
# Mock follow_edges_batch to return test triples
query.follow_edges_batch = AsyncMock(return_value={
("entity1", "predicate1", "object1"),
("entity2", "predicate2", "object2")
})
# Call get_subgraph
result = await query.get_subgraph("test query")
# Verify get_entities was called
query.get_entities.assert_called_once_with("test query")
# Verify follow_edges was called for each entity
assert query.follow_edges.call_count == 2
query.follow_edges.assert_any_call("entity1", unittest.mock.ANY, 1)
query.follow_edges.assert_any_call("entity2", unittest.mock.ANY, 1)
# Verify result is list format
# Verify follow_edges_batch was called with entities and max_path_length
query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1)
# Verify result is list format and contains expected triples
assert isinstance(result, list)
assert len(result) == 2
assert ("entity1", "predicate1", "object1") in result
assert ("entity2", "predicate2", "object2") in result
@pytest.mark.asyncio
async def test_get_labelgraph_method(self):

View file

@ -178,37 +178,24 @@ class TestPineconeDocEmbeddingsStorageProcessor:
assert calls[2][1]['vectors'][0]['metadata']['doc'] == "This is the second document chunk"
@pytest.mark.asyncio
async def test_store_document_embeddings_index_creation(self, processor):
"""Test automatic index creation when index doesn't exist"""
async def test_store_document_embeddings_index_validation(self, processor):
"""Test that writing to non-existent index raises ValueError"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
# Mock index doesn't exist initially
# Mock index doesn't exist
processor.pinecone.has_index.return_value = False
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Mock index creation
processor.pinecone.describe_index.return_value.status = {"ready": True}
with patch('uuid.uuid4', return_value='test-id'):
with pytest.raises(ValueError, match="Collection .* does not exist"):
await processor.store_document_embeddings(message)
# Verify index creation was called
expected_index_name = "d-test_user-test_collection"
processor.pinecone.create_index.assert_called_once()
create_call = processor.pinecone.create_index.call_args
assert create_call[1]['name'] == expected_index_name
assert create_call[1]['dimension'] == 3
assert create_call[1]['metric'] == "cosine"
@pytest.mark.asyncio
async def test_store_document_embeddings_empty_chunk(self, processor):
@ -357,47 +344,44 @@ class TestPineconeDocEmbeddingsStorageProcessor:
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_index_creation_failure(self, processor):
"""Test handling of index creation failure"""
async def test_store_document_embeddings_validation_before_creation(self, processor):
"""Test that validation error occurs before creation attempts"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
# Mock index doesn't exist and creation fails
# Mock index doesn't exist
processor.pinecone.has_index.return_value = False
processor.pinecone.create_index.side_effect = Exception("Index creation failed")
with pytest.raises(Exception, match="Index creation failed"):
with pytest.raises(ValueError, match="Collection .* does not exist"):
await processor.store_document_embeddings(message)
@pytest.mark.asyncio
async def test_store_document_embeddings_index_creation_timeout(self, processor):
"""Test handling of index creation timeout"""
async def test_store_document_embeddings_validates_before_timeout(self, processor):
"""Test that validation error occurs before timeout checks"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
# Mock index doesn't exist and never becomes ready
# Mock index doesn't exist
processor.pinecone.has_index.return_value = False
processor.pinecone.describe_index.return_value.status = {"ready": False}
with patch('time.sleep'): # Speed up the test
with pytest.raises(RuntimeError, match="Gave up waiting for index creation"):
await processor.store_document_embeddings(message)
with pytest.raises(ValueError, match="Collection .* does not exist"):
await processor.store_document_embeddings(message)
@pytest.mark.asyncio
async def test_store_document_embeddings_unicode_content(self, processor):

View file

@ -43,8 +43,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Verify processor attributes
assert hasattr(processor, 'qdrant')
assert processor.qdrant == mock_qdrant_instance
assert hasattr(processor, 'last_collection')
assert processor.last_collection is None
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
@ -245,8 +243,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True # Collection exists
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
@ -255,36 +254,37 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# Create mock message with empty chunk
mock_message = MagicMock()
mock_message.metadata.user = 'empty_user'
mock_message.metadata.collection = 'empty_collection'
mock_chunk_empty = MagicMock()
mock_chunk_empty.chunk.decode.return_value = "" # Empty string
mock_chunk_empty.vectors = [[0.1, 0.2]]
mock_message.chunks = [mock_chunk_empty]
# Act
await processor.store_document_embeddings(mock_message)
# Assert
# Should not call upsert for empty chunks
mock_qdrant_instance.upsert.assert_not_called()
mock_qdrant_instance.collection_exists.assert_not_called()
# But collection_exists should be called for validation
mock_qdrant_instance.collection_exists.assert_called_once()
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
async def test_collection_creation_when_not_exists(self, mock_base_init, mock_qdrant_client):
"""Test collection creation when it doesn't exist"""
"""Test that writing to non-existent collection raises ValueError"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
@ -293,46 +293,32 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'new_user'
mock_message.metadata.collection = 'new_collection'
mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'test chunk'
mock_chunk.vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] # 5 dimensions
mock_message.chunks = [mock_chunk]
# Act
await processor.store_document_embeddings(mock_message)
# Assert
expected_collection = 'd_new_user_new_collection'
# Verify collection existence check and creation
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
mock_qdrant_instance.create_collection.assert_called_once()
# Verify create_collection was called with correct parameters
create_call_args = mock_qdrant_instance.create_collection.call_args
assert create_call_args[1]['collection_name'] == expected_collection
# Verify upsert was still called after collection creation
mock_qdrant_instance.upsert.assert_called_once()
mock_message.chunks = [mock_chunk]
# Act & Assert
with pytest.raises(ValueError, match="Collection .* does not exist"):
await processor.store_document_embeddings(mock_message)
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
async def test_collection_creation_exception(self, mock_base_init, mock_qdrant_client):
"""Test collection creation handles exceptions"""
"""Test that validation error occurs before connection errors"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = False
mock_qdrant_instance.create_collection.side_effect = Exception("Qdrant connection failed")
mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
@ -341,32 +327,35 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'error_user'
mock_message.metadata.collection = 'error_collection'
mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'test chunk'
mock_chunk.vectors = [[0.1, 0.2]]
mock_message.chunks = [mock_chunk]
# Act & Assert
with pytest.raises(Exception, match="Qdrant connection failed"):
with pytest.raises(ValueError, match="Collection .* does not exist"):
await processor.store_document_embeddings(mock_message)
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
async def test_collection_caching_behavior(self, mock_base_init, mock_qdrant_client):
"""Test collection caching with last_collection"""
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
async def test_collection_validation_on_write(self, mock_uuid, mock_base_init, mock_qdrant_client):
"""Test collection validation checks collection exists before writing"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value = MagicMock()
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
@ -375,46 +364,45 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# Create first mock message
mock_message1 = MagicMock()
mock_message1.metadata.user = 'cache_user'
mock_message1.metadata.collection = 'cache_collection'
mock_chunk1 = MagicMock()
mock_chunk1.chunk.decode.return_value = 'first chunk'
mock_chunk1.vectors = [[0.1, 0.2, 0.3]]
mock_message1.chunks = [mock_chunk1]
# First call
await processor.store_document_embeddings(mock_message1)
# Reset mock to track second call
mock_qdrant_instance.reset_mock()
mock_qdrant_instance.collection_exists.return_value = True
# Create second mock message with same dimensions
mock_message2 = MagicMock()
mock_message2.metadata.user = 'cache_user'
mock_message2.metadata.collection = 'cache_collection'
mock_chunk2 = MagicMock()
mock_chunk2.chunk.decode.return_value = 'second chunk'
mock_chunk2.vectors = [[0.4, 0.5, 0.6]] # Same dimension (3)
mock_message2.chunks = [mock_chunk2]
# Act - Second call with same collection
await processor.store_document_embeddings(mock_message2)
# Assert
expected_collection = 'd_cache_user_cache_collection'
assert processor.last_collection == expected_collection
# Verify second call skipped existence check (cached)
mock_qdrant_instance.collection_exists.assert_not_called()
mock_qdrant_instance.create_collection.assert_not_called()
# Verify collection existence is checked on each write
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
# But upsert should still be called
mock_qdrant_instance.upsert.assert_called_once()

View file

@ -178,37 +178,24 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
assert calls[2][1]['vectors'][0]['metadata']['entity'] == "entity2"
@pytest.mark.asyncio
async def test_store_graph_embeddings_index_creation(self, processor):
"""Test automatic index creation when index doesn't exist"""
async def test_store_graph_embeddings_index_validation(self, processor):
"""Test that writing to non-existent index raises ValueError"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
# Mock index doesn't exist initially
# Mock index doesn't exist
processor.pinecone.has_index.return_value = False
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Mock index creation
processor.pinecone.describe_index.return_value.status = {"ready": True}
with patch('uuid.uuid4', return_value='test-id'):
with pytest.raises(ValueError, match="Collection .* does not exist"):
await processor.store_graph_embeddings(message)
# Verify index creation was called
expected_index_name = "t-test_user-test_collection"
processor.pinecone.create_index.assert_called_once()
create_call = processor.pinecone.create_index.call_args
assert create_call[1]['name'] == expected_index_name
assert create_call[1]['dimension'] == 3
assert create_call[1]['metric'] == "cosine"
@pytest.mark.asyncio
async def test_store_graph_embeddings_empty_entity_value(self, processor):
@ -328,47 +315,44 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_index_creation_failure(self, processor):
"""Test handling of index creation failure"""
async def test_store_graph_embeddings_validation_before_creation(self, processor):
"""Test that validation error occurs before any creation attempts"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
# Mock index doesn't exist and creation fails
# Mock index doesn't exist
processor.pinecone.has_index.return_value = False
processor.pinecone.create_index.side_effect = Exception("Index creation failed")
with pytest.raises(Exception, match="Index creation failed"):
with pytest.raises(ValueError, match="Collection .* does not exist"):
await processor.store_graph_embeddings(message)
@pytest.mark.asyncio
async def test_store_graph_embeddings_index_creation_timeout(self, processor):
"""Test handling of index creation timeout"""
async def test_store_graph_embeddings_validates_before_timeout(self, processor):
"""Test that validation error occurs before timeout checks"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
# Mock index doesn't exist and never becomes ready
# Mock index doesn't exist
processor.pinecone.has_index.return_value = False
processor.pinecone.describe_index.return_value.status = {"ready": False}
with patch('time.sleep'): # Speed up the test
with pytest.raises(RuntimeError, match="Gave up waiting for index creation"):
await processor.store_graph_embeddings(message)
with pytest.raises(ValueError, match="Collection .* does not exist"):
await processor.store_graph_embeddings(message)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""

View file

@ -43,19 +43,17 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
# Verify processor attributes
assert hasattr(processor, 'qdrant')
assert processor.qdrant == mock_qdrant_instance
assert hasattr(processor, 'last_collection')
assert processor.last_collection is None
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
async def test_get_collection_creates_new_collection(self, mock_base_init, mock_qdrant_client):
"""Test get_collection creates a new collection when it doesn't exist"""
async def test_get_collection_validates_existence(self, mock_base_init, mock_qdrant_client):
"""Test get_collection validates that collection exists"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = False
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
@ -64,22 +62,10 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# Act
collection_name = processor.get_collection(dim=512, user='test_user', collection='test_collection')
# Assert
expected_name = 't_test_user_test_collection'
assert collection_name == expected_name
assert processor.last_collection == expected_name
# Verify collection existence check and creation
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name)
mock_qdrant_instance.create_collection.assert_called_once()
# Verify create_collection was called with correct parameters
create_call_args = mock_qdrant_instance.create_collection.call_args
assert create_call_args[1]['collection_name'] == expected_name
# Act & Assert
with pytest.raises(ValueError, match="Collection .* does not exist"):
processor.get_collection(user='test_user', collection='test_collection')
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
@ -142,7 +128,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True # Collection exists
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
@ -151,15 +137,14 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# Act
collection_name = processor.get_collection(dim=256, user='existing_user', collection='existing_collection')
collection_name = processor.get_collection(user='existing_user', collection='existing_collection')
# Assert
expected_name = 't_existing_user_existing_collection'
assert collection_name == expected_name
assert processor.last_collection == expected_name
# Verify collection existence check was performed
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name)
# Verify create_collection was NOT called
@ -167,14 +152,14 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
async def test_get_collection_caches_last_collection(self, mock_base_init, mock_qdrant_client):
"""Test get_collection skips checks when using same collection"""
async def test_get_collection_validates_on_each_call(self, mock_base_init, mock_qdrant_client):
"""Test get_collection validates collection existence on each call"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
@ -183,36 +168,36 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# First call
collection_name1 = processor.get_collection(dim=128, user='cache_user', collection='cache_collection')
collection_name1 = processor.get_collection(user='cache_user', collection='cache_collection')
# Reset mock to track second call
mock_qdrant_instance.reset_mock()
mock_qdrant_instance.collection_exists.return_value = True
# Act - Second call with same parameters
collection_name2 = processor.get_collection(dim=128, user='cache_user', collection='cache_collection')
collection_name2 = processor.get_collection(user='cache_user', collection='cache_collection')
# Assert
expected_name = 't_cache_user_cache_collection'
assert collection_name1 == expected_name
assert collection_name2 == expected_name
# Verify second call skipped existence check (cached)
mock_qdrant_instance.collection_exists.assert_not_called()
# Verify collection existence check happens on each call
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name)
mock_qdrant_instance.create_collection.assert_not_called()
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
async def test_get_collection_creation_exception(self, mock_base_init, mock_qdrant_client):
"""Test get_collection handles collection creation exceptions"""
"""Test get_collection raises ValueError when collection doesn't exist"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = False
mock_qdrant_instance.create_collection.side_effect = Exception("Qdrant connection failed")
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
@ -221,10 +206,10 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# Act & Assert
with pytest.raises(Exception, match="Qdrant connection failed"):
processor.get_collection(dim=512, user='error_user', collection='error_collection')
with pytest.raises(ValueError, match="Collection .* does not exist"):
processor.get_collection(user='error_user', collection='error_collection')
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')

View file

@ -47,7 +47,7 @@ class TestMemgraphUserCollectionIsolation:
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
@ -55,28 +55,30 @@ class TestMemgraphUserCollectionIsolation:
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=MagicMock())
# Create mock triple with URI object
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.o.value = "http://example.com/object"
triple.o.is_uri = True
# Create mock message with metadata
mock_message = MagicMock()
mock_message.triples = [triple]
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
await processor.store_triples(mock_message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
# Verify user/collection parameters were passed to all operations
# Should have: create_node (subject), create_node (object), relate_node = 3 calls
assert mock_driver.execute_query.call_count == 3
# Check that user and collection were included in all calls
for call in mock_driver.execute_query.call_args_list:
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
@ -93,7 +95,7 @@ class TestMemgraphUserCollectionIsolation:
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
@ -101,24 +103,26 @@ class TestMemgraphUserCollectionIsolation:
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=MagicMock())
# Create mock triple
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.o.value = "literal_value"
triple.o.is_uri = False
# Create mock message without user/collection metadata
mock_message = MagicMock()
mock_message.triples = [triple]
mock_message.metadata.user = None
mock_message.metadata.collection = None
await processor.store_triples(mock_message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
# Verify defaults were used
for call in mock_driver.execute_query.call_args_list:
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
@ -295,7 +299,7 @@ class TestMemgraphUserCollectionRegression:
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
@ -303,23 +307,25 @@ class TestMemgraphUserCollectionRegression:
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=MagicMock())
# Store data for user1
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.o.value = "user1_data"
triple.o.is_uri = False
message_user1 = MagicMock()
message_user1.triples = [triple]
message_user1.metadata.user = "user1"
message_user1.metadata.collection = "collection1"
await processor.store_triples(message_user1)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message_user1)
# Verify that all storage operations included user1/collection1 parameters
for call in mock_driver.execute_query.call_args_list:
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]

View file

@ -75,8 +75,10 @@ class TestNeo4jUserCollectionIsolation:
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_driver.execute_query.return_value.summary = mock_summary
await processor.store_triples(message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
# Verify nodes and relationships were created with user/collection properties
expected_calls = [
@ -141,8 +143,10 @@ class TestNeo4jUserCollectionIsolation:
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_driver.execute_query.return_value.summary = mock_summary
await processor.store_triples(message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
# Verify defaults were used
mock_driver.execute_query.assert_any_call(
@ -273,10 +277,12 @@ class TestNeo4jUserCollectionIsolation:
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_driver.execute_query.return_value.summary = mock_summary
# Store data for both users
await processor.store_triples(message_user1)
await processor.store_triples(message_user2)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
# Store data for both users
await processor.store_triples(message_user1)
await processor.store_triples(message_user2)
# Verify user1 data was stored with user1/coll1
mock_driver.execute_query.assert_any_call(
@ -446,9 +452,11 @@ class TestNeo4jUserCollectionRegression:
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_driver.execute_query.return_value.summary = mock_summary
await processor.store_triples(message_user1)
await processor.store_triples(message_user2)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message_user1)
await processor.store_triples(message_user2)
# Verify two separate nodes were created with same URI but different user/collection
user1_node_call = call(

View file

@ -251,6 +251,8 @@ class TestObjectsCassandraStorageLogic:
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.known_tables = {"test_user": set()} # Pre-populate
# Create test object
test_obj = ExtractedObject(
@ -291,18 +293,19 @@ class TestObjectsCassandraStorageLogic:
"""Test that secondary indexes are created for indexed fields"""
processor = MagicMock()
processor.schemas = {}
processor.known_keyspaces = set()
processor.known_tables = {}
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.known_tables = {"test_user": set()} # Pre-populate
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
def mock_ensure_keyspace(keyspace):
processor.known_keyspaces.add(keyspace)
processor.known_tables[keyspace] = set()
if keyspace not in processor.known_tables:
processor.known_tables[keyspace] = set()
processor.ensure_keyspace = mock_ensure_keyspace
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
# Create schema with indexed field
schema = RowSchema(
name="products",
@ -313,10 +316,10 @@ class TestObjectsCassandraStorageLogic:
Field(name="price", type="float", size=8, indexed=True)
]
)
# Call ensure_table
processor.ensure_table("test_user", "products", schema)
# Should have 3 calls: create table + 2 indexes
assert processor.session.execute.call_count == 3
@ -346,9 +349,10 @@ class TestObjectsCassandraStorageBatchLogic:
]
)
}
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.ensure_table = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
@ -415,6 +419,8 @@ class TestObjectsCassandraStorageBatchLogic:
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.known_tables = {"test_user": set()} # Pre-populate
# Create empty batch object
empty_batch_obj = ExtractedObject(
@ -461,6 +467,8 @@ class TestObjectsCassandraStorageBatchLogic:
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.known_tables = {"test_user": set()} # Pre-populate
# Create single-item batch object (backward compatibility case)
single_batch_obj = ExtractedObject(

View file

@ -194,7 +194,13 @@ class TestFalkorDBStorageProcessor:
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
await processor.store_triples(message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
# Verify queries were called in the correct order
expected_calls = [
@ -225,7 +231,13 @@ class TestFalkorDBStorageProcessor:
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
await processor.store_triples(mock_message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
# Verify queries were called in the correct order
expected_calls = [
@ -273,7 +285,13 @@ class TestFalkorDBStorageProcessor:
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
await processor.store_triples(message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
# Verify total number of queries (3 per triple)
assert processor.io.query.call_count == 6
@ -299,7 +317,13 @@ class TestFalkorDBStorageProcessor:
message.metadata.collection = 'test_collection'
message.triples = []
await processor.store_triples(message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
# Verify no queries were made
processor.io.query.assert_not_called()
@ -329,7 +353,13 @@ class TestFalkorDBStorageProcessor:
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
await processor.store_triples(message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
# Verify total number of queries (3 per triple)
assert processor.io.query.call_count == 6

View file

@ -308,7 +308,13 @@ class TestMemgraphStorageProcessor:
# Reset the mock to clear initialization calls
processor.io.execute_query.reset_mock()
await processor.store_triples(mock_message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
# Verify execute_query was called for create_node, create_literal, and relate_literal
# (since mock_message has a literal object)
@ -352,7 +358,13 @@ class TestMemgraphStorageProcessor:
)
message.triples = [triple1, triple2]
await processor.store_triples(message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
# Verify execute_query was called:
# Triple1: create_node(s) + create_literal(o) + relate_literal = 3 calls
@ -381,7 +393,13 @@ class TestMemgraphStorageProcessor:
message.metadata.collection = 'test_collection'
message.triples = []
await processor.store_triples(message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
# Verify no session calls were made (no triples to process)
processor.io.session.assert_not_called()

View file

@ -268,7 +268,9 @@ class TestNeo4jStorageProcessor:
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
await processor.store_triples(mock_message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
# Verify create_node was called for subject and object
# Verify relate_node was called
@ -336,7 +338,9 @@ class TestNeo4jStorageProcessor:
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
await processor.store_triples(mock_message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
# Verify create_node was called for subject
# Verify create_literal was called for object
@ -411,7 +415,9 @@ class TestNeo4jStorageProcessor:
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
await processor.store_triples(mock_message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
# Should have processed both triples
# Triple1: 2 nodes + 1 relationship = 3 calls
@ -437,7 +443,9 @@ class TestNeo4jStorageProcessor:
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
await processor.store_triples(mock_message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
# Should not have made any execute_query calls beyond index creation
# Only index creation calls should have been made during initialization
@ -552,7 +560,9 @@ class TestNeo4jStorageProcessor:
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
await processor.store_triples(mock_message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
# Verify the triple was processed with special characters preserved
mock_driver.execute_query.assert_any_call(

View file

@ -44,7 +44,7 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gpt-4'
assert processor.default_model == 'gpt-4'
assert processor.temperature == 0.0
assert processor.max_output == 4192
assert hasattr(processor, 'openai')
@ -254,7 +254,7 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gpt-35-turbo'
assert processor.default_model == 'gpt-35-turbo'
assert processor.temperature == 0.7
assert processor.max_output == 2048
mock_azure_openai_class.assert_called_once_with(
@ -289,7 +289,7 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gpt-4'
assert processor.default_model == 'gpt-4'
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 4192 # default_max_output
mock_azure_openai_class.assert_called_once_with(
@ -402,6 +402,156 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
assert call_args[1]['max_tokens'] == 1024
assert call_args[1]['top_p'] == 1
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_azure_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = 'Response with custom temperature'
mock_response.usage.prompt_tokens = 20
mock_response.usage.completion_tokens = 12
mock_azure_client.chat.completions.create.return_value = mock_response
mock_azure_openai_class.return_value = mock_azure_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-4',
'endpoint': 'https://test.openai.azure.com/',
'token': 'test-token',
'api_version': '2024-12-01-preview',
'temperature': 0.0, # Default temperature
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.8 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify Azure OpenAI API was called with overridden temperature
call_args = mock_azure_client.chat.completions.create.call_args
assert call_args[1]['temperature'] == 0.8 # Should use runtime override
assert call_args[1]['model'] == 'gpt-4'
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
"""Test model parameter override functionality"""
# Arrange
mock_azure_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = 'Response with custom model'
mock_response.usage.prompt_tokens = 18
mock_response.usage.completion_tokens = 14
mock_azure_client.chat.completions.create.return_value = mock_response
mock_azure_openai_class.return_value = mock_azure_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-4', # Default model
'endpoint': 'https://test.openai.azure.com/',
'token': 'test-token',
'api_version': '2024-12-01-preview',
'temperature': 0.1, # Default temperature
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gpt-4o", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify Azure OpenAI API was called with overridden model
call_args = mock_azure_client.chat.completions.create.call_args
assert call_args[1]['model'] == 'gpt-4o' # Should use runtime override
assert call_args[1]['temperature'] == 0.1 # Should use processor default
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_azure_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = 'Response with both overrides'
mock_response.usage.prompt_tokens = 22
mock_response.usage.completion_tokens = 16
mock_azure_client.chat.completions.create.return_value = mock_response
mock_azure_openai_class.return_value = mock_azure_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-4', # Default model
'endpoint': 'https://test.openai.azure.com/',
'token': 'test-token',
'api_version': '2024-12-01-preview',
'temperature': 0.0, # Default temperature
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gpt-4o-mini", # Override model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify Azure OpenAI API was called with both overrides
call_args = mock_azure_client.chat.completions.create.call_args
assert call_args[1]['model'] == 'gpt-4o-mini' # Should use runtime override
assert call_args[1]['temperature'] == 0.9 # Should use runtime override
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -43,7 +43,7 @@ class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
assert processor.token == 'test-token'
assert processor.temperature == 0.0
assert processor.max_output == 4192
assert processor.model == 'AzureAI'
assert processor.default_model == 'AzureAI'
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -261,7 +261,7 @@ class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
assert processor.token == 'custom-token'
assert processor.temperature == 0.7
assert processor.max_output == 2048
assert processor.model == 'AzureAI'
assert processor.default_model == 'AzureAI'
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -289,7 +289,7 @@ class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
assert processor.token == 'test-token'
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 4192 # default_max_output
assert processor.model == 'AzureAI' # default_model
assert processor.default_model == 'AzureAI' # default_model
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -459,5 +459,150 @@ class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
)
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_requests):
"""Test generate_content with model parameter override"""
# Arrange
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
'choices': [{
'message': {
'content': 'Response with model override'
}
}],
'usage': {
'prompt_tokens': 15,
'completion_tokens': 10
}
}
mock_requests.post.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
'token': 'test-token',
'temperature': 0.0,
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model
result = await processor.generate_content("System", "Prompt", model="custom-azure-model")
# Assert
assert result.model == "custom-azure-model" # Should use overridden model
assert result.text == "Response with model override"
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_requests):
"""Test generate_content with temperature parameter override"""
# Arrange
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
'choices': [{
'message': {
'content': 'Response with temperature override'
}
}],
'usage': {
'prompt_tokens': 15,
'completion_tokens': 10
}
}
mock_requests.post.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
'token': 'test-token',
'temperature': 0.0, # Default temperature
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature
result = await processor.generate_content("System", "Prompt", temperature=0.8)
# Assert
assert result.text == "Response with temperature override"
# Verify the request was made with the overridden temperature
mock_requests.post.assert_called_once()
call_args = mock_requests.post.call_args
import json
request_body = json.loads(call_args[1]['data'])
assert request_body['temperature'] == 0.8
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_requests):
"""Test generate_content with both model and temperature overrides"""
# Arrange
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
'choices': [{
'message': {
'content': 'Response with both parameters override'
}
}],
'usage': {
'prompt_tokens': 18,
'completion_tokens': 12
}
}
mock_requests.post.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
'token': 'test-token',
'temperature': 0.0,
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.9)
# Assert
assert result.model == "override-model"
assert result.text == "Response with both parameters override"
# Verify the request was made with overridden temperature
mock_requests.post.assert_called_once()
call_args = mock_requests.post.call_args
import json
request_body = json.loads(call_args[1]['data'])
assert request_body['temperature'] == 0.9
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,280 @@
"""
Unit tests for trustgraph.model.text_completion.bedrock
Following the same successful pattern as other processor tests
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
import json
# Import the service under test
from trustgraph.model.text_completion.bedrock.llm import Processor, Mistral, Anthropic
from trustgraph.base import LlmResult
class TestBedrockProcessorSimple(IsolatedAsyncioTestCase):
"""Test Bedrock processor functionality"""
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test basic processor initialization"""
# Arrange
mock_session = MagicMock()
mock_bedrock = MagicMock()
mock_session.client.return_value = mock_bedrock
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'mistral.mistral-large-2407-v1:0',
'temperature': 0.1,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.default_model == 'mistral.mistral-large-2407-v1:0'
assert processor.temperature == 0.1
assert hasattr(processor, 'bedrock')
mock_session_class.assert_called_once()
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success_mistral(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test successful content generation with Mistral model"""
# Arrange
mock_session = MagicMock()
mock_bedrock = MagicMock()
mock_session.client.return_value = mock_bedrock
mock_session_class.return_value = mock_session
mock_response = {
'body': MagicMock(),
'ResponseMetadata': {
'HTTPHeaders': {
'x-amzn-bedrock-input-token-count': '15',
'x-amzn-bedrock-output-token-count': '8'
}
}
}
mock_response['body'].read.return_value = json.dumps({
'outputs': [{'text': 'Generated response from Bedrock'}]
})
mock_bedrock.invoke_model.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'mistral.mistral-large-2407-v1:0',
'temperature': 0.0,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System prompt", "User prompt")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Generated response from Bedrock"
assert result.in_token == 15
assert result.out_token == 8
assert result.model == 'mistral.mistral-large-2407-v1:0'
mock_bedrock.invoke_model.assert_called_once()
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_session = MagicMock()
mock_bedrock = MagicMock()
mock_session.client.return_value = mock_bedrock
mock_session_class.return_value = mock_session
mock_response = {
'body': MagicMock(),
'ResponseMetadata': {
'HTTPHeaders': {
'x-amzn-bedrock-input-token-count': '20',
'x-amzn-bedrock-output-token-count': '12'
}
}
}
mock_response['body'].read.return_value = json.dumps({
'outputs': [{'text': 'Response with custom temperature'}]
})
mock_bedrock.invoke_model.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'mistral.mistral-large-2407-v1:0',
'temperature': 0.0, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.8 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify the model variant was created with overridden temperature
# The cache key should include the temperature
cache_key = f"mistral.mistral-large-2407-v1:0:0.8"
assert cache_key in processor.model_variants
variant = processor.model_variants[cache_key]
assert variant.temperature == 0.8
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test model parameter override functionality"""
# Arrange
mock_session = MagicMock()
mock_bedrock = MagicMock()
mock_session.client.return_value = mock_bedrock
mock_session_class.return_value = mock_session
mock_response = {
'body': MagicMock(),
'ResponseMetadata': {
'HTTPHeaders': {
'x-amzn-bedrock-input-token-count': '18',
'x-amzn-bedrock-output-token-count': '14'
}
}
}
mock_response['body'].read.return_value = json.dumps({
'content': [{'text': 'Response with custom model'}]
})
mock_bedrock.invoke_model.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'mistral.mistral-large-2407-v1:0', # Default model
'temperature': 0.1, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="anthropic.claude-3-sonnet-20240229-v1:0", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify Bedrock API was called with overridden model
mock_bedrock.invoke_model.assert_called_once()
call_args = mock_bedrock.invoke_model.call_args
assert call_args[1]['modelId'] == "anthropic.claude-3-sonnet-20240229-v1:0"
# Verify the correct model variant (Anthropic) was used
cache_key = f"anthropic.claude-3-sonnet-20240229-v1:0:0.1"
assert cache_key in processor.model_variants
variant = processor.model_variants[cache_key]
assert isinstance(variant, Anthropic)
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_session = MagicMock()
mock_bedrock = MagicMock()
mock_session.client.return_value = mock_bedrock
mock_session_class.return_value = mock_session
mock_response = {
'body': MagicMock(),
'ResponseMetadata': {
'HTTPHeaders': {
'x-amzn-bedrock-input-token-count': '22',
'x-amzn-bedrock-output-token-count': '16'
}
}
}
mock_response['body'].read.return_value = json.dumps({
'generation': 'Response with both overrides'
})
mock_bedrock.invoke_model.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'mistral.mistral-large-2407-v1:0', # Default model
'temperature': 0.0, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="meta.llama3-70b-instruct-v1:0", # Override model (Meta/Llama)
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify Bedrock API was called with both overrides
mock_bedrock.invoke_model.assert_called_once()
call_args = mock_bedrock.invoke_model.call_args
assert call_args[1]['modelId'] == "meta.llama3-70b-instruct-v1:0"
# Verify the correct model variant (Meta) was used with correct temperature
cache_key = f"meta.llama3-70b-instruct-v1:0:0.9"
assert cache_key in processor.model_variants
variant = processor.model_variants[cache_key]
assert variant.temperature == 0.9
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -42,7 +42,7 @@ class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'claude-3-5-sonnet-20240620'
assert processor.default_model == 'claude-3-5-sonnet-20240620'
assert processor.temperature == 0.0
assert processor.max_output == 8192
assert hasattr(processor, 'claude')
@ -217,7 +217,7 @@ class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'claude-3-haiku-20240307'
assert processor.default_model == 'claude-3-haiku-20240307'
assert processor.temperature == 0.7
assert processor.max_output == 4096
mock_anthropic_class.assert_called_once_with(api_key='custom-api-key')
@ -246,7 +246,7 @@ class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'claude-3-5-sonnet-20240620' # default_model
assert processor.default_model == 'claude-3-5-sonnet-20240620' # default_model
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 8192 # default_max_output
mock_anthropic_class.assert_called_once_with(api_key='test-api-key')
@ -433,7 +433,157 @@ class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
# Verify processor has the client
assert processor.claude == mock_claude_client
assert processor.model == 'claude-3-opus-20240229'
assert processor.default_model == 'claude-3-opus-20240229'
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_anthropic_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_claude_client = MagicMock()
mock_response = MagicMock()
mock_response.content = [MagicMock()]
mock_response.content[0].text = "Response with custom temperature"
mock_response.usage.input_tokens = 20
mock_response.usage.output_tokens = 12
mock_claude_client.messages.create.return_value = mock_response
mock_anthropic_class.return_value = mock_claude_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'claude-3-5-sonnet-20240620',
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify Claude API was called with overridden temperature
mock_claude_client.messages.create.assert_called_once()
call_kwargs = mock_claude_client.messages.create.call_args.kwargs
assert call_kwargs['temperature'] == 0.9 # Should use runtime override
assert call_kwargs['model'] == 'claude-3-5-sonnet-20240620' # Should use processor default
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_anthropic_class):
"""Test model parameter override functionality"""
# Arrange
mock_claude_client = MagicMock()
mock_response = MagicMock()
mock_response.content = [MagicMock()]
mock_response.content[0].text = "Response with custom model"
mock_response.usage.input_tokens = 18
mock_response.usage.output_tokens = 14
mock_claude_client.messages.create.return_value = mock_response
mock_anthropic_class.return_value = mock_claude_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'claude-3-5-sonnet-20240620', # Default model
'api_key': 'test-api-key',
'temperature': 0.2, # Default temperature
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="claude-3-haiku-20240307", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify Claude API was called with overridden model
mock_claude_client.messages.create.assert_called_once()
call_kwargs = mock_claude_client.messages.create.call_args.kwargs
assert call_kwargs['model'] == 'claude-3-haiku-20240307' # Should use runtime override
assert call_kwargs['temperature'] == 0.2 # Should use processor default
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_anthropic_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_claude_client = MagicMock()
mock_response = MagicMock()
mock_response.content = [MagicMock()]
mock_response.content[0].text = "Response with both overrides"
mock_response.usage.input_tokens = 22
mock_response.usage.output_tokens = 16
mock_claude_client.messages.create.return_value = mock_response
mock_anthropic_class.return_value = mock_claude_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'claude-3-5-sonnet-20240620', # Default model
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="claude-3-opus-20240229", # Override model
temperature=0.8 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify Claude API was called with both overrides
mock_claude_client.messages.create.assert_called_once()
call_kwargs = mock_claude_client.messages.create.call_args.kwargs
assert call_kwargs['model'] == 'claude-3-opus-20240229' # Should use runtime override
assert call_kwargs['temperature'] == 0.8 # Should use runtime override
if __name__ == '__main__':

View file

@ -41,7 +41,7 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'c4ai-aya-23-8b'
assert processor.default_model == 'c4ai-aya-23-8b'
assert processor.temperature == 0.0
assert hasattr(processor, 'cohere')
mock_cohere_class.assert_called_once_with(api_key='test-api-key')
@ -201,7 +201,7 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'command-light'
assert processor.default_model == 'command-light'
assert processor.temperature == 0.7
mock_cohere_class.assert_called_once_with(api_key='custom-api-key')
@ -229,7 +229,7 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'c4ai-aya-23-8b' # default_model
assert processor.default_model == 'c4ai-aya-23-8b' # default_model
assert processor.temperature == 0.0 # default_temperature
mock_cohere_class.assert_called_once_with(api_key='test-api-key')
@ -395,7 +395,7 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
# Verify processor has the client
assert processor.cohere == mock_cohere_client
assert processor.model == 'command-r'
assert processor.default_model == 'command-r'
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -442,6 +442,162 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
assert call_args[1]['prompt_truncation'] == 'auto'
assert call_args[1]['connectors'] == []
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_cohere_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_cohere_client = MagicMock()
mock_output = MagicMock()
mock_output.text = 'Response with custom temperature'
mock_output.meta.billed_units.input_tokens = 20
mock_output.meta.billed_units.output_tokens = 12
mock_cohere_client.chat.return_value = mock_output
mock_cohere_class.return_value = mock_cohere_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'c4ai-aya-23-8b',
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.8 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify Cohere API was called with overridden temperature
mock_cohere_client.chat.assert_called_once_with(
model='c4ai-aya-23-8b',
message='User prompt',
preamble='System prompt',
temperature=0.8, # Should use runtime override
chat_history=[],
prompt_truncation='auto',
connectors=[]
)
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_cohere_class):
"""Test model parameter override functionality"""
# Arrange
mock_cohere_client = MagicMock()
mock_output = MagicMock()
mock_output.text = 'Response with custom model'
mock_output.meta.billed_units.input_tokens = 18
mock_output.meta.billed_units.output_tokens = 14
mock_cohere_client.chat.return_value = mock_output
mock_cohere_class.return_value = mock_cohere_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'c4ai-aya-23-8b', # Default model
'api_key': 'test-api-key',
'temperature': 0.1, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="command-r-plus", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify Cohere API was called with overridden model
mock_cohere_client.chat.assert_called_once_with(
model='command-r-plus', # Should use runtime override
message='User prompt',
preamble='System prompt',
temperature=0.1, # Should use processor default
chat_history=[],
prompt_truncation='auto',
connectors=[]
)
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_cohere_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_cohere_client = MagicMock()
mock_output = MagicMock()
mock_output.text = 'Response with both overrides'
mock_output.meta.billed_units.input_tokens = 22
mock_output.meta.billed_units.output_tokens = 16
mock_cohere_client.chat.return_value = mock_output
mock_cohere_class.return_value = mock_cohere_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'c4ai-aya-23-8b', # Default model
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="command-r", # Override model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify Cohere API was called with both overrides
mock_cohere_client.chat.assert_called_once_with(
model='command-r', # Should use runtime override
message='User prompt',
preamble='System prompt',
temperature=0.9, # Should use runtime override
chat_history=[],
prompt_truncation='auto',
connectors=[]
)
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -42,7 +42,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gemini-2.0-flash-001'
assert processor.default_model == 'gemini-2.0-flash-001'
assert processor.temperature == 0.0
assert processor.max_output == 8192
assert hasattr(processor, 'client')
@ -205,7 +205,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gemini-1.5-pro'
assert processor.default_model == 'gemini-1.5-pro'
assert processor.temperature == 0.7
assert processor.max_output == 4096
mock_genai_class.assert_called_once_with(api_key='custom-api-key')
@ -234,7 +234,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gemini-2.0-flash-001' # default_model
assert processor.default_model == 'gemini-2.0-flash-001' # default_model
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 8192 # default_max_output
mock_genai_class.assert_called_once_with(api_key='test-api-key')
@ -431,7 +431,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
# Verify processor has the client
assert processor.client == mock_genai_client
assert processor.model == 'gemini-1.5-flash'
assert processor.default_model == 'gemini-1.5-flash'
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -477,6 +477,156 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
# The system instruction should be in the config object
assert call_args[1]['contents'] == "Explain quantum computing"
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_genai_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_genai_client = MagicMock()
mock_response = MagicMock()
mock_response.text = 'Response with custom temperature'
mock_response.usage_metadata.prompt_token_count = 20
mock_response.usage_metadata.candidates_token_count = 12
mock_genai_client.models.generate_content.return_value = mock_response
mock_genai_class.return_value = mock_genai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001',
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.8 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify the generation config was created with overridden temperature
cache_key = f"gemini-2.0-flash-001:0.8"
assert cache_key in processor.generation_configs
config_obj = processor.generation_configs[cache_key]
assert config_obj.temperature == 0.8
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_genai_class):
"""Test model parameter override functionality"""
# Arrange
mock_genai_client = MagicMock()
mock_response = MagicMock()
mock_response.text = 'Response with custom model'
mock_response.usage_metadata.prompt_token_count = 18
mock_response.usage_metadata.candidates_token_count = 14
mock_genai_client.models.generate_content.return_value = mock_response
mock_genai_class.return_value = mock_genai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001', # Default model
'api_key': 'test-api-key',
'temperature': 0.1, # Default temperature
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gemini-1.5-pro", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify Google AI Studio API was called with overridden model
call_args = mock_genai_client.models.generate_content.call_args
assert call_args[1]['model'] == 'gemini-1.5-pro' # Should use runtime override
# Verify the generation config was created for the correct model
cache_key = f"gemini-1.5-pro:0.1"
assert cache_key in processor.generation_configs
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_genai_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_genai_client = MagicMock()
mock_response = MagicMock()
mock_response.text = 'Response with both overrides'
mock_response.usage_metadata.prompt_token_count = 22
mock_response.usage_metadata.candidates_token_count = 16
mock_genai_client.models.generate_content.return_value = mock_response
mock_genai_class.return_value = mock_genai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001', # Default model
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gemini-1.5-flash", # Override model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify Google AI Studio API was called with both overrides
call_args = mock_genai_client.models.generate_content.call_args
assert call_args[1]['model'] == 'gemini-1.5-flash' # Should use runtime override
# Verify the generation config was created with both overrides
cache_key = f"gemini-1.5-flash:0.9"
assert cache_key in processor.generation_configs
config_obj = processor.generation_configs[cache_key]
assert config_obj.temperature == 0.9
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -42,7 +42,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'LLaMA_CPP'
assert processor.default_model == 'LLaMA_CPP'
assert processor.llamafile == 'http://localhost:8080/v1'
assert processor.temperature == 0.0
assert processor.max_output == 4096
@ -91,7 +91,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
assert result.text == "Generated response from LlamaFile"
assert result.in_token == 20
assert result.out_token == 12
assert result.model == 'llama.cpp' # Note: model in result is hardcoded to 'llama.cpp'
assert result.model == 'LLaMA_CPP' # Uses the default model name
# Verify the OpenAI API call structure
mock_openai_client.chat.completions.create.assert_called_once_with(
@ -99,7 +99,15 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
messages=[{
"role": "user",
"content": "System prompt\n\nUser prompt"
}]
}],
temperature=0.0,
max_tokens=4096,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
response_format={
"type": "text"
}
)
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
@ -157,7 +165,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'custom-llama'
assert processor.default_model == 'custom-llama'
assert processor.llamafile == 'http://custom-host:8080/v1'
assert processor.temperature == 0.7
assert processor.max_output == 2048
@ -189,7 +197,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'LLaMA_CPP' # default_model
assert processor.default_model == 'LLaMA_CPP' # default_model
assert processor.llamafile == 'http://localhost:8080/v1' # default_llamafile
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 4096 # default_max_output
@ -237,7 +245,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
assert result.text == "Default response"
assert result.in_token == 2
assert result.out_token == 3
assert result.model == 'llama.cpp'
assert result.model == 'LLaMA_CPP'
# Verify the combined prompt is sent correctly
call_args = mock_openai_client.chat.completions.create.call_args
@ -408,8 +416,8 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
result = await processor.generate_content("System", "User")
# Assert
assert result.model == 'llama.cpp' # Should always be 'llama.cpp', not 'custom-model-name'
assert processor.model == 'custom-model-name' # But processor.model should still be custom
assert result.model == 'custom-model-name' # Uses the actual model name passed to generate_content
assert processor.default_model == 'custom-model-name' # But processor.model should still be custom
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -450,5 +458,132 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
# No specific rate limit error handling tested since SLM presumably has no rate limits
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test generate_content with model parameter override"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response from overridden model"
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 10
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'LLaMA_CPP',
'llamafile': 'http://localhost:8080/v1',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model
result = await processor.generate_content("System", "Prompt", model="custom-llamafile-model")
# Assert
assert result.model == "custom-llamafile-model" # Should use overridden model
assert result.text == "Response from overridden model"
# Verify the API call was made with overridden model
call_args = mock_openai_client.chat.completions.create.call_args
assert call_args[1]['model'] == "custom-llamafile-model"
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test generate_content with temperature parameter override"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response with temperature override"
mock_response.usage.prompt_tokens = 18
mock_response.usage.completion_tokens = 12
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'LLaMA_CPP',
'llamafile': 'http://localhost:8080/v1',
'temperature': 0.0, # Default temperature
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature
result = await processor.generate_content("System", "Prompt", temperature=0.7)
# Assert
assert result.text == "Response with temperature override"
# Verify the API call was made with overridden temperature
call_args = mock_openai_client.chat.completions.create.call_args
assert call_args[1]['temperature'] == 0.7
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test generate_content with both model and temperature overrides"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response with both parameters override"
mock_response.usage.prompt_tokens = 20
mock_response.usage.completion_tokens = 15
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'LLaMA_CPP',
'llamafile': 'http://localhost:8080/v1',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
# Assert
assert result.model == "override-model"
assert result.text == "Response with both parameters override"
# Verify the API call was made with overridden parameters
call_args = mock_openai_client.chat.completions.create.call_args
assert call_args[1]['model'] == "override-model"
assert call_args[1]['temperature'] == 0.8
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,229 @@
"""
Unit tests for trustgraph.model.text_completion.lmstudio
Following the same successful pattern as previous tests
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.model.text_completion.lmstudio.llm import Processor
from trustgraph.base import LlmResult
from trustgraph.exceptions import TooManyRequests
class TestLMStudioProcessorSimple(IsolatedAsyncioTestCase):
"""Test LMStudio processor functionality"""
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test basic processor initialization"""
# Arrange
mock_openai = MagicMock()
mock_openai_class.return_value = mock_openai
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemma3:9b',
'url': 'http://localhost:1234/',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.default_model == 'gemma3:9b'
assert processor.url == 'http://localhost:1234/v1/'
assert processor.temperature == 0.0
assert processor.max_output == 4096
assert hasattr(processor, 'openai')
mock_openai_class.assert_called_once_with(
base_url='http://localhost:1234/v1/',
api_key='sk-no-key-required'
)
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test successful content generation"""
# Arrange
mock_openai = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Generated response from LMStudio'
mock_response.usage.prompt_tokens = 20
mock_response.usage.completion_tokens = 12
mock_openai.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemma3:9b',
'url': 'http://localhost:1234/',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System prompt", "User prompt")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Generated response from LMStudio"
assert result.in_token == 20
assert result.out_token == 12
assert result.model == 'gemma3:9b'
# Verify the API call was made correctly
mock_openai.chat.completions.create.assert_called_once()
call_args = mock_openai.chat.completions.create.call_args
# Check model and temperature
assert call_args[1]['model'] == 'gemma3:9b'
assert call_args[1]['temperature'] == 0.0
assert call_args[1]['max_tokens'] == 4096
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test generate_content with model parameter override"""
# Arrange
mock_openai = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Response from overridden model'
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 10
mock_openai.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemma3:9b',
'url': 'http://localhost:1234/',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model
result = await processor.generate_content("System", "Prompt", model="custom-lmstudio-model")
# Assert
assert result.model == "custom-lmstudio-model" # Should use overridden model
assert result.text == "Response from overridden model"
# Verify the API call was made with overridden model
call_args = mock_openai.chat.completions.create.call_args
assert call_args[1]['model'] == "custom-lmstudio-model"
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test generate_content with temperature parameter override"""
# Arrange
mock_openai = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Response with temperature override'
mock_response.usage.prompt_tokens = 18
mock_response.usage.completion_tokens = 12
mock_openai.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemma3:9b',
'url': 'http://localhost:1234/',
'temperature': 0.0, # Default temperature
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature
result = await processor.generate_content("System", "Prompt", temperature=0.7)
# Assert
assert result.text == "Response with temperature override"
# Verify the API call was made with overridden temperature
call_args = mock_openai.chat.completions.create.call_args
assert call_args[1]['temperature'] == 0.7
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test generate_content with both model and temperature overrides"""
# Arrange
mock_openai = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Response with both parameters override'
mock_response.usage.prompt_tokens = 20
mock_response.usage.completion_tokens = 15
mock_openai.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemma3:9b',
'url': 'http://localhost:1234/',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
# Assert
assert result.model == "override-model"
assert result.text == "Response with both parameters override"
# Verify the API call was made with overridden parameters
call_args = mock_openai.chat.completions.create.call_args
assert call_args[1]['model'] == "override-model"
assert call_args[1]['temperature'] == 0.8
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,275 @@
"""
Unit tests for trustgraph.model.text_completion.mistral
Following the same successful pattern as other processor tests
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.model.text_completion.mistral.llm import Processor
from trustgraph.base import LlmResult
class TestMistralProcessorSimple(IsolatedAsyncioTestCase):
"""Test Mistral processor functionality"""
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_mistral_class):
"""Test basic processor initialization"""
# Arrange
mock_mistral_client = MagicMock()
mock_mistral_class.return_value = mock_mistral_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'ministral-8b-latest',
'api_key': 'test-api-key',
'temperature': 0.1,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.default_model == 'ministral-8b-latest'
assert processor.temperature == 0.1
assert processor.max_output == 2048
assert hasattr(processor, 'mistral')
mock_mistral_class.assert_called_once_with(api_key='test-api-key')
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_mistral_class):
"""Test successful content generation"""
# Arrange
mock_mistral_client = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Generated response from Mistral'
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 8
mock_mistral_client.chat.complete.return_value = mock_response
mock_mistral_class.return_value = mock_mistral_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'ministral-8b-latest',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System prompt", "User prompt")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Generated response from Mistral"
assert result.in_token == 15
assert result.out_token == 8
assert result.model == 'ministral-8b-latest'
mock_mistral_client.chat.complete.assert_called_once()
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_mistral_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_mistral_client = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Response with custom temperature'
mock_response.usage.prompt_tokens = 20
mock_response.usage.completion_tokens = 12
mock_mistral_client.chat.complete.return_value = mock_response
mock_mistral_class.return_value = mock_mistral_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'ministral-8b-latest',
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.8 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify Mistral API was called with overridden temperature
call_args = mock_mistral_client.chat.complete.call_args
assert call_args[1]['temperature'] == 0.8 # Should use runtime override
assert call_args[1]['model'] == 'ministral-8b-latest'
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_mistral_class):
"""Test model parameter override functionality"""
# Arrange
mock_mistral_client = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Response with custom model'
mock_response.usage.prompt_tokens = 18
mock_response.usage.completion_tokens = 14
mock_mistral_client.chat.complete.return_value = mock_response
mock_mistral_class.return_value = mock_mistral_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'ministral-8b-latest', # Default model
'api_key': 'test-api-key',
'temperature': 0.1, # Default temperature
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="mistral-large-latest", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify Mistral API was called with overridden model
call_args = mock_mistral_client.chat.complete.call_args
assert call_args[1]['model'] == 'mistral-large-latest' # Should use runtime override
assert call_args[1]['temperature'] == 0.1 # Should use processor default
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_mistral_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_mistral_client = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Response with both overrides'
mock_response.usage.prompt_tokens = 22
mock_response.usage.completion_tokens = 16
mock_mistral_client.chat.complete.return_value = mock_response
mock_mistral_class.return_value = mock_mistral_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'ministral-8b-latest', # Default model
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="mistral-large-latest", # Override model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify Mistral API was called with both overrides
call_args = mock_mistral_client.chat.complete.call_args
assert call_args[1]['model'] == 'mistral-large-latest' # Should use runtime override
assert call_args[1]['temperature'] == 0.9 # Should use runtime override
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_mistral_class):
"""Test prompt construction with system and user prompts"""
# Arrange
mock_mistral_client = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Response with system instructions'
mock_response.usage.prompt_tokens = 25
mock_response.usage.completion_tokens = 15
mock_mistral_client.chat.complete.return_value = mock_response
mock_mistral_class.return_value = mock_mistral_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'ministral-8b-latest',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
# Assert
assert result.text == "Response with system instructions"
assert result.in_token == 25
assert result.out_token == 15
# Verify the combined prompt structure
call_args = mock_mistral_client.chat.complete.call_args
messages = call_args[1]['messages']
assert len(messages) == 1
assert messages[0]['role'] == 'user'
assert messages[0]['content'][0]['type'] == 'text'
assert messages[0]['content'][0]['text'] == "You are a helpful assistant\n\nWhat is AI?"
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -40,7 +40,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'llama2'
assert processor.default_model == 'llama2'
assert hasattr(processor, 'llm')
mock_client_class.assert_called_once_with(host='http://localhost:11434')
@ -81,7 +81,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
assert result.in_token == 15
assert result.out_token == 8
assert result.model == 'llama2'
mock_client.generate.assert_called_once_with('llama2', "System prompt\n\nUser prompt")
mock_client.generate.assert_called_once_with('llama2', "System prompt\n\nUser prompt", options={'temperature': 0.0})
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -134,7 +134,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'mistral'
assert processor.default_model == 'mistral'
mock_client_class.assert_called_once_with(host='http://192.168.1.100:11434')
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@ -160,7 +160,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gemma2:9b' # default_model
assert processor.default_model == 'gemma2:9b' # default_model
# Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env)
mock_client_class.assert_called_once()
@ -203,7 +203,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
assert result.model == 'llama2'
# The prompt should be "" + "\n\n" + "" = "\n\n"
mock_client.generate.assert_called_once_with('llama2', "\n\n")
mock_client.generate.assert_called_once_with('llama2', "\n\n", options={'temperature': 0.0})
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -310,7 +310,151 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
assert result.out_token == 15
# Verify the combined prompt
mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?")
mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?", options={'temperature': 0.0})
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_client = MagicMock()
mock_response = {
'response': 'Response with custom temperature',
'prompt_eval_count': 20,
'eval_count': 12
}
mock_client.generate.return_value = mock_response
mock_client_class.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'llama2',
'ollama': 'http://localhost:11434',
'temperature': 0.0, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.8 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify Ollama API was called with overridden temperature
mock_client.generate.assert_called_once_with(
'llama2',
"System prompt\n\nUser prompt",
options={'temperature': 0.8} # Should use runtime override
)
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test model parameter override functionality"""
# Arrange
mock_client = MagicMock()
mock_response = {
'response': 'Response with custom model',
'prompt_eval_count': 18,
'eval_count': 14
}
mock_client.generate.return_value = mock_response
mock_client_class.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'llama2', # Default model
'ollama': 'http://localhost:11434',
'temperature': 0.1, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="mistral", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify Ollama API was called with overridden model
mock_client.generate.assert_called_once_with(
'mistral', # Should use runtime override
"System prompt\n\nUser prompt",
options={'temperature': 0.1} # Should use processor default
)
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_client = MagicMock()
mock_response = {
'response': 'Response with both overrides',
'prompt_eval_count': 22,
'eval_count': 16
}
mock_client.generate.return_value = mock_response
mock_client_class.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'llama2', # Default model
'ollama': 'http://localhost:11434',
'temperature': 0.0, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="codellama", # Override model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify Ollama API was called with both overrides
mock_client.generate.assert_called_once_with(
'codellama', # Should use runtime override
"System prompt\n\nUser prompt",
options={'temperature': 0.9} # Should use runtime override
)
if __name__ == '__main__':

View file

@ -43,7 +43,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gpt-3.5-turbo'
assert processor.default_model == 'gpt-3.5-turbo'
assert processor.temperature == 0.0
assert processor.max_output == 4096
assert hasattr(processor, 'openai')
@ -222,7 +222,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gpt-4'
assert processor.default_model == 'gpt-4'
assert processor.temperature == 0.7
assert processor.max_output == 2048
mock_openai_class.assert_called_once_with(base_url='https://custom-openai-url.com/v1', api_key='custom-api-key')
@ -251,7 +251,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gpt-3.5-turbo' # default_model
assert processor.default_model == 'gpt-3.5-turbo' # default_model
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 4096 # default_max_output
mock_openai_class.assert_called_once_with(base_url='https://api.openai.com/v1', api_key='test-api-key')
@ -391,5 +391,210 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
assert call_args[1]['response_format'] == {"type": "text"}
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response with custom temperature"
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 10
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-3.5-turbo',
'api_key': 'test-api-key',
'url': 'https://api.openai.com/v1',
'temperature': 0.0, # Default temperature
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify the OpenAI API was called with overridden temperature
mock_openai_client.chat.completions.create.assert_called_once()
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
assert call_kwargs['temperature'] == 0.9 # Should use runtime override
assert call_kwargs['model'] == 'gpt-3.5-turbo' # Should use processor default
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test model parameter override functionality"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response with custom model"
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 10
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-3.5-turbo', # Default model
'api_key': 'test-api-key',
'url': 'https://api.openai.com/v1',
'temperature': 0.2,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gpt-4", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify the OpenAI API was called with overridden model
mock_openai_client.chat.completions.create.assert_called_once()
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
assert call_kwargs['model'] == 'gpt-4' # Should use runtime override
assert call_kwargs['temperature'] == 0.2 # Should use processor default
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response with both overrides"
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 10
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-3.5-turbo', # Default model
'api_key': 'test-api-key',
'url': 'https://api.openai.com/v1',
'temperature': 0.0, # Default temperature
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gpt-4", # Override model
temperature=0.7 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify the OpenAI API was called with both overrides
mock_openai_client.chat.completions.create.assert_called_once()
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
assert call_kwargs['model'] == 'gpt-4' # Should use runtime override
assert call_kwargs['temperature'] == 0.7 # Should use runtime override
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_no_override_uses_defaults(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test that when no parameters are overridden, processor defaults are used"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response with defaults"
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 10
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-4', # Default model
'api_key': 'test-api-key',
'url': 'https://api.openai.com/v1',
'temperature': 0.5, # Default temperature
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Don't override any parameters (pass None)
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with defaults"
# Verify the OpenAI API was called with processor defaults
mock_openai_client.chat.completions.create.assert_called_once()
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
assert call_kwargs['model'] == 'gpt-4' # Should use processor default
assert call_kwargs['temperature'] == 0.5 # Should use processor default
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,186 @@
"""
Unit tests for Parameter-Based Caching in LLM Processors
Testing processors that cache based on temperature parameters (Bedrock, GoogleAIStudio)
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
from trustgraph.model.text_completion.googleaistudio.llm import Processor as GoogleAIProcessor
from trustgraph.base import LlmResult
class TestParameterCaching(IsolatedAsyncioTestCase):
"""Test parameter-based caching functionality"""
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_googleai_temperature_cache_keys(self, mock_llm_init, mock_async_init, mock_genai):
"""Test that GoogleAI processor creates separate cache entries for different temperatures"""
# Arrange
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
mock_response = MagicMock()
mock_response.text = "Generated response"
mock_response.usage_metadata.prompt_token_count = 10
mock_response.usage_metadata.candidates_token_count = 5
mock_client.models.generate_content.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001',
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'max_output': 1024,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = GoogleAIProcessor(**config)
# Act - Call with different temperatures
await processor.generate_content("System", "Prompt 1", model="gemini-2.0-flash-001", temperature=0.0)
await processor.generate_content("System", "Prompt 2", model="gemini-2.0-flash-001", temperature=0.5)
await processor.generate_content("System", "Prompt 3", model="gemini-2.0-flash-001", temperature=1.0)
# Assert - Should have 3 different cache entries
cache_keys = list(processor.generation_configs.keys())
assert len(cache_keys) == 3
assert "gemini-2.0-flash-001:0.0" in cache_keys
assert "gemini-2.0-flash-001:0.5" in cache_keys
assert "gemini-2.0-flash-001:1.0" in cache_keys
# Verify each cached config has the correct temperature
assert processor.generation_configs["gemini-2.0-flash-001:0.0"].temperature == 0.0
assert processor.generation_configs["gemini-2.0-flash-001:0.5"].temperature == 0.5
assert processor.generation_configs["gemini-2.0-flash-001:1.0"].temperature == 1.0
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_googleai_cache_reuse_same_parameters(self, mock_llm_init, mock_async_init, mock_genai):
"""Test that GoogleAI processor reuses cache for identical model+temperature combinations"""
# Arrange
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
mock_response = MagicMock()
mock_response.text = "Generated response"
mock_response.usage_metadata.prompt_token_count = 10
mock_response.usage_metadata.candidates_token_count = 5
mock_client.models.generate_content.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 1024,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = GoogleAIProcessor(**config)
# Act - Call multiple times with same parameters
await processor.generate_content("System", "Prompt 1", model="gemini-2.0-flash-001", temperature=0.7)
await processor.generate_content("System", "Prompt 2", model="gemini-2.0-flash-001", temperature=0.7)
await processor.generate_content("System", "Prompt 3", model="gemini-2.0-flash-001", temperature=0.7)
# Assert - Should have only 1 cache entry for the repeated parameters
cache_keys = list(processor.generation_configs.keys())
assert len(cache_keys) == 1
assert "gemini-2.0-flash-001:0.7" in cache_keys
# The same config object should be reused
config_obj = processor.generation_configs["gemini-2.0-flash-001:0.7"]
assert config_obj.temperature == 0.7
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_googleai_different_models_separate_caches(self, mock_llm_init, mock_async_init, mock_genai):
"""Test that different models create separate cache entries even with same temperature"""
# Arrange
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
mock_response = MagicMock()
mock_response.text = "Generated response"
mock_response.usage_metadata.prompt_token_count = 10
mock_response.usage_metadata.candidates_token_count = 5
mock_client.models.generate_content.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 1024,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = GoogleAIProcessor(**config)
# Act - Call with different models, same temperature
await processor.generate_content("System", "Prompt 1", model="gemini-2.0-flash-001", temperature=0.5)
await processor.generate_content("System", "Prompt 2", model="gemini-1.5-flash-001", temperature=0.5)
# Assert - Should have separate cache entries for different models
cache_keys = list(processor.generation_configs.keys())
assert len(cache_keys) == 2
assert "gemini-2.0-flash-001:0.5" in cache_keys
assert "gemini-1.5-flash-001:0.5" in cache_keys
# Note: Bedrock tests would be similar but testing the Bedrock processor's caching behavior
# The Bedrock processor caches model variants with temperature in the cache key
async def test_bedrock_temperature_cache_keys(self):
"""Test Bedrock processor temperature-aware caching"""
# This would test the Bedrock processor's _get_or_create_variant method
# with different temperature values to ensure proper cache key generation
# Implementation would follow similar pattern to GoogleAI tests above
# but using the Bedrock processor and testing model_variants cache
pass
async def test_bedrock_cache_isolation_different_temperatures(self):
"""Test that Bedrock processor isolates cache entries by temperature"""
pass
async def test_cache_memory_efficiency(self):
"""Test that caches don't grow unbounded with many different parameter combinations"""
# This could test cache size limits or cleanup behavior if implemented
pass
class TestCachePerformance(IsolatedAsyncioTestCase):
"""Test caching performance characteristics"""
async def test_cache_hit_performance(self):
"""Test that cache hits are faster than cache misses"""
# This would measure timing differences between cache hits and misses
pass
async def test_concurrent_cache_access(self):
"""Test concurrent access to cached configurations"""
# This would test thread-safety of cache access
pass
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,271 @@
"""
Unit tests for trustgraph.model.text_completion.tgi
Following the same successful pattern as previous tests
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.model.text_completion.tgi.llm import Processor
from trustgraph.base import LlmResult
from trustgraph.exceptions import TooManyRequests
class TestTGIProcessorSimple(IsolatedAsyncioTestCase):
"""Test TGI processor functionality"""
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test basic processor initialization"""
# Arrange
mock_session = MagicMock()
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'tgi',
'url': 'http://tgi-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.default_model == 'tgi'
assert processor.base_url == 'http://tgi-service:8899/v1'
assert processor.temperature == 0.0
assert processor.max_output == 2048
assert hasattr(processor, 'session')
mock_session_class.assert_called_once()
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test successful content generation"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'message': {
'content': 'Generated response from TGI'
}
}],
'usage': {
'prompt_tokens': 20,
'completion_tokens': 12
}
})
# Mock the async context manager
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'tgi',
'url': 'http://tgi-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System prompt", "User prompt")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Generated response from TGI"
assert result.in_token == 20
assert result.out_token == 12
assert result.model == 'tgi'
# Verify the API call was made correctly
mock_session.post.assert_called_once()
call_args = mock_session.post.call_args
# Check URL
assert call_args[0][0] == 'http://tgi-service:8899/v1/chat/completions'
# Check request structure
request_body = call_args[1]['json']
assert request_body['model'] == 'tgi'
assert request_body['temperature'] == 0.0
assert request_body['max_tokens'] == 2048
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test generate_content with model parameter override"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'message': {
'content': 'Response from overridden model'
}
}],
'usage': {
'prompt_tokens': 15,
'completion_tokens': 10
}
})
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'tgi',
'url': 'http://tgi-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model
result = await processor.generate_content("System", "Prompt", model="custom-tgi-model")
# Assert
assert result.model == "custom-tgi-model" # Should use overridden model
assert result.text == "Response from overridden model"
# Verify the API call was made with overridden model
call_args = mock_session.post.call_args
assert call_args[1]['json']['model'] == "custom-tgi-model"
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test generate_content with temperature parameter override"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'message': {
'content': 'Response with temperature override'
}
}],
'usage': {
'prompt_tokens': 18,
'completion_tokens': 12
}
})
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'tgi',
'url': 'http://tgi-service:8899/v1',
'temperature': 0.0, # Default temperature
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature
result = await processor.generate_content("System", "Prompt", temperature=0.7)
# Assert
assert result.text == "Response with temperature override"
# Verify the API call was made with overridden temperature
call_args = mock_session.post.call_args
assert call_args[1]['json']['temperature'] == 0.7
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test generate_content with both model and temperature overrides"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'message': {
'content': 'Response with both parameters override'
}
}],
'usage': {
'prompt_tokens': 20,
'completion_tokens': 15
}
})
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'tgi',
'url': 'http://tgi-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
# Assert
assert result.model == "override-model"
assert result.text == "Response with both parameters override"
# Verify the API call was made with overridden parameters
call_args = mock_session.post.call_args
assert call_args[1]['json']['model'] == "override-model"
assert call_args[1]['json']['temperature'] == 0.8
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -47,10 +47,10 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gemini-2.0-flash-001' # It's stored as 'model', not 'model_name'
assert hasattr(processor, 'generation_config')
assert processor.default_model == 'gemini-2.0-flash-001' # It's stored as 'model', not 'model_name'
assert hasattr(processor, 'generation_configs') # Now a cache dictionary
assert hasattr(processor, 'safety_settings')
assert hasattr(processor, 'llm')
assert hasattr(processor, 'model_clients') # LLM clients are now cached
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('private.json')
mock_vertexai.init.assert_called_once()
@ -102,7 +102,8 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
mock_model.generate_content.assert_called_once()
# Verify the call was made with the expected parameters
call_args = mock_model.generate_content.call_args
assert call_args[1]['generation_config'] == processor.generation_config
# Generation config is now created dynamically per model
assert 'generation_config' in call_args[1]
assert call_args[1]['safety_settings'] == processor.safety_settings
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@ -223,7 +224,7 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gemini-2.0-flash-001'
assert processor.default_model == 'gemini-2.0-flash-001'
mock_auth_default.assert_called_once()
mock_vertexai.init.assert_called_once_with(
location='us-central1',
@ -296,11 +297,11 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gemini-1.5-pro'
assert processor.default_model == 'gemini-1.5-pro'
# Verify that generation_config object exists (can't easily check internal values)
assert hasattr(processor, 'generation_config')
assert processor.generation_config is not None
assert hasattr(processor, 'generation_configs') # Now a cache dictionary
assert processor.generation_configs == {} # Empty cache initially
# Verify that safety settings are configured
assert len(processor.safety_settings) == 4
@ -353,8 +354,8 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
project='test-project-123'
)
# Verify GenerativeModel was created with the right model name
mock_generative_model.assert_called_once_with('gemini-2.0-flash-001')
# GenerativeModel is now created lazily on first use, not at initialization
mock_generative_model.assert_not_called()
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@ -440,8 +441,8 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'claude-3-sonnet@20240229'
assert processor.is_anthropic == True
assert processor.default_model == 'claude-3-sonnet@20240229'
# is_anthropic logic is now determined dynamically per request
# Verify service account was called with private key
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('anthropic-key.json')
@ -459,6 +460,180 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
assert processor.api_params["top_p"] == 1.0
assert processor.api_params["top_k"] == 32
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
"""Test temperature parameter override functionality"""
# Arrange
mock_credentials = MagicMock()
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = "Response with custom temperature"
mock_response.usage_metadata.prompt_token_count = 20
mock_response.usage_metadata.candidates_token_count = 12
mock_model.generate_content.return_value = mock_response
mock_generative_model.return_value = mock_model
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'region': 'us-central1',
'model': 'gemini-2.0-flash-001',
'temperature': 0.0, # Default temperature
'max_output': 8192,
'private_key': 'private.json',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.8 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify Gemini API was called with overridden temperature
mock_model.generate_content.assert_called_once()
call_args = mock_model.generate_content.call_args
# Check that generation_config was created (we can't directly access temperature from mock)
generation_config = call_args.kwargs['generation_config']
assert generation_config is not None # Should use overridden temperature configuration
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
"""Test model parameter override functionality"""
# Arrange
mock_credentials = MagicMock()
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
# Mock different models
mock_model_default = MagicMock()
mock_model_override = MagicMock()
mock_response = MagicMock()
mock_response.text = "Response with custom model"
mock_response.usage_metadata.prompt_token_count = 18
mock_response.usage_metadata.candidates_token_count = 14
mock_model_override.generate_content.return_value = mock_response
# GenerativeModel should return different models based on input
def model_factory(model_name):
if model_name == 'gemini-1.5-pro':
return mock_model_override
return mock_model_default
mock_generative_model.side_effect = model_factory
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'region': 'us-central1',
'model': 'gemini-2.0-flash-001', # Default model
'temperature': 0.2, # Default temperature
'max_output': 8192,
'private_key': 'private.json',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gemini-1.5-pro", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify the overridden model was used
mock_model_override.generate_content.assert_called_once()
# Verify GenerativeModel was called with the override model
mock_generative_model.assert_called_with('gemini-1.5-pro')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_credentials = MagicMock()
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = "Response with both overrides"
mock_response.usage_metadata.prompt_token_count = 22
mock_response.usage_metadata.candidates_token_count = 16
mock_model.generate_content.return_value = mock_response
mock_generative_model.return_value = mock_model
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'region': 'us-central1',
'model': 'gemini-2.0-flash-001', # Default model
'temperature': 0.0, # Default temperature
'max_output': 8192,
'private_key': 'private.json',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gemini-1.5-flash-001", # Override model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify both overrides were used
mock_model.generate_content.assert_called_once()
call_args = mock_model.generate_content.call_args
# Verify model override
mock_generative_model.assert_called_with('gemini-1.5-flash-001') # Should use runtime override
# Verify temperature override (we can't directly access temperature from mock)
generation_config = call_args.kwargs['generation_config']
assert generation_config is not None # Should use overridden temperature configuration
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -42,7 +42,7 @@ class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'TheBloke/Mistral-7B-v0.1-AWQ'
assert processor.default_model == 'TheBloke/Mistral-7B-v0.1-AWQ'
assert processor.base_url == 'http://vllm-service:8899/v1'
assert processor.temperature == 0.0
assert processor.max_output == 2048
@ -199,7 +199,7 @@ class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'custom-model'
assert processor.default_model == 'custom-model'
assert processor.base_url == 'http://custom-vllm:8080/v1'
assert processor.temperature == 0.7
assert processor.max_output == 1024
@ -228,7 +228,7 @@ class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'TheBloke/Mistral-7B-v0.1-AWQ' # default_model
assert processor.default_model == 'TheBloke/Mistral-7B-v0.1-AWQ' # default_model
assert processor.base_url == 'http://vllm-service:8899/v1' # default_base_url
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 2048 # default_max_output
@ -485,5 +485,148 @@ class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
assert call_args[1]['json']['prompt'] == expected_prompt
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test generate_content with model parameter override"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'text': 'Response from overridden model'
}],
'usage': {
'prompt_tokens': 12,
'completion_tokens': 8
}
})
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
'url': 'http://vllm-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model
result = await processor.generate_content("System", "Prompt", model="custom-vllm-model")
# Assert
assert result.model == "custom-vllm-model" # Should use overridden model
assert result.text == "Response from overridden model"
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test generate_content with temperature parameter override"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'text': 'Response with temperature override'
}],
'usage': {
'prompt_tokens': 15,
'completion_tokens': 10
}
})
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
'url': 'http://vllm-service:8899/v1',
'temperature': 0.0, # Default temperature
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature
result = await processor.generate_content("System", "Prompt", temperature=0.7)
# Assert
assert result.text == "Response with temperature override"
# Verify the request was made with overridden temperature
call_args = mock_session.post.call_args
assert call_args[1]['json']['temperature'] == 0.7
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test generate_content with both model and temperature overrides"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'text': 'Response with both parameters override'
}],
'usage': {
'prompt_tokens': 18,
'completion_tokens': 12
}
})
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
'url': 'http://vllm-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
# Assert
assert result.model == "override-model"
assert result.text == "Response with both parameters override"
# Verify the request was made with overridden temperature
call_args = mock_session.post.call_args
assert call_args[1]['json']['temperature'] == 0.8
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -12,6 +12,7 @@ requires-python = ">=3.8"
dependencies = [
"pulsar-client",
"prometheus-client",
"requests",
]
classifiers = [
"Programming Language :: Python :: 3",

View file

@ -87,7 +87,7 @@ class Flow:
return json.loads(self.request(request = input)["flow"])
def start(self, class_name, id, description):
def start(self, class_name, id, description, parameters=None):
# The input consists of system and prompt strings
input = {
@ -97,6 +97,9 @@ class Flow:
"description": description,
}
if parameters:
input["parameters"] = parameters
self.request(request = input)
def stop(self, id):

View file

@ -8,11 +8,12 @@ from . subscriber import Subscriber
from . metrics import ProcessorMetrics, ConsumerMetrics, ProducerMetrics
from . flow_processor import FlowProcessor
from . consumer_spec import ConsumerSpec
from . setting_spec import SettingSpec
from . parameter_spec import ParameterSpec
from . producer_spec import ProducerSpec
from . subscriber_spec import SubscriberSpec
from . request_response_spec import RequestResponseSpec
from . llm_service import LlmService, LlmResult
from . chunking_service import ChunkingService
from . embeddings_service import EmbeddingsService
from . embeddings_client import EmbeddingsClientSpec
from . text_completion_client import TextCompletionClientSpec

View file

@ -0,0 +1,62 @@
"""
Base chunking service that provides parameter specification functionality
for chunk-size and chunk-overlap parameters
"""
import logging
from .flow_processor import FlowProcessor
from .parameter_spec import ParameterSpec
# Module logger
logger = logging.getLogger(__name__)
class ChunkingService(FlowProcessor):
"""Base service for chunking processors with parameter specification support"""
def __init__(self, **params):
# Call parent constructor
super(ChunkingService, self).__init__(**params)
# Register parameter specifications for chunk-size and chunk-overlap
self.register_specification(
ParameterSpec(name="chunk-size")
)
self.register_specification(
ParameterSpec(name="chunk-overlap")
)
logger.debug("ChunkingService initialized with parameter specifications")
async def chunk_document(self, msg, consumer, flow, default_chunk_size, default_chunk_overlap):
"""
Extract chunk parameters from flow and return effective values
Args:
msg: The message containing the document to chunk
consumer: The consumer spec
flow: The flow context
default_chunk_size: Default chunk size from processor config
default_chunk_overlap: Default chunk overlap from processor config
Returns:
tuple: (chunk_size, chunk_overlap) - effective values to use
"""
# Extract parameters from flow (flow-configurable parameters)
chunk_size = flow("chunk-size")
chunk_overlap = flow("chunk-overlap")
# Use provided values or fall back to defaults
effective_chunk_size = chunk_size if chunk_size is not None else default_chunk_size
effective_chunk_overlap = chunk_overlap if chunk_overlap is not None else default_chunk_overlap
logger.debug(f"Using chunk-size: {effective_chunk_size}")
logger.debug(f"Using chunk-overlap: {effective_chunk_overlap}")
return effective_chunk_size, effective_chunk_overlap
@staticmethod
def add_args(parser):
"""Add chunking service arguments to parser"""
FlowProcessor.add_args(parser)

View file

@ -12,7 +12,7 @@ class Flow:
# Consumers and publishers. Is this a bit untidy?
self.consumer = {}
self.setting = {}
self.parameter = {}
for spec in processor.specifications:
spec.add(self, processor, defn)
@ -28,5 +28,5 @@ class Flow:
def __call__(self, key):
if key in self.producer: return self.producer[key]
if key in self.consumer: return self.consumer[key]
if key in self.setting: return self.setting[key].value
if key in self.parameter: return self.parameter[key].value
return None

View file

@ -35,7 +35,7 @@ class FlowProcessor(AsyncProcessor):
# These can be overriden by a derived class:
# Array of specifications: ConsumerSpec, ProducerSpec, SettingSpec
# Array of specifications: ConsumerSpec, ProducerSpec, ParameterSpec
self.specifications = []
logger.info("Service initialised.")

View file

@ -5,11 +5,11 @@ LLM text completion base class
import time
import logging
from prometheus_client import Histogram
from prometheus_client import Histogram, Info
from .. schema import TextCompletionRequest, TextCompletionResponse, Error
from .. exceptions import TooManyRequests
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec, ParameterSpec
# Module logger
logger = logging.getLogger(__name__)
@ -32,7 +32,7 @@ class LlmService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
id = params.get("id", default_ident)
concurrency = params.get("concurrency", 1)
super(LlmService, self).__init__(**params | {
@ -56,6 +56,18 @@ class LlmService(FlowProcessor):
)
)
self.register_specification(
ParameterSpec(
name = "model",
)
)
self.register_specification(
ParameterSpec(
name = "temperature",
)
)
if not hasattr(__class__, "text_completion_metric"):
__class__.text_completion_metric = Histogram(
'text_completion_duration',
@ -70,6 +82,13 @@ class LlmService(FlowProcessor):
]
)
if not hasattr(__class__, "text_completion_model_metric"):
__class__.text_completion_model_metric = Info(
'text_completion_model',
'Text completion model',
["processor", "flow"]
)
async def on_request(self, msg, consumer, flow):
try:
@ -85,10 +104,21 @@ class LlmService(FlowProcessor):
flow=f"{flow.name}-{consumer.name}",
).time():
model = flow("model")
temperature = flow("temperature")
response = await self.generate_content(
request.system, request.prompt
request.system, request.prompt, model, temperature
)
__class__.text_completion_model_metric.labels(
processor = self.id,
flow = flow.name
).info({
"model": str(model) if model is not None else "",
"temperature": str(temperature) if temperature is not None else "",
})
await flow("response").send(
TextCompletionResponse(
error=None,

View file

@ -1,7 +1,7 @@
from . spec import Spec
class Setting:
class Parameter:
def __init__(self, value):
self.value = value
async def start():
@ -9,11 +9,13 @@ class Setting:
async def stop():
pass
class SettingSpec(Spec):
class ParameterSpec(Spec):
def __init__(self, name):
self.name = name
def add(self, flow, processor, definition):
flow.config[self.name] = Setting(definition[self.name])
value = definition.get(self.name, None)
flow.parameter[self.name] = Parameter(value)

View file

@ -49,8 +49,6 @@ class RequestResponse(Subscriber):
id = str(uuid.uuid4())
logger.debug(f"Sending request {id}...")
q = await self.subscribe(id)
try:
@ -75,8 +73,6 @@ class RequestResponse(Subscriber):
timeout=timeout
)
logger.debug("Received response")
if recipient is None:
# If no recipient handler, just return the first

View file

@ -12,12 +12,13 @@ class FlowRequestTranslator(MessageTranslator):
class_name=data.get("class-name"),
class_definition=data.get("class-definition"),
description=data.get("description"),
flow_id=data.get("flow-id")
flow_id=data.get("flow-id"),
parameters=data.get("parameters")
)
def from_pulsar(self, obj: FlowRequest) -> Dict[str, Any]:
result = {}
if obj.operation is not None:
result["operation"] = obj.operation
if obj.class_name is not None:
@ -28,7 +29,9 @@ class FlowRequestTranslator(MessageTranslator):
result["description"] = obj.description
if obj.flow_id is not None:
result["flow-id"] = obj.flow_id
if obj.parameters is not None:
result["parameters"] = obj.parameters
return result
@ -40,7 +43,7 @@ class FlowResponseTranslator(MessageTranslator):
def from_pulsar(self, obj: FlowResponse) -> Dict[str, Any]:
result = {}
if obj.class_names is not None:
result["class-names"] = obj.class_names
if obj.flow_ids is not None:
@ -51,7 +54,9 @@ class FlowResponseTranslator(MessageTranslator):
result["flow"] = obj.flow
if obj.description is not None:
result["description"] = obj.description
if obj.parameters is not None:
result["parameters"] = obj.parameters
return result
def from_response_with_completion(self, obj: FlowResponse) -> Tuple[Dict[str, Any], bool]:

View file

@ -35,6 +35,9 @@ class FlowRequest(Record):
# get_flow, start_flow, stop_flow
flow_id = String()
# start_flow - optional parameters for flow customization
parameters = Map(String())
class FlowResponse(Record):
# list_classes
@ -52,6 +55,9 @@ class FlowResponse(Record):
# get_flow
description = String()
# get_flow - parameters used when flow was started
parameters = Map(String())
# Everything
error = Error()

View file

@ -183,13 +183,13 @@ class Processor(LlmService):
}
)
self.model = model
# Store default configuration
self.default_model = model
self.temperature = temperature
self.max_output = max_output
self.variant = self.determine_variant(self.model)()
self.variant.set_temperature(temperature)
self.variant.set_max_output(max_output)
# Cache for model variants to avoid re-initialization
self.model_variants = {}
self.session = boto3.Session(
aws_access_key_id=aws_access_key_id,
@ -208,47 +208,75 @@ class Processor(LlmService):
# FIXME: Missing, Amazon models, Deepseek
# This set of conditions deals with normal bedrock on-demand usage
if self.model.startswith("mistral"):
if model.startswith("mistral"):
return Mistral
elif self.model.startswith("meta"):
elif model.startswith("meta"):
return Meta
elif self.model.startswith("anthropic"):
elif model.startswith("anthropic"):
return Anthropic
elif self.model.startswith("ai21"):
elif model.startswith("ai21"):
return Ai21
elif self.model.startswith("cohere"):
elif model.startswith("cohere"):
return Cohere
# The inference profiles
if self.model.startswith("us.meta"):
if model.startswith("us.meta"):
return Meta
elif self.model.startswith("us.anthropic"):
elif model.startswith("us.anthropic"):
return Anthropic
elif self.model.startswith("eu.meta"):
elif model.startswith("eu.meta"):
return Meta
elif self.model.startswith("eu.anthropic"):
elif model.startswith("eu.anthropic"):
return Anthropic
return Default
async def generate_content(self, system, prompt):
def _get_or_create_variant(self, model_name, temperature=None):
"""Get cached model variant or create new one"""
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
# Create a cache key that includes temperature to avoid conflicts
cache_key = f"{model_name}:{effective_temperature}"
if cache_key not in self.model_variants:
logger.info(f"Creating model variant for '{model_name}' with temperature {effective_temperature}")
variant_class = self.determine_variant(model_name)
variant = variant_class()
variant.set_temperature(effective_temperature)
variant.set_max_output(self.max_output)
self.model_variants[cache_key] = variant
return self.model_variants[cache_key]
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
try:
# Get the appropriate variant for this model
variant = self._get_or_create_variant(model_name, effective_temperature)
promptbody = self.variant.encode_request(system, prompt)
promptbody = variant.encode_request(system, prompt)
accept = 'application/json'
contentType = 'application/json'
response = self.bedrock.invoke_model(
body=promptbody,
modelId=self.model,
modelId=model_name,
accept=accept,
contentType=contentType
)
# Response structure decode
outputtext = self.variant.decode_response(response)
outputtext = variant.decode_response(response)
metadata = response['ResponseMetadata']['HTTPHeaders']
inputtokens = int(metadata['x-amzn-bedrock-input-token-count'])
@ -262,7 +290,7 @@ class Processor(LlmService):
text = outputtext,
in_token = inputtokens,
out_token = outputtokens,
model = self.model
model = model_name
)
return resp

View file

@ -72,6 +72,7 @@ tg-show-kg-cores = "trustgraph.cli.show_kg_cores:main"
tg-show-library-documents = "trustgraph.cli.show_library_documents:main"
tg-show-library-processing = "trustgraph.cli.show_library_processing:main"
tg-show-mcp-tools = "trustgraph.cli.show_mcp_tools:main"
tg-show-parameter-types = "trustgraph.cli.show_parameter_types:main"
tg-show-processor-state = "trustgraph.cli.show_processor_state:main"
tg-show-prompts = "trustgraph.cli.show_prompts:main"
tg-show-token-costs = "trustgraph.cli.show_token_costs:main"

View file

@ -599,8 +599,7 @@ def _send_to_trustgraph(objects, api_url, flow, batch_size=1000):
imported_count += 1
if imported_count % 100 == 0:
logger.info(f"Imported {imported_count}/{len(objects)} records...")
print(f"✅ Imported {imported_count}/{len(objects)} records...")
logger.debug(f"Imported {imported_count}/{len(objects)} records...")
except Exception as e:
logger.error(f"Failed to send record {imported_count + 1}: {e}")

View file

@ -5,38 +5,93 @@ Shows all defined flow classes.
import argparse
import os
import tabulate
from trustgraph.api import Api
from trustgraph.api import Api, ConfigKey
import json
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
def format_parameters(params_metadata, config_api):
"""
Format parameter metadata for display
Args:
params_metadata: Parameter definitions from flow class
config_api: API client to get parameter type information
Returns:
Formatted string describing parameters
"""
if not params_metadata:
return "None"
param_list = []
# Sort parameters by order if available
sorted_params = sorted(
params_metadata.items(),
key=lambda x: x[1].get("order", 999)
)
for param_name, param_meta in sorted_params:
description = param_meta.get("description", param_name)
param_type = param_meta.get("type", "unknown")
# Get type information if available
type_info = param_type
if config_api:
try:
key = ConfigKey("parameter-types", param_type)
type_def_value = config_api.get([key])[0].value
param_type_def = json.loads(type_def_value)
# Add default value if available
default = param_type_def.get("default")
if default is not None:
type_info = f"{param_type} (default: {default})"
except:
# If we can't get type definition, just show the type name
pass
param_list.append(f" {param_name}: {description} [{type_info}]")
return "\n".join(param_list)
def show_flow_classes(url):
api = Api(url).flow()
api = Api(url)
flow_api = api.flow()
config_api = api.config()
class_names = api.list_classes()
class_names = flow_api.list_classes()
if len(class_names) == 0:
print("No flows.")
print("No flow classes.")
return
classes = []
for class_name in class_names:
cls = api.get_class(class_name)
classes.append((
class_name,
cls.get("description", ""),
", ".join(cls.get("tags", [])),
))
cls = flow_api.get_class(class_name)
print(tabulate.tabulate(
classes,
tablefmt="pretty",
maxcolwidths=[None, 40, 20],
stralign="left",
headers = ["flow class", "description", "tags"],
))
table = []
table.append(("name", class_name))
table.append(("description", cls.get("description", "")))
tags = cls.get("tags", [])
if tags:
table.append(("tags", ", ".join(tags)))
# Show parameters if they exist
parameters = cls.get("parameters", {})
if parameters:
param_str = format_parameters(parameters, config_api)
table.append(("parameters", param_str))
print(tabulate.tabulate(
table,
tablefmt="pretty",
stralign="left",
))
print()
def main():

View file

@ -45,6 +45,89 @@ def describe_interfaces(intdefs, flow):
return "\n".join(lst)
def get_enum_description(param_value, param_type_def):
"""
Get the human-readable description for an enum value
Args:
param_value: The actual parameter value (e.g., "gpt-4")
param_type_def: The parameter type definition containing enum objects
Returns:
Human-readable description or the original value if not found
"""
enum_list = param_type_def.get("enum", [])
# Handle both old format (strings) and new format (objects with id/description)
for enum_item in enum_list:
if isinstance(enum_item, dict):
if enum_item.get("id") == param_value:
return enum_item.get("description", param_value)
elif enum_item == param_value:
return param_value
# If not found in enum, return original value
return param_value
def format_parameters(flow_params, class_params_metadata, config_api):
"""
Format flow parameters with their human-readable descriptions
Args:
flow_params: The actual parameter values used in the flow
class_params_metadata: The parameter metadata from the flow class definition
config_api: API client to retrieve parameter type definitions
Returns:
Formatted string of parameters with descriptions
"""
if not flow_params:
return "None"
param_list = []
# Sort parameters by order if available
sorted_params = sorted(
class_params_metadata.items(),
key=lambda x: x[1].get("order", 999)
)
for param_name, param_meta in sorted_params:
if param_name in flow_params:
value = flow_params[param_name]
description = param_meta.get("description", param_name)
param_type = param_meta.get("type", "")
controlled_by = param_meta.get("controlled-by", None)
# Try to get enum description if this parameter has a type definition
display_value = value
if param_type and config_api:
try:
from trustgraph.api import ConfigKey
key = ConfigKey("parameter-types", param_type)
type_def_value = config_api.get([key])[0].value
param_type_def = json.loads(type_def_value)
display_value = get_enum_description(value, param_type_def)
except:
# If we can't get the type definition, just use the original value
display_value = value
# Format the parameter line
line = f"{description}: {display_value}"
# Add controlled-by indicator if present
if controlled_by:
line += f" (controlled by {controlled_by})"
param_list.append(line)
# Add any parameters that aren't in the class metadata (shouldn't happen normally)
for param_name, value in flow_params.items():
if param_name not in class_params_metadata:
param_list.append(f"{param_name}: {value} (undefined)")
return "\n".join(param_list) if param_list else "None"
def show_flows(url):
api = Api(url)
@ -74,6 +157,26 @@ def show_flows(url):
table.append(("id", id))
table.append(("class", flow.get("class-name", "")))
table.append(("desc", flow.get("description", "")))
# Display parameters with human-readable descriptions
parameters = flow.get("parameters", {})
if parameters:
# Try to get the flow class definition for parameter metadata
class_name = flow.get("class-name", "")
if class_name:
try:
flow_class = flow_api.get_class(class_name)
class_params_metadata = flow_class.get("parameters", {})
param_str = format_parameters(parameters, class_params_metadata, config_api)
except Exception as e:
# Fallback to JSON if we can't get the class definition
param_str = json.dumps(parameters, indent=2)
else:
# No class name, fallback to JSON
param_str = json.dumps(parameters, indent=2)
table.append(("parameters", param_str))
table.append(("queue", describe_interfaces(interface_defs, flow)))
print(tabulate.tabulate(

View file

@ -0,0 +1,210 @@
"""
Shows all defined parameter types used in flow classes.
Parameter types define the schema and constraints for parameters that can
be used in flow class definitions. This includes data types, default values,
valid enums, and validation rules.
"""
import argparse
import os
import tabulate
from trustgraph.api import Api, ConfigKey
import json
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
def format_enum_values(enum_list):
"""
Format enum values for display, handling both old and new formats
Args:
enum_list: List of enum values (strings or objects with id/description)
Returns:
Formatted string describing enum options
"""
if not enum_list:
return "Any value"
enum_items = []
for item in enum_list:
if isinstance(item, dict):
# New format: objects with id and description
enum_id = item.get("id", "")
description = item.get("description", "")
if description:
enum_items.append(f"{enum_id} ({description})")
else:
enum_items.append(enum_id)
else:
# Old format: simple strings
enum_items.append(str(item))
return "\n".join(f"{item}" for item in enum_items)
def format_constraints(param_type_def):
"""
Format validation constraints for display
Args:
param_type_def: Parameter type definition
Returns:
Formatted string describing constraints
"""
constraints = []
# Handle numeric constraints
if "minimum" in param_type_def:
constraints.append(f"min: {param_type_def['minimum']}")
if "maximum" in param_type_def:
constraints.append(f"max: {param_type_def['maximum']}")
# Handle string constraints
if "minLength" in param_type_def:
constraints.append(f"min length: {param_type_def['minLength']}")
if "maxLength" in param_type_def:
constraints.append(f"max length: {param_type_def['maxLength']}")
if "pattern" in param_type_def:
constraints.append(f"pattern: {param_type_def['pattern']}")
# Handle required field
if param_type_def.get("required", False):
constraints.append("required")
return ", ".join(constraints) if constraints else "None"
def show_parameter_types(url):
"""
Show all parameter type definitions
"""
api = Api(url)
config_api = api.config()
# Get list of all parameter types
try:
param_type_names = config_api.list("parameter-types")
except Exception as e:
print(f"Error retrieving parameter types: {e}")
return
if len(param_type_names) == 0:
print("No parameter types defined.")
return
for param_type_name in param_type_names:
try:
# Get the parameter type definition
key = ConfigKey("parameter-types", param_type_name)
type_def_value = config_api.get([key])[0].value
param_type_def = json.loads(type_def_value)
table = []
table.append(("name", param_type_name))
table.append(("description", param_type_def.get("description", "")))
table.append(("type", param_type_def.get("type", "unknown")))
# Show default value if present
default = param_type_def.get("default")
if default is not None:
table.append(("default", str(default)))
# Show enum values if present
enum_list = param_type_def.get("enum")
if enum_list:
enum_str = format_enum_values(enum_list)
table.append(("valid values", enum_str))
# Show constraints
constraints = format_constraints(param_type_def)
if constraints != "None":
table.append(("constraints", constraints))
print(tabulate.tabulate(
table,
tablefmt="pretty",
stralign="left",
))
print()
except Exception as e:
print(f"Error retrieving parameter type '{param_type_name}': {e}")
print()
def main():
parser = argparse.ArgumentParser(
prog='tg-show-parameter-types',
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
'-u', '--api-url',
default=default_url,
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-t', '--type',
help='Show only the specified parameter type',
)
args = parser.parse_args()
try:
if args.type:
# Show specific parameter type
show_specific_parameter_type(args.api_url, args.type)
else:
# Show all parameter types
show_parameter_types(args.api_url)
except Exception as e:
print("Exception:", e, flush=True)
def show_specific_parameter_type(url, param_type_name):
"""
Show a specific parameter type definition
"""
api = Api(url)
config_api = api.config()
try:
# Get the parameter type definition
key = ConfigKey("parameter-types", param_type_name)
type_def_value = config_api.get([key])[0].value
param_type_def = json.loads(type_def_value)
table = []
table.append(("name", param_type_name))
table.append(("description", param_type_def.get("description", "")))
table.append(("type", param_type_def.get("type", "unknown")))
# Show default value if present
default = param_type_def.get("default")
if default is not None:
table.append(("default", str(default)))
# Show enum values if present
enum_list = param_type_def.get("enum")
if enum_list:
enum_str = format_enum_values(enum_list)
table.append(("valid values", enum_str))
# Show constraints
constraints = format_constraints(param_type_def)
if constraints != "None":
table.append(("constraints", constraints))
print(tabulate.tabulate(
table,
tablefmt="pretty",
stralign="left",
))
except Exception as e:
print(f"Error retrieving parameter type '{param_type_name}': {e}")
if __name__ == "__main__":
main()

View file

@ -1,5 +1,13 @@
"""
Starts a processing flow using a defined flow class
Starts a processing flow using a defined flow class.
Parameters can be provided in three ways:
1. As key=value pairs: --param model=gpt-4 --param temp=0.7
2. As JSON string: -p '{"model": "gpt-4", "temp": 0.7}'
3. As JSON file: --parameters-file params.json
Note: All parameter values are stored as strings internally, regardless of their
input format. Numbers and booleans will be converted to string representation.
"""
import argparse
@ -10,7 +18,7 @@ import json
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
def start_flow(url, class_name, flow_id, description):
def start_flow(url, class_name, flow_id, description, parameters=None):
api = Api(url).flow()
@ -18,6 +26,7 @@ def start_flow(url, class_name, flow_id, description):
class_name = class_name,
id = flow_id,
description = description,
parameters = parameters,
)
def main():
@ -51,15 +60,58 @@ def main():
help=f'Flow description',
)
parser.add_argument(
'-p', '--parameters',
help='Flow parameters as JSON string (e.g., \'{"model": "gpt-4", "temp": 0.7}\')',
)
parser.add_argument(
'--parameters-file',
help='Path to JSON file containing flow parameters',
)
parser.add_argument(
'--param',
action='append',
help='Flow parameter as key=value pair (can be used multiple times, e.g., --param model=gpt-4 --param temp=0.7)',
)
args = parser.parse_args()
try:
# Parse parameters from command line arguments
parameters = None
if args.parameters_file:
with open(args.parameters_file, 'r') as f:
params_data = json.load(f)
# Convert all values to strings
parameters = {k: str(v) for k, v in params_data.items()}
elif args.parameters:
params_data = json.loads(args.parameters)
# Convert all values to strings
parameters = {k: str(v) for k, v in params_data.items()}
elif args.param:
# Parse key=value pairs
parameters = {}
for param in args.param:
if '=' not in param:
raise ValueError(f"Invalid parameter format: {param}. Expected key=value")
key, value = param.split('=', 1)
key = key.strip()
value = value.strip()
# All parameter values must be strings for Pulsar
# Just store everything as a string
parameters[key] = value
start_flow(
url = args.api_url,
class_name = args.class_name,
flow_id = args.flow_id,
description = args.description,
parameters = parameters,
)
except Exception as e:

View file

@ -9,14 +9,14 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter
from prometheus_client import Histogram
from ... schema import TextDocument, Chunk
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import ChunkingService, ConsumerSpec, ProducerSpec
# Module logger
logger = logging.getLogger(__name__)
default_ident = "chunker"
class Processor(FlowProcessor):
class Processor(ChunkingService):
def __init__(self, **params):
@ -28,6 +28,10 @@ class Processor(FlowProcessor):
**params | { "id": id }
)
# Store default values for parameter override
self.default_chunk_size = chunk_size
self.default_chunk_overlap = chunk_overlap
if not hasattr(__class__, "chunk_metric"):
__class__.chunk_metric = Histogram(
'chunk_size', 'Chunk size',
@ -65,7 +69,22 @@ class Processor(FlowProcessor):
v = msg.value()
logger.info(f"Chunking document {v.metadata.id}...")
texts = self.text_splitter.create_documents(
# Extract chunk parameters from flow (allows runtime override)
chunk_size, chunk_overlap = await self.chunk_document(
msg, consumer, flow,
self.default_chunk_size,
self.default_chunk_overlap
)
# Create text splitter with effective parameters
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
is_separator_regex=False,
)
texts = text_splitter.create_documents(
[v.text.decode("utf-8")]
)
@ -89,7 +108,7 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
ChunkingService.add_args(parser)
parser.add_argument(
'-z', '--chunk-size',

View file

@ -9,14 +9,14 @@ from langchain_text_splitters import TokenTextSplitter
from prometheus_client import Histogram
from ... schema import TextDocument, Chunk
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import ChunkingService, ConsumerSpec, ProducerSpec
# Module logger
logger = logging.getLogger(__name__)
default_ident = "chunker"
class Processor(FlowProcessor):
class Processor(ChunkingService):
def __init__(self, **params):
@ -28,6 +28,10 @@ class Processor(FlowProcessor):
**params | { "id": id }
)
# Store default values for parameter override
self.default_chunk_size = chunk_size
self.default_chunk_overlap = chunk_overlap
if not hasattr(__class__, "chunk_metric"):
__class__.chunk_metric = Histogram(
'chunk_size', 'Chunk size',
@ -64,7 +68,21 @@ class Processor(FlowProcessor):
v = msg.value()
logger.info(f"Chunking document {v.metadata.id}...")
texts = self.text_splitter.create_documents(
# Extract chunk parameters from flow (allows runtime override)
chunk_size, chunk_overlap = await self.chunk_document(
msg, consumer, flow,
self.default_chunk_size,
self.default_chunk_overlap
)
# Create text splitter with effective parameters
text_splitter = TokenTextSplitter(
encoding_name="cl100k_base",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
texts = text_splitter.create_documents(
[v.text.decode("utf-8")]
)
@ -88,7 +106,7 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
ChunkingService.add_args(parser)
parser.add_argument(
'-z', '--chunk-size',

View file

@ -10,6 +10,95 @@ class FlowConfig:
def __init__(self, config):
self.config = config
# Cache for parameter type definitions to avoid repeated lookups
self.param_type_cache = {}
async def resolve_parameters(self, flow_class, user_params):
"""
Resolve parameters by merging user-provided values with defaults.
Args:
flow_class: The flow class definition dict
user_params: User-provided parameters dict (may be None or empty)
Returns:
Complete parameter dict with user values and defaults merged (all values as strings)
"""
# If the flow class has no parameters section, return user params as-is (stringified)
if "parameters" not in flow_class:
if not user_params:
return {}
# Ensure all values are strings
return {k: str(v) for k, v in user_params.items()}
resolved = {}
flow_params = flow_class["parameters"]
user_params = user_params if user_params else {}
# First pass: resolve parameters with explicit values or defaults
for param_name, param_meta in flow_params.items():
# Check if user provided a value
if param_name in user_params:
# Store as string
resolved[param_name] = str(user_params[param_name])
else:
# Look up the parameter type definition
param_type = param_meta.get("type")
if param_type:
# Check cache first
if param_type not in self.param_type_cache:
try:
# Fetch parameter type definition from config store
type_def = await self.config.get("parameter-types").get(param_type)
if type_def:
self.param_type_cache[param_type] = json.loads(type_def)
else:
logger.warning(f"Parameter type '{param_type}' not found in config")
self.param_type_cache[param_type] = {}
except Exception as e:
logger.error(f"Error fetching parameter type '{param_type}': {e}")
self.param_type_cache[param_type] = {}
# Apply default from type definition (as string)
type_def = self.param_type_cache[param_type]
if "default" in type_def:
default_value = type_def["default"]
# Convert to string based on type
if isinstance(default_value, bool):
resolved[param_name] = "true" if default_value else "false"
else:
resolved[param_name] = str(default_value)
elif type_def.get("required", False):
# Required parameter with no default and no user value
raise RuntimeError(f"Required parameter '{param_name}' not provided and has no default")
# Second pass: handle controlled-by relationships
for param_name, param_meta in flow_params.items():
if param_name not in resolved and "controlled-by" in param_meta:
controller = param_meta["controlled-by"]
if controller in resolved:
# Inherit value from controlling parameter (already a string)
resolved[param_name] = resolved[controller]
else:
# Controller has no value, try to get default from type definition
param_type = param_meta.get("type")
if param_type and param_type in self.param_type_cache:
type_def = self.param_type_cache[param_type]
if "default" in type_def:
default_value = type_def["default"]
# Convert to string based on type
if isinstance(default_value, bool):
resolved[param_name] = "true" if default_value else "false"
else:
resolved[param_name] = str(default_value)
# Include any extra parameters from user that weren't in flow class definition
# This allows for forward compatibility (ensure they're strings)
for key, value in user_params.items():
if key not in resolved:
resolved[key] = str(value)
return resolved
async def handle_list_classes(self, msg):
@ -68,11 +157,14 @@ class FlowConfig:
async def handle_get_flow(self, msg):
flow = await self.config.get("flows").get(msg.flow_id)
flow_data = await self.config.get("flows").get(msg.flow_id)
flow = json.loads(flow_data)
return FlowResponse(
error = None,
flow = flow,
flow = flow_data,
description = flow.get("description", ""),
parameters = flow.get("parameters", {}),
)
async def handle_start_flow(self, msg):
@ -83,45 +175,65 @@ class FlowConfig:
if msg.flow_id is None:
raise RuntimeError("No flow ID")
if msg.flow_id in await self.config.get("flows").values():
if msg.flow_id in await self.config.get("flows").keys():
raise RuntimeError("Flow already exists")
if msg.description is None:
raise RuntimeError("No description")
if msg.class_name not in await self.config.get("flow-classes").values():
if msg.class_name not in await self.config.get("flow-classes").keys():
raise RuntimeError("Class does not exist")
def repl_template(tmp):
return tmp.replace(
"{class}", msg.class_name
).replace(
"{id}", msg.flow_id
)
cls = json.loads(
await self.config.get("flow-classes").get(msg.class_name)
)
# Resolve parameters by merging user-provided values with defaults
user_params = msg.parameters if msg.parameters else {}
parameters = await self.resolve_parameters(cls, user_params)
# Log the resolved parameters for debugging
logger.debug(f"User provided parameters: {user_params}")
logger.debug(f"Resolved parameters (with defaults): {parameters}")
# Apply parameter substitution to template replacement function
def repl_template_with_params(tmp):
result = tmp.replace(
"{class}", msg.class_name
).replace(
"{id}", msg.flow_id
)
# Apply parameter substitutions
for param_name, param_value in parameters.items():
result = result.replace(f"{{{param_name}}}", str(param_value))
return result
for kind in ("class", "flow"):
for k, v in cls[kind].items():
processor, variant = k.split(":", 1)
variant = repl_template(variant)
variant = repl_template_with_params(variant)
v = {
repl_template(k2): repl_template(v2)
repl_template_with_params(k2): repl_template_with_params(v2)
for k2, v2 in v.items()
}
flac = await self.config.get("flows-active").values()
if processor in flac:
target = json.loads(flac[processor])
flac = await self.config.get("flows-active").get(processor)
if flac is not None:
target = json.loads(flac)
else:
target = {}
# The condition if variant not in target: means it only adds
# the configuration if the variant doesn't already exist.
# If "everything" already exists in the target with old
# values, they won't update.
if variant not in target:
target[variant] = v
@ -131,10 +243,10 @@ class FlowConfig:
def repl_interface(i):
if isinstance(i, str):
return repl_template(i)
return repl_template_with_params(i)
else:
return {
k: repl_template(v)
k: repl_template_with_params(v)
for k, v in i.items()
}
@ -152,6 +264,7 @@ class FlowConfig:
"description": msg.description,
"class-name": msg.class_name,
"interfaces": interfaces,
"parameters": parameters,
})
)
@ -177,15 +290,20 @@ class FlowConfig:
raise RuntimeError("Internal error: flow has no flow class")
class_name = flow["class-name"]
parameters = flow.get("parameters", {})
cls = json.loads(await self.config.get("flow-classes").get(class_name))
def repl_template(tmp):
return tmp.replace(
result = tmp.replace(
"{class}", class_name
).replace(
"{id}", msg.flow_id
)
# Apply parameter substitutions
for param_name, param_value in parameters.items():
result = result.replace(f"{{{param_name}}}", str(param_value))
return result
for kind in ("flow",):
@ -195,10 +313,10 @@ class FlowConfig:
variant = repl_template(variant)
flac = await self.config.get("flows-active").values()
flac = await self.config.get("flows-active").get(processor)
if processor in flac:
target = json.loads(flac[processor])
if flac is not None:
target = json.loads(flac)
else:
target = {}
@ -209,7 +327,7 @@ class FlowConfig:
processor, json.dumps(target)
)
if msg.flow_id in await self.config.get("flows").values():
if msg.flow_id in await self.config.get("flows").keys():
await self.config.get("flows").delete(msg.flow_id)
await self.config.inc_version()

View file

@ -24,16 +24,12 @@ class KnowledgeGraph:
self.keyspace = keyspace
self.username = username
# Multi-table schema design for optimal performance
self.use_legacy = os.getenv('CASSANDRA_USE_LEGACY', 'false').lower() == 'true'
if self.use_legacy:
self.table = "triples" # Legacy single table
else:
# New optimized tables
self.subject_table = "triples_s"
self.po_table = "triples_p"
self.object_table = "triples_o"
# Optimized multi-table schema with collection deletion support
self.subject_table = "triples_s"
self.po_table = "triples_p"
self.object_table = "triples_o"
self.collection_table = "triples_collection" # For SPO queries and deletion
self.collection_metadata_table = "collection_metadata" # For tracking which collections exist
if username and password:
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
@ -47,9 +43,7 @@ class KnowledgeGraph:
_active_clusters.append(self.cluster)
self.init()
if not self.use_legacy:
self.prepare_statements()
self.prepare_statements()
def clear(self):
@ -70,42 +64,13 @@ class KnowledgeGraph:
""");
self.session.set_keyspace(self.keyspace)
self.init_optimized_schema()
if self.use_legacy:
self.init_legacy_schema()
else:
self.init_optimized_schema()
def init_legacy_schema(self):
"""Initialize legacy single-table schema for backward compatibility"""
self.session.execute(f"""
create table if not exists {self.table} (
collection text,
s text,
p text,
o text,
PRIMARY KEY (collection, s, p, o)
);
""");
self.session.execute(f"""
create index if not exists {self.table}_s
ON {self.table} (s);
""");
self.session.execute(f"""
create index if not exists {self.table}_p
ON {self.table} (p);
""");
self.session.execute(f"""
create index if not exists {self.table}_o
ON {self.table} (o);
""");
def init_optimized_schema(self):
"""Initialize optimized multi-table schema for performance"""
# Table 1: Subject-centric queries (get_s, get_sp, get_spo, get_os)
# Table 1: Subject-centric queries (get_s, get_sp, get_os)
# Compound partition key for optimal data distribution
self.session.execute(f"""
CREATE TABLE IF NOT EXISTS {self.subject_table} (
collection text,
@ -117,6 +82,7 @@ class KnowledgeGraph:
""");
# Table 2: Predicate-Object queries (get_p, get_po) - eliminates ALLOW FILTERING!
# Compound partition key for optimal data distribution
self.session.execute(f"""
CREATE TABLE IF NOT EXISTS {self.po_table} (
collection text,
@ -128,6 +94,7 @@ class KnowledgeGraph:
""");
# Table 3: Object-centric queries (get_o)
# Compound partition key for optimal data distribution
self.session.execute(f"""
CREATE TABLE IF NOT EXISTS {self.object_table} (
collection text,
@ -138,7 +105,29 @@ class KnowledgeGraph:
);
""");
logger.info("Optimized multi-table schema initialized")
# Table 4: Collection management and SPO queries (get_spo)
# Simple partition key enables efficient collection deletion
self.session.execute(f"""
CREATE TABLE IF NOT EXISTS {self.collection_table} (
collection text,
s text,
p text,
o text,
PRIMARY KEY (collection, s, p, o)
);
""");
# Table 5: Collection metadata tracking
# Tracks which collections exist without polluting triple data
self.session.execute(f"""
CREATE TABLE IF NOT EXISTS {self.collection_metadata_table} (
collection text,
created_at timestamp,
PRIMARY KEY (collection)
);
""");
logger.info("Optimized multi-table schema initialized (5 tables)")
def prepare_statements(self):
"""Prepare statements for optimal performance"""
@ -155,6 +144,10 @@ class KnowledgeGraph:
f"INSERT INTO {self.object_table} (collection, o, s, p) VALUES (?, ?, ?, ?)"
)
self.insert_collection_stmt = self.session.prepare(
f"INSERT INTO {self.collection_table} (collection, s, p, o) VALUES (?, ?, ?, ?)"
)
# Query statements for optimized access
self.get_all_stmt = self.session.prepare(
f"SELECT s, p, o FROM {self.subject_table} WHERE collection = ? LIMIT ? ALLOW FILTERING"
@ -186,158 +179,177 @@ class KnowledgeGraph:
)
self.get_spo_stmt = self.session.prepare(
f"SELECT s as x FROM {self.subject_table} WHERE collection = ? AND s = ? AND p = ? AND o = ? LIMIT ?"
f"SELECT s as x FROM {self.collection_table} WHERE collection = ? AND s = ? AND p = ? AND o = ? LIMIT ?"
)
logger.info("Prepared statements initialized for optimal performance")
# Delete statements for collection deletion
self.delete_subject_stmt = self.session.prepare(
f"DELETE FROM {self.subject_table} WHERE collection = ? AND s = ? AND p = ? AND o = ?"
)
self.delete_po_stmt = self.session.prepare(
f"DELETE FROM {self.po_table} WHERE collection = ? AND p = ? AND o = ? AND s = ?"
)
self.delete_object_stmt = self.session.prepare(
f"DELETE FROM {self.object_table} WHERE collection = ? AND o = ? AND s = ? AND p = ?"
)
self.delete_collection_stmt = self.session.prepare(
f"DELETE FROM {self.collection_table} WHERE collection = ? AND s = ? AND p = ? AND o = ?"
)
logger.info("Prepared statements initialized for optimal performance (4 tables)")
def insert(self, collection, s, p, o):
# Batch write to all four tables for consistency
batch = BatchStatement()
if self.use_legacy:
self.session.execute(
f"insert into {self.table} (collection, s, p, o) values (%s, %s, %s, %s)",
(collection, s, p, o)
)
else:
# Batch write to all three tables for consistency
batch = BatchStatement()
# Insert into subject table
batch.add(self.insert_subject_stmt, (collection, s, p, o))
# Insert into subject table
batch.add(self.insert_subject_stmt, (collection, s, p, o))
# Insert into predicate-object table (column order: collection, p, o, s)
batch.add(self.insert_po_stmt, (collection, p, o, s))
# Insert into predicate-object table (column order: collection, p, o, s)
batch.add(self.insert_po_stmt, (collection, p, o, s))
# Insert into object table (column order: collection, o, s, p)
batch.add(self.insert_object_stmt, (collection, o, s, p))
# Insert into object table (column order: collection, o, s, p)
batch.add(self.insert_object_stmt, (collection, o, s, p))
# Insert into collection table for SPO queries and deletion tracking
batch.add(self.insert_collection_stmt, (collection, s, p, o))
self.session.execute(batch)
self.session.execute(batch)
def get_all(self, collection, limit=50):
if self.use_legacy:
return self.session.execute(
f"select s, p, o from {self.table} where collection = %s limit {limit}",
(collection,)
)
else:
# Use subject table for get_all queries
return self.session.execute(
self.get_all_stmt,
(collection, limit)
)
# Use subject table for get_all queries
return self.session.execute(
self.get_all_stmt,
(collection, limit)
)
def get_s(self, collection, s, limit=10):
if self.use_legacy:
return self.session.execute(
f"select p, o from {self.table} where collection = %s and s = %s limit {limit}",
(collection, s)
)
else:
# Optimized: Direct partition access with (collection, s)
return self.session.execute(
self.get_s_stmt,
(collection, s, limit)
)
# Optimized: Direct partition access with (collection, s)
return self.session.execute(
self.get_s_stmt,
(collection, s, limit)
)
def get_p(self, collection, p, limit=10):
if self.use_legacy:
return self.session.execute(
f"select s, o from {self.table} where collection = %s and p = %s limit {limit}",
(collection, p)
)
else:
# Optimized: Use po_table for direct partition access
return self.session.execute(
self.get_p_stmt,
(collection, p, limit)
)
# Optimized: Use po_table for direct partition access
return self.session.execute(
self.get_p_stmt,
(collection, p, limit)
)
def get_o(self, collection, o, limit=10):
if self.use_legacy:
return self.session.execute(
f"select s, p from {self.table} where collection = %s and o = %s limit {limit}",
(collection, o)
)
else:
# Optimized: Use object_table for direct partition access
return self.session.execute(
self.get_o_stmt,
(collection, o, limit)
)
# Optimized: Use object_table for direct partition access
return self.session.execute(
self.get_o_stmt,
(collection, o, limit)
)
def get_sp(self, collection, s, p, limit=10):
if self.use_legacy:
return self.session.execute(
f"select o from {self.table} where collection = %s and s = %s and p = %s limit {limit}",
(collection, s, p)
)
else:
# Optimized: Use subject_table with clustering key access
return self.session.execute(
self.get_sp_stmt,
(collection, s, p, limit)
)
# Optimized: Use subject_table with clustering key access
return self.session.execute(
self.get_sp_stmt,
(collection, s, p, limit)
)
def get_po(self, collection, p, o, limit=10):
if self.use_legacy:
return self.session.execute(
f"select s from {self.table} where collection = %s and p = %s and o = %s limit {limit} allow filtering",
(collection, p, o)
)
else:
# CRITICAL OPTIMIZATION: Use po_table - NO MORE ALLOW FILTERING!
return self.session.execute(
self.get_po_stmt,
(collection, p, o, limit)
)
# CRITICAL OPTIMIZATION: Use po_table - NO MORE ALLOW FILTERING!
return self.session.execute(
self.get_po_stmt,
(collection, p, o, limit)
)
def get_os(self, collection, o, s, limit=10):
if self.use_legacy:
return self.session.execute(
f"select p from {self.table} where collection = %s and o = %s and s = %s limit {limit} allow filtering",
(collection, o, s)
)
else:
# Optimized: Use subject_table with clustering access (no more ALLOW FILTERING)
return self.session.execute(
self.get_os_stmt,
(collection, s, o, limit)
)
# Optimized: Use subject_table with clustering access (no more ALLOW FILTERING)
return self.session.execute(
self.get_os_stmt,
(collection, s, o, limit)
)
def get_spo(self, collection, s, p, o, limit=10):
if self.use_legacy:
return self.session.execute(
f"""select s as x from {self.table} where collection = %s and s = %s and p = %s and o = %s limit {limit}""",
(collection, s, p, o)
# Optimized: Use collection_table for exact key lookup
return self.session.execute(
self.get_spo_stmt,
(collection, s, p, o, limit)
)
def collection_exists(self, collection):
"""Check if collection exists by querying collection_metadata table"""
try:
result = self.session.execute(
f"SELECT collection FROM {self.collection_metadata_table} WHERE collection = %s LIMIT 1",
(collection,)
)
else:
# Optimized: Use subject_table for exact key lookup
return self.session.execute(
self.get_spo_stmt,
(collection, s, p, o, limit)
return bool(list(result))
except Exception as e:
logger.error(f"Error checking collection existence: {e}")
return False
def create_collection(self, collection):
"""Create collection by inserting metadata row"""
try:
import datetime
self.session.execute(
f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)",
(collection, datetime.datetime.now())
)
logger.info(f"Created collection metadata for {collection}")
except Exception as e:
logger.error(f"Error creating collection: {e}")
raise e
def delete_collection(self, collection):
"""Delete all triples for a specific collection"""
if self.use_legacy:
self.session.execute(
f"delete from {self.table} where collection = %s",
(collection,)
)
else:
# Delete from all three tables
self.session.execute(
f"delete from {self.subject_table} where collection = %s",
(collection,)
)
self.session.execute(
f"delete from {self.po_table} where collection = %s",
(collection,)
)
self.session.execute(
f"delete from {self.object_table} where collection = %s",
(collection,)
)
"""Delete all triples for a specific collection
Uses collection_table to enumerate all triples, then deletes from all 4 tables
using full partition keys for optimal performance with compound keys.
"""
# Step 1: Read all triples from collection_table (single partition read)
rows = self.session.execute(
f"SELECT s, p, o FROM {self.collection_table} WHERE collection = %s",
(collection,)
)
# Step 2: Delete each triple from all 4 tables using full partition keys
# Batch deletions for efficiency
batch = BatchStatement()
count = 0
for row in rows:
s, p, o = row.s, row.p, row.o
# Delete from subject table (partition key: collection, s)
batch.add(self.delete_subject_stmt, (collection, s, p, o))
# Delete from predicate-object table (partition key: collection, p)
batch.add(self.delete_po_stmt, (collection, p, o, s))
# Delete from object table (partition key: collection, o)
batch.add(self.delete_object_stmt, (collection, o, s, p))
# Delete from collection table (partition key: collection only)
batch.add(self.delete_collection_stmt, (collection, s, p, o))
count += 1
# Execute batch every 100 triples to avoid oversized batches
if count % 100 == 0:
self.session.execute(batch)
batch = BatchStatement()
# Execute remaining deletions
if count % 100 != 0:
self.session.execute(batch)
# Step 3: Delete collection metadata
self.session.execute(
f"DELETE FROM {self.collection_metadata_table} WHERE collection = %s",
(collection,)
)
logger.info(f"Deleted {count} triples from collection {collection}")
def close(self):
"""Close the Cassandra session and cluster connections properly"""

View file

@ -49,6 +49,22 @@ class DocVectors:
self.next_reload = time.time() + self.reload_time
logger.debug(f"Reload at {self.next_reload}")
def collection_exists(self, user, collection):
"""Check if collection exists (dimension-independent check)"""
collection_name = make_safe_collection_name(user, collection, self.prefix)
return self.client.has_collection(collection_name)
def create_collection(self, user, collection, dimension=384):
"""Create collection with default dimension"""
collection_name = make_safe_collection_name(user, collection, self.prefix)
if self.client.has_collection(collection_name):
logger.info(f"Collection {collection_name} already exists")
return
self.init_collection(dimension, user, collection)
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
def init_collection(self, dimension, user, collection):
collection_name = make_safe_collection_name(user, collection, self.prefix)
@ -128,14 +144,6 @@ class DocVectors:
coll = self.collections[(dim, user, collection)]
search_params = {
"metric_type": "COSINE",
"params": {
"radius": 0.1,
"range_filter": 0.8
}
}
logger.debug("Loading...")
self.client.load_collection(
collection_name=coll,
@ -145,10 +153,11 @@ class DocVectors:
res = self.client.search(
collection_name=coll,
anns_field="vector",
data=[embeds],
limit=limit,
output_fields=fields,
search_params=search_params,
search_params={ "metric_type": "COSINE" },
)[0]

View file

@ -49,6 +49,22 @@ class EntityVectors:
self.next_reload = time.time() + self.reload_time
logger.debug(f"Reload at {self.next_reload}")
def collection_exists(self, user, collection):
"""Check if collection exists (dimension-independent check)"""
collection_name = make_safe_collection_name(user, collection, self.prefix)
return self.client.has_collection(collection_name)
def create_collection(self, user, collection, dimension=384):
"""Create collection with default dimension"""
collection_name = make_safe_collection_name(user, collection, self.prefix)
if self.client.has_collection(collection_name):
logger.info(f"Collection {collection_name} already exists")
return
self.init_collection(dimension, user, collection)
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
def init_collection(self, dimension, user, collection):
collection_name = make_safe_collection_name(user, collection, self.prefix)
@ -128,14 +144,6 @@ class EntityVectors:
coll = self.collections[(dim, user, collection)]
search_params = {
"metric_type": "COSINE",
"params": {
"radius": 0.1,
"range_filter": 0.8
}
}
logger.debug("Loading...")
self.client.load_collection(
collection_name=coll,
@ -145,10 +153,11 @@ class EntityVectors:
res = self.client.search(
collection_name=coll,
anns_field="vector",
data=[embeds],
limit=limit,
output_fields=fields,
search_params=search_params,
search_params={ "metric_type": "COSINE" },
)[0]

View file

@ -60,7 +60,7 @@ class CollectionManager:
async def ensure_collection_exists(self, user: str, collection: str):
"""
Ensure a collection exists, creating it if necessary (lazy creation)
Ensure a collection exists, creating it if necessary with broadcast to storage
Args:
user: User ID
@ -74,7 +74,7 @@ class CollectionManager:
return
# Create new collection with default metadata
logger.info(f"Creating new collection {user}/{collection}")
logger.info(f"Auto-creating collection {user}/{collection} from document submission")
await self.table_store.create_collection(
user=user,
collection=collection,
@ -83,10 +83,64 @@ class CollectionManager:
tags=set()
)
# Broadcast collection creation to all storage backends
creation_key = (user, collection)
logger.info(f"Broadcasting create-collection for {creation_key}")
self.pending_deletions[creation_key] = {
"responses_pending": 4, # doc-embeddings, graph-embeddings, object, triples
"responses_received": [],
"all_successful": True,
"error_messages": [],
"deletion_complete": asyncio.Event()
}
storage_request = StorageManagementRequest(
operation="create-collection",
user=user,
collection=collection
)
# Send creation requests to all storage types
if self.vector_storage_producer:
await self.vector_storage_producer.send(storage_request)
if self.object_storage_producer:
await self.object_storage_producer.send(storage_request)
if self.triples_storage_producer:
await self.triples_storage_producer.send(storage_request)
# Wait for all storage creations to complete (with timeout)
creation_info = self.pending_deletions[creation_key]
try:
await asyncio.wait_for(
creation_info["deletion_complete"].wait(),
timeout=30.0 # 30 second timeout
)
except asyncio.TimeoutError:
logger.error(f"Timeout waiting for storage creation responses for {creation_key}")
creation_info["all_successful"] = False
creation_info["error_messages"].append("Timeout waiting for storage creation")
# Check if all creations succeeded
if not creation_info["all_successful"]:
error_msg = f"Storage creation failed: {'; '.join(creation_info['error_messages'])}"
logger.error(error_msg)
# Clean up metadata on failure
await self.table_store.delete_collection(user, collection)
# Clean up tracking
del self.pending_deletions[creation_key]
raise RuntimeError(error_msg)
# Clean up tracking
del self.pending_deletions[creation_key]
logger.info(f"Collection {creation_key} auto-created successfully in all storage backends")
except Exception as e:
logger.error(f"Error ensuring collection exists: {e}")
# Don't fail the operation if collection creation fails
# This maintains backward compatibility
raise e
async def list_collections(self, request: CollectionManagementRequest) -> CollectionManagementResponse:
"""
@ -154,6 +208,67 @@ class CollectionManager:
tags=tags
)
# Broadcast collection creation to all storage backends
creation_key = (request.user, request.collection)
logger.info(f"Broadcasting create-collection for {creation_key}")
self.pending_deletions[creation_key] = {
"responses_pending": 4, # doc-embeddings, graph-embeddings, object, triples
"responses_received": [],
"all_successful": True,
"error_messages": [],
"deletion_complete": asyncio.Event()
}
storage_request = StorageManagementRequest(
operation="create-collection",
user=request.user,
collection=request.collection
)
# Send creation requests to all storage types
if self.vector_storage_producer:
await self.vector_storage_producer.send(storage_request)
if self.object_storage_producer:
await self.object_storage_producer.send(storage_request)
if self.triples_storage_producer:
await self.triples_storage_producer.send(storage_request)
# Wait for all storage creations to complete (with timeout)
creation_info = self.pending_deletions[creation_key]
try:
await asyncio.wait_for(
creation_info["deletion_complete"].wait(),
timeout=30.0 # 30 second timeout
)
except asyncio.TimeoutError:
logger.error(f"Timeout waiting for storage creation responses for {creation_key}")
creation_info["all_successful"] = False
creation_info["error_messages"].append("Timeout waiting for storage creation")
# Check if all creations succeeded
if not creation_info["all_successful"]:
error_msg = f"Storage creation failed: {'; '.join(creation_info['error_messages'])}"
logger.error(error_msg)
# Clean up metadata on failure
await self.table_store.delete_collection(request.user, request.collection)
# Clean up tracking
del self.pending_deletions[creation_key]
return CollectionManagementResponse(
error=Error(
type="storage_creation_error",
message=error_msg
),
timestamp=datetime.now().isoformat()
)
# Clean up tracking
del self.pending_deletions[creation_key]
logger.info(f"Collection {creation_key} created successfully in all storage backends")
# Get the newly created collection for response
created_collection = await self.table_store.get_collection(request.user, request.collection)
@ -213,7 +328,7 @@ class CollectionManager:
# Track this deletion request
self.pending_deletions[deletion_key] = {
"responses_pending": 3, # vector, object, triples
"responses_pending": 4, # doc-embeddings, graph-embeddings, object, triples
"responses_received": [],
"all_successful": True,
"error_messages": [],
@ -303,9 +418,9 @@ class CollectionManager:
if response.error and response.error.message:
info["all_successful"] = False
info["error_messages"].append(response.error.message)
logger.warning(f"Storage deletion failed for {deletion_key}: {response.error.message}")
logger.warning(f"Storage operation failed for {deletion_key}: {response.error.message}")
else:
logger.debug(f"Storage deletion succeeded for {deletion_key}")
logger.debug(f"Storage operation succeeded for {deletion_key}")
# If all responses received, signal completion
if info["responses_pending"] == 0:

View file

@ -32,7 +32,7 @@ class Processor(LlmService):
token = params.get("token", default_token)
temperature = params.get("temperature", default_temperature)
max_output = params.get("max_output", default_max_output)
model = default_model
model = params.get("model", default_model)
if endpoint is None:
raise RuntimeError("Azure endpoint not specified")
@ -53,9 +53,11 @@ class Processor(LlmService):
self.token = token
self.temperature = temperature
self.max_output = max_output
self.model = model
self.default_model = model
def build_prompt(self, system, content):
def build_prompt(self, system, content, temperature=None):
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
data = {
"messages": [
@ -67,7 +69,7 @@ class Processor(LlmService):
}
],
"max_tokens": self.max_output,
"temperature": self.temperature,
"temperature": effective_temperature,
"top_p": 1
}
@ -100,13 +102,22 @@ class Processor(LlmService):
return result
async def generate_content(self, system, prompt):
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
try:
prompt = self.build_prompt(
system,
prompt
prompt,
effective_temperature
)
response = self.call_llm(prompt)
@ -125,7 +136,7 @@ class Processor(LlmService):
text = resp,
in_token = inputtokens,
out_token = outputtokens,
model = self.model
model = model_name
)
return resp

View file

@ -54,7 +54,7 @@ class Processor(LlmService):
self.temperature = temperature
self.max_output = max_output
self.model = model
self.default_model = model
self.openai = AzureOpenAI(
api_key=token,
@ -62,14 +62,22 @@ class Processor(LlmService):
azure_endpoint = endpoint,
)
async def generate_content(self, system, prompt):
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
prompt = system + "\n\n" + prompt
try:
resp = self.openai.chat.completions.create(
model=self.model,
model=model_name,
messages=[
{
"role": "user",
@ -81,7 +89,7 @@ class Processor(LlmService):
]
}
],
temperature=self.temperature,
temperature=effective_temperature,
max_tokens=self.max_output,
top_p=1,
)
@ -97,7 +105,7 @@ class Processor(LlmService):
text = resp.choices[0].message.content,
in_token = inputtokens,
out_token = outputtokens,
model = self.model
model = model_name
)
return r

View file

@ -41,21 +41,29 @@ class Processor(LlmService):
}
)
self.model = model
self.default_model = model
self.claude = anthropic.Anthropic(api_key=api_key)
self.temperature = temperature
self.max_output = max_output
logger.info("Claude LLM service initialized")
async def generate_content(self, system, prompt):
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
try:
response = message = self.claude.messages.create(
model=self.model,
model=model_name,
max_tokens=self.max_output,
temperature=self.temperature,
temperature=effective_temperature,
system = system,
messages=[
{
@ -81,7 +89,7 @@ class Processor(LlmService):
text = resp,
in_token = inputtokens,
out_token = outputtokens,
model = self.model
model = model_name
)
return resp

View file

@ -39,21 +39,29 @@ class Processor(LlmService):
}
)
self.model = model
self.default_model = model
self.temperature = temperature
self.cohere = cohere.Client(api_key=api_key)
logger.info("Cohere LLM service initialized")
async def generate_content(self, system, prompt):
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
try:
output = self.cohere.chat(
model=self.model,
output = self.cohere.chat(
model=model_name,
message=prompt,
preamble = system,
temperature=self.temperature,
temperature=effective_temperature,
chat_history=[],
prompt_truncation='auto',
connectors=[]
@ -71,7 +79,7 @@ class Processor(LlmService):
text = resp,
in_token = inputtokens,
out_token = outputtokens,
model = self.model
model = model_name
)
return resp

View file

@ -53,10 +53,13 @@ class Processor(LlmService):
)
self.client = genai.Client(api_key=api_key)
self.model = model
self.default_model = model
self.temperature = temperature
self.max_output = max_output
# Cache for generation configs per model
self.generation_configs = {}
block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH
self.safety_settings = [
@ -83,22 +86,45 @@ class Processor(LlmService):
logger.info("GoogleAIStudio LLM service initialized")
async def generate_content(self, system, prompt):
def _get_or_create_config(self, model_name, temperature=None):
"""Get or create generation config with dynamic temperature"""
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
generation_config = types.GenerateContentConfig(
temperature = self.temperature,
top_p = 1,
top_k = 40,
max_output_tokens = self.max_output,
response_mime_type = "text/plain",
system_instruction = system,
safety_settings = self.safety_settings,
)
# Create cache key that includes temperature to avoid conflicts
cache_key = f"{model_name}:{effective_temperature}"
if cache_key not in self.generation_configs:
logger.info(f"Creating generation config for '{model_name}' with temperature {effective_temperature}")
self.generation_configs[cache_key] = types.GenerateContentConfig(
temperature = effective_temperature,
top_p = 1,
top_k = 40,
max_output_tokens = self.max_output,
response_mime_type = "text/plain",
safety_settings = self.safety_settings,
)
return self.generation_configs[cache_key]
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
generation_config = self._get_or_create_config(model_name, effective_temperature)
# Set system instruction per request (can't be cached)
generation_config.system_instruction = system
try:
response = self.client.models.generate_content(
model=self.model,
model=model_name,
config=generation_config,
contents=prompt,
)
@ -114,7 +140,7 @@ class Processor(LlmService):
text = resp,
in_token = inputtokens,
out_token = outputtokens,
model = self.model
model = model_name
)
return resp

View file

@ -39,7 +39,7 @@ class Processor(LlmService):
}
)
self.model = model
self.default_model = model
self.llamafile=llamafile
self.temperature = temperature
self.max_output = max_output
@ -50,25 +50,33 @@ class Processor(LlmService):
logger.info("Llamafile LLM service initialized")
async def generate_content(self, system, prompt):
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
prompt = system + "\n\n" + prompt
try:
resp = self.openai.chat.completions.create(
model=self.model,
model=model_name,
messages=[
{"role": "user", "content": prompt}
]
#temperature=self.temperature,
#max_tokens=self.max_output,
#top_p=1,
#frequency_penalty=0,
#presence_penalty=0,
#response_format={
# "type": "text"
#}
],
temperature=effective_temperature,
max_tokens=self.max_output,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
response_format={
"type": "text"
}
)
inputtokens = resp.usage.prompt_tokens
@ -82,7 +90,7 @@ class Processor(LlmService):
text = resp.choices[0].message.content,
in_token = inputtokens,
out_token = outputtokens,
model = "llama.cpp",
model = model_name,
)
return resp

View file

@ -39,7 +39,7 @@ class Processor(LlmService):
}
)
self.model = model
self.default_model = model
self.url = url + "v1/"
self.temperature = temperature
self.max_output = max_output
@ -50,7 +50,15 @@ class Processor(LlmService):
logger.info("LMStudio LLM service initialized")
async def generate_content(self, system, prompt):
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
prompt = system + "\n\n" + prompt
@ -59,18 +67,18 @@ class Processor(LlmService):
logger.debug(f"Prompt: {prompt}")
resp = self.openai.chat.completions.create(
model=self.model,
model=model_name,
messages=[
{"role": "user", "content": prompt}
]
#temperature=self.temperature,
#max_tokens=self.max_output,
#top_p=1,
#frequency_penalty=0,
#presence_penalty=0,
#response_format={
# "type": "text"
#}
],
temperature=effective_temperature,
max_tokens=self.max_output,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
response_format={
"type": "text"
}
)
logger.debug(f"Full response: {resp}")
@ -86,7 +94,7 @@ class Processor(LlmService):
text = resp.choices[0].message.content,
in_token = inputtokens,
out_token = outputtokens,
model = self.model
model = model_name
)
return resp

View file

@ -41,21 +41,29 @@ class Processor(LlmService):
}
)
self.model = model
self.default_model = model
self.temperature = temperature
self.max_output = max_output
self.mistral = Mistral(api_key=api_key)
logger.info("Mistral LLM service initialized")
async def generate_content(self, system, prompt):
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
prompt = system + "\n\n" + prompt
try:
resp = self.mistral.chat.complete(
model=self.model,
model=model_name,
messages=[
{
"role": "user",
@ -67,7 +75,7 @@ class Processor(LlmService):
]
}
],
temperature=self.temperature,
temperature=effective_temperature,
max_tokens=self.max_output,
top_p=1,
frequency_penalty=0,
@ -87,7 +95,7 @@ class Processor(LlmService):
text = resp.choices[0].message.content,
in_token = inputtokens,
out_token = outputtokens,
model = self.model
model = model_name
)
return resp

View file

@ -17,6 +17,7 @@ from .... base import LlmService, LlmResult
default_ident = "text-completion"
default_model = 'gemma2:9b'
default_temperature = 0.0
default_ollama = os.getenv("OLLAMA_HOST", 'http://localhost:11434')
class Processor(LlmService):
@ -24,25 +25,36 @@ class Processor(LlmService):
def __init__(self, **params):
model = params.get("model", default_model)
temperature = params.get("temperature", default_temperature)
ollama = params.get("ollama", default_ollama)
super(Processor, self).__init__(
**params | {
"model": model,
"temperature": temperature,
"ollama": ollama,
}
)
self.model = model
self.default_model = model
self.temperature = temperature
self.llm = Client(host=ollama)
async def generate_content(self, system, prompt):
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
prompt = system + "\n\n" + prompt
try:
response = self.llm.generate(self.model, prompt)
response = self.llm.generate(model_name, prompt, options={'temperature': effective_temperature})
response_text = response['response']
logger.debug("Sending response...")
@ -55,7 +67,7 @@ class Processor(LlmService):
text = response_text,
in_token = inputtokens,
out_token = outputtokens,
model = self.model
model = model_name
)
return resp
@ -84,6 +96,13 @@ class Processor(LlmService):
help=f'ollama (default: {default_ollama})'
)
parser.add_argument(
'-t', '--temperature',
type=float,
default=default_temperature,
help=f'LLM temperature parameter (default: {default_temperature})'
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -47,7 +47,7 @@ class Processor(LlmService):
}
)
self.model = model
self.default_model = model
self.temperature = temperature
self.max_output = max_output
@ -58,14 +58,22 @@ class Processor(LlmService):
logger.info("OpenAI LLM service initialized")
async def generate_content(self, system, prompt):
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
prompt = system + "\n\n" + prompt
try:
resp = self.openai.chat.completions.create(
model=self.model,
model=model_name,
messages=[
{
"role": "user",
@ -77,7 +85,7 @@ class Processor(LlmService):
]
}
],
temperature=self.temperature,
temperature=effective_temperature,
max_tokens=self.max_output,
top_p=1,
frequency_penalty=0,
@ -97,7 +105,7 @@ class Processor(LlmService):
text = resp.choices[0].message.content,
in_token = inputtokens,
out_token = outputtokens,
model = self.model
model = model_name
)
return resp

View file

@ -30,32 +30,43 @@ class Processor(LlmService):
base_url = params.get("url", default_base_url)
temperature = params.get("temperature", default_temperature)
max_output = params.get("max_output", default_max_output)
model = params.get("model", "tgi")
super(Processor, self).__init__(
**params | {
"temperature": temperature,
"max_output": max_output,
"url": base_url,
"model": model,
}
)
self.base_url = base_url
self.temperature = temperature
self.max_output = max_output
self.default_model = model
self.session = aiohttp.ClientSession()
logger.info(f"Using TGI service at {base_url}")
logger.info("TGI LLM service initialized")
async def generate_content(self, system, prompt):
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
headers = {
"Content-Type": "application/json",
}
request = {
"model": "tgi",
"model": model_name,
"messages": [
{
"role": "system",
@ -67,7 +78,7 @@ class Processor(LlmService):
}
],
"max_tokens": self.max_output,
"temperature": self.temperature,
"temperature": effective_temperature,
}
try:
@ -96,7 +107,7 @@ class Processor(LlmService):
text = ans,
in_token = inputtokens,
out_token = outputtokens,
model = "tgi",
model = model_name,
)
return resp

View file

@ -45,24 +45,32 @@ class Processor(LlmService):
self.base_url = base_url
self.temperature = temperature
self.max_output = max_output
self.model = model
self.default_model = model
self.session = aiohttp.ClientSession()
logger.info(f"Using vLLM service at {base_url}")
logger.info("vLLM LLM service initialized")
async def generate_content(self, system, prompt):
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
headers = {
"Content-Type": "application/json",
}
request = {
"model": self.model,
"model": model_name,
"prompt": system + "\n\n" + prompt,
"max_tokens": self.max_output,
"temperature": self.temperature,
"temperature": effective_temperature,
}
try:
@ -91,7 +99,7 @@ class Processor(LlmService):
text = ans,
in_token = inputtokens,
out_token = outputtokens,
model = self.model,
model = model_name,
)
return resp

View file

@ -57,21 +57,26 @@ class Processor(DocumentEmbeddingsQueryService):
raise e
self.last_collection = collection
def collection_exists(self, collection):
"""Check if collection exists (no implicit creation)"""
return self.qdrant.collection_exists(collection)
async def query_document_embeddings(self, msg):
try:
chunks = []
collection = (
"d_" + msg.user + "_" + msg.collection
)
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
logger.info(f"Collection {collection} does not exist, returning empty results")
return []
for vec in msg.vectors:
dim = len(vec)
collection = (
"d_" + msg.user + "_" + msg.collection
)
self.ensure_collection_exists(collection, dim)
search_result = self.qdrant.query_points(
collection_name=collection,
query=vec,

View file

@ -57,6 +57,10 @@ class Processor(GraphEmbeddingsQueryService):
raise e
self.last_collection = collection
def collection_exists(self, collection):
"""Check if collection exists (no implicit creation)"""
return self.qdrant.collection_exists(collection)
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
return Value(value=ent, is_uri=True)
@ -70,12 +74,16 @@ class Processor(GraphEmbeddingsQueryService):
entity_set = set()
entities = []
for vec in msg.vectors:
collection = (
"t_" + msg.user + "_" + msg.collection
)
dim = len(vec)
collection = (
"t_" + msg.user + "_" + msg.collection
)
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
logger.info(f"Collection {collection} does not exist, returning empty results")
return []
for vec in msg.vectors:
self.ensure_collection_exists(collection, dim)

View file

@ -1,12 +1,56 @@
import asyncio
import logging
import time
from collections import OrderedDict
# Module logger
logger = logging.getLogger(__name__)
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
class LRUCacheWithTTL:
"""LRU cache with TTL for label caching
CRITICAL SECURITY WARNING:
This cache is shared within a GraphRag instance but GraphRag instances
are created per-request. Cache keys MUST include user:collection prefix
to ensure data isolation between different security contexts.
"""
def __init__(self, max_size=5000, ttl=300):
self.cache = OrderedDict()
self.access_times = {}
self.max_size = max_size
self.ttl = ttl
def get(self, key):
if key not in self.cache:
return None
# Check TTL expiration
if time.time() - self.access_times[key] > self.ttl:
del self.cache[key]
del self.access_times[key]
return None
# Move to end (most recently used)
self.cache.move_to_end(key)
return self.cache[key]
def put(self, key, value):
if key in self.cache:
self.cache.move_to_end(key)
else:
if len(self.cache) >= self.max_size:
# Remove least recently used
oldest_key = next(iter(self.cache))
del self.cache[oldest_key]
del self.access_times[oldest_key]
self.cache[key] = value
self.access_times[key] = time.time()
class Query:
def __init__(
@ -61,8 +105,14 @@ class Query:
async def maybe_label(self, e):
if e in self.rag.label_cache:
return self.rag.label_cache[e]
# CRITICAL SECURITY: Cache key MUST include user and collection
# to prevent data leakage between different contexts
cache_key = f"{self.user}:{self.collection}:{e}"
# Check LRU cache first with isolated key
cached_label = self.rag.label_cache.get(cache_key)
if cached_label is not None:
return cached_label
res = await self.rag.triples_client.query(
s=e, p=LABEL, o=None, limit=1,
@ -70,60 +120,104 @@ class Query:
)
if len(res) == 0:
self.rag.label_cache[e] = e
self.rag.label_cache.put(cache_key, e)
return e
self.rag.label_cache[e] = str(res[0].o)
return self.rag.label_cache[e]
label = str(res[0].o)
self.rag.label_cache.put(cache_key, label)
return label
async def execute_batch_triple_queries(self, entities, limit_per_entity):
"""Execute triple queries for multiple entities concurrently"""
tasks = []
for entity in entities:
# Create concurrent tasks for all 3 query types per entity
tasks.extend([
self.rag.triples_client.query(
s=entity, p=None, o=None,
limit=limit_per_entity,
user=self.user, collection=self.collection
),
self.rag.triples_client.query(
s=None, p=entity, o=None,
limit=limit_per_entity,
user=self.user, collection=self.collection
),
self.rag.triples_client.query(
s=None, p=None, o=entity,
limit=limit_per_entity,
user=self.user, collection=self.collection
)
])
# Execute all queries concurrently
results = await asyncio.gather(*tasks, return_exceptions=True)
# Combine all results
all_triples = []
for result in results:
if not isinstance(result, Exception):
all_triples.extend(result)
return all_triples
async def follow_edges_batch(self, entities, max_depth):
"""Optimized iterative graph traversal with batching"""
visited = set()
current_level = set(entities)
subgraph = set()
for depth in range(max_depth):
if not current_level or len(subgraph) >= self.max_subgraph_size:
break
# Filter out already visited entities
unvisited_entities = [e for e in current_level if e not in visited]
if not unvisited_entities:
break
# Batch query all unvisited entities at current level
triples = await self.execute_batch_triple_queries(
unvisited_entities, self.triple_limit
)
# Process results and collect next level entities
next_level = set()
for triple in triples:
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
subgraph.add(triple_tuple)
# Collect entities for next level (only from s and o positions)
if depth < max_depth - 1: # Don't collect for final depth
s, p, o = triple_tuple
if s not in visited:
next_level.add(s)
if o not in visited:
next_level.add(o)
# Stop if subgraph size limit reached
if len(subgraph) >= self.max_subgraph_size:
return subgraph
# Update for next iteration
visited.update(current_level)
current_level = next_level
return subgraph
async def follow_edges(self, ent, subgraph, path_length):
# Not needed?
"""Legacy method - replaced by follow_edges_batch"""
# Maintain backward compatibility with early termination checks
if path_length <= 0:
return
# Stop spanning around if the subgraph is already maxed out
if len(subgraph) >= self.max_subgraph_size:
return
res = await self.rag.triples_client.query(
s=ent, p=None, o=None,
limit=self.triple_limit,
user=self.user, collection=self.collection,
)
for triple in res:
subgraph.add(
(str(triple.s), str(triple.p), str(triple.o))
)
if path_length > 1:
await self.follow_edges(str(triple.o), subgraph, path_length-1)
res = await self.rag.triples_client.query(
s=None, p=ent, o=None,
limit=self.triple_limit,
user=self.user, collection=self.collection,
)
for triple in res:
subgraph.add(
(str(triple.s), str(triple.p), str(triple.o))
)
res = await self.rag.triples_client.query(
s=None, p=None, o=ent,
limit=self.triple_limit,
user=self.user, collection=self.collection,
)
for triple in res:
subgraph.add(
(str(triple.s), str(triple.p), str(triple.o))
)
if path_length > 1:
await self.follow_edges(
str(triple.s), subgraph, path_length-1
)
# For backward compatibility, convert to new approach
batch_result = await self.follow_edges_batch([ent], path_length)
subgraph.update(batch_result)
async def get_subgraph(self, query):
@ -132,31 +226,52 @@ class Query:
if self.verbose:
logger.debug("Getting subgraph...")
subgraph = set()
# Use optimized batch traversal instead of sequential processing
subgraph = await self.follow_edges_batch(entities, self.max_path_length)
for ent in entities:
await self.follow_edges(ent, subgraph, self.max_path_length)
return list(subgraph)
subgraph = list(subgraph)
async def resolve_labels_batch(self, entities):
"""Resolve labels for multiple entities in parallel"""
tasks = []
for entity in entities:
tasks.append(self.maybe_label(entity))
return subgraph
return await asyncio.gather(*tasks, return_exceptions=True)
async def get_labelgraph(self, query):
subgraph = await self.get_subgraph(query)
# Filter out label triples
filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL]
# Collect all unique entities that need label resolution
entities_to_resolve = set()
for s, p, o in filtered_subgraph:
entities_to_resolve.update([s, p, o])
# Batch resolve labels for all entities in parallel
entity_list = list(entities_to_resolve)
resolved_labels = await self.resolve_labels_batch(entity_list)
# Create entity-to-label mapping
label_map = {}
for entity, label in zip(entity_list, resolved_labels):
if not isinstance(label, Exception):
label_map[entity] = label
else:
label_map[entity] = entity # Fallback to entity itself
# Apply labels to subgraph
sg2 = []
for edge in subgraph:
if edge[1] == LABEL:
continue
s = await self.maybe_label(edge[0])
p = await self.maybe_label(edge[1])
o = await self.maybe_label(edge[2])
sg2.append((s, p, o))
for s, p, o in filtered_subgraph:
labeled_triple = (
label_map.get(s, s),
label_map.get(p, p),
label_map.get(o, o)
)
sg2.append(labeled_triple)
sg2 = sg2[0:self.max_subgraph_size]
@ -171,6 +286,13 @@ class Query:
return sg2
class GraphRag:
"""
CRITICAL SECURITY:
This class MUST be instantiated per-request to ensure proper isolation
between users and collections. The cache within this instance will only
live for the duration of a single request, preventing cross-contamination
of data between different security contexts.
"""
def __init__(
self, prompt_client, embeddings_client, graph_embeddings_client,
@ -184,7 +306,9 @@ class GraphRag:
self.graph_embeddings_client = graph_embeddings_client
self.triples_client = triples_client
self.label_cache = {}
# Replace simple dict with LRU cache with TTL
# CRITICAL: This cache only lives for one request due to per-request instantiation
self.label_cache = LRUCacheWithTTL(max_size=5000, ttl=300)
if self.verbose:
logger.debug("GraphRag initialized")

View file

@ -45,6 +45,10 @@ class Processor(FlowProcessor):
self.default_max_subgraph_size = max_subgraph_size
self.default_max_path_length = max_path_length
# CRITICAL SECURITY: NEVER share data between users or collections
# Each user/collection combination MUST have isolated data access
# Caching must NEVER allow information leakage across these boundaries
self.register_specification(
ConsumerSpec(
name = "request",
@ -93,11 +97,14 @@ class Processor(FlowProcessor):
try:
self.rag = GraphRag(
embeddings_client = flow("embeddings-request"),
graph_embeddings_client = flow("graph-embeddings-request"),
triples_client = flow("triples-request"),
prompt_client = flow("prompt-request"),
# CRITICAL SECURITY: Create new GraphRag instance per request
# This ensures proper isolation between users and collections
# Flow clients are request-scoped and must not be shared
rag = GraphRag(
embeddings_client=flow("embeddings-request"),
graph_embeddings_client=flow("graph-embeddings-request"),
triples_client=flow("triples-request"),
prompt_client=flow("prompt-request"),
verbose=True,
)
@ -128,7 +135,7 @@ class Processor(FlowProcessor):
else:
max_path_length = self.default_max_path_length
response = await self.rag.query(
response = await rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,

View file

@ -60,19 +60,34 @@ class Processor(DocumentEmbeddingsStoreService):
metrics=storage_response_metrics,
)
async def start(self):
"""Start the processor and its storage management consumer"""
await super().start()
await self.storage_request_consumer.start()
await self.storage_response_producer.start()
async def store_document_embeddings(self, message):
# Validate collection exists before accepting writes
if not self.vecstore.collection_exists(message.metadata.user, message.metadata.collection):
error_msg = (
f"Collection {message.metadata.collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
for emb in message.chunks:
if emb.chunk is None or emb.chunk == b"": continue
chunk = emb.chunk.decode("utf-8")
if chunk == "": continue
for vec in emb.vectors:
self.vecstore.insert(
vec, chunk,
message.metadata.user,
vec, chunk,
message.metadata.user,
message.metadata.collection
)
@ -87,18 +102,21 @@ class Processor(DocumentEmbeddingsStoreService):
help=f'Milvus store URI (default: {default_store_uri})'
)
async def on_storage_management(self, message):
async def on_storage_management(self, message, consumer, flow):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
request = message.value()
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
if request.operation == "create-collection":
await self.handle_create_collection(request)
elif request.operation == "delete-collection":
await self.handle_delete_collection(request)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
message=f"Unknown operation: {request.operation}"
)
)
await self.storage_response_producer.send(response)
@ -113,17 +131,40 @@ class Processor(DocumentEmbeddingsStoreService):
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
async def handle_create_collection(self, request):
"""Create a Milvus collection for document embeddings"""
try:
if self.vecstore.collection_exists(request.user, request.collection):
logger.info(f"Collection {request.user}/{request.collection} already exists")
else:
self.vecstore.create_collection(request.user, request.collection)
logger.info(f"Created collection {request.user}/{request.collection}")
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, request):
"""Delete the collection for document embeddings"""
try:
self.vecstore.delete_collection(message.user, message.collection)
self.vecstore.delete_collection(request.user, request.collection)
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -115,38 +115,36 @@ class Processor(DocumentEmbeddingsStoreService):
"Gave up waiting for index creation"
)
async def start(self):
"""Start the processor and its storage management consumer"""
await super().start()
await self.storage_request_consumer.start()
await self.storage_response_producer.start()
async def store_document_embeddings(self, message):
index_name = (
"d-" + message.metadata.user + "-" + message.metadata.collection
)
# Validate collection exists before accepting writes
if not self.pinecone.has_index(index_name):
error_msg = (
f"Collection {message.metadata.collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
for emb in message.chunks:
if emb.chunk is None or emb.chunk == b"": continue
chunk = emb.chunk.decode("utf-8")
if chunk == "": continue
for vec in emb.vectors:
dim = len(vec)
index_name = (
"d-" + message.metadata.user + "-" + message.metadata.collection
)
if index_name != self.last_index_name:
if not self.pinecone.has_index(index_name):
try:
self.create_index(index_name, dim)
except Exception as e:
logger.error("Pinecone index creation failed")
raise e
logger.info(f"Index {index_name} created")
self.last_index_name = index_name
index = self.pinecone.Index(index_name)
# Generate unique ID for each vector
@ -192,18 +190,21 @@ class Processor(DocumentEmbeddingsStoreService):
help=f'Pinecone region, (default: {default_region}'
)
async def on_storage_management(self, message):
async def on_storage_management(self, message, consumer, flow):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
request = message.value()
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
if request.operation == "create-collection":
await self.handle_create_collection(request)
elif request.operation == "delete-collection":
await self.handle_delete_collection(request)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
message=f"Unknown operation: {request.operation}"
)
)
await self.storage_response_producer.send(response)
@ -218,10 +219,36 @@ class Processor(DocumentEmbeddingsStoreService):
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
async def handle_create_collection(self, request):
"""Create a Pinecone index for document embeddings"""
try:
index_name = f"d-{request.user}-{request.collection}"
if self.pinecone.has_index(index_name):
logger.info(f"Pinecone index {index_name} already exists")
else:
# Create with default dimension - will need to be recreated if dimension doesn't match
self.create_index(index_name, dim=384)
logger.info(f"Created Pinecone index: {index_name}")
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, request):
"""Delete the collection for document embeddings"""
try:
index_name = f"d-{message.user}-{message.collection}"
index_name = f"d-{request.user}-{request.collection}"
if self.pinecone.has_index(index_name):
self.pinecone.delete_index(index_name)
@ -234,7 +261,7 @@ class Processor(DocumentEmbeddingsStoreService):
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -36,8 +36,6 @@ class Processor(DocumentEmbeddingsStoreService):
}
)
self.last_collection = None
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
# Set up storage management if base class attributes are available
@ -71,8 +69,30 @@ class Processor(DocumentEmbeddingsStoreService):
metrics=storage_response_metrics,
)
async def start(self):
"""Start the processor and its storage management consumer"""
await super().start()
if hasattr(self, 'storage_request_consumer'):
await self.storage_request_consumer.start()
if hasattr(self, 'storage_response_producer'):
await self.storage_response_producer.start()
async def store_document_embeddings(self, message):
# Validate collection exists before accepting writes
collection = (
"d_" + message.metadata.user + "_" +
message.metadata.collection
)
if not self.qdrant.collection_exists(collection):
error_msg = (
f"Collection {message.metadata.collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
for emb in message.chunks:
chunk = emb.chunk.decode("utf-8")
@ -80,29 +100,6 @@ class Processor(DocumentEmbeddingsStoreService):
for vec in emb.vectors:
dim = len(vec)
collection = (
"d_" + message.metadata.user + "_" +
message.metadata.collection
)
if collection != self.last_collection:
if not self.qdrant.collection_exists(collection):
try:
self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
)
except Exception as e:
logger.error("Qdrant collection creation failed")
raise e
self.last_collection = collection
self.qdrant.upsert(
collection_name=collection,
points=[
@ -133,18 +130,21 @@ class Processor(DocumentEmbeddingsStoreService):
help=f'Qdrant API key (default: None)'
)
async def on_storage_management(self, message):
async def on_storage_management(self, message, consumer, flow):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
request = message.value()
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
if request.operation == "create-collection":
await self.handle_create_collection(request)
elif request.operation == "delete-collection":
await self.handle_delete_collection(request)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
message=f"Unknown operation: {request.operation}"
)
)
await self.storage_response_producer.send(response)
@ -159,10 +159,43 @@ class Processor(DocumentEmbeddingsStoreService):
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
async def handle_create_collection(self, request):
"""Create a Qdrant collection for document embeddings"""
try:
collection_name = f"d_{request.user}_{request.collection}"
if self.qdrant.collection_exists(collection_name):
logger.info(f"Qdrant collection {collection_name} already exists")
else:
# Create collection with default dimension (will be recreated with correct dim on first write if needed)
# Using a placeholder dimension - actual dimension determined by first embedding
self.qdrant.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=384, # Default dimension, common for many models
distance=Distance.COSINE
)
)
logger.info(f"Created Qdrant collection: {collection_name}")
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, request):
"""Delete the collection for document embeddings"""
try:
collection_name = f"d_{message.user}_{message.collection}"
collection_name = f"d_{request.user}_{request.collection}"
if self.qdrant.collection_exists(collection_name):
self.qdrant.delete_collection(collection_name)
@ -175,7 +208,7 @@ class Processor(DocumentEmbeddingsStoreService):
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -60,8 +60,23 @@ class Processor(GraphEmbeddingsStoreService):
metrics=storage_response_metrics,
)
async def start(self):
"""Start the processor and its storage management consumer"""
await super().start()
await self.storage_request_consumer.start()
await self.storage_response_producer.start()
async def store_graph_embeddings(self, message):
# Validate collection exists before accepting writes
if not self.vecstore.collection_exists(message.metadata.user, message.metadata.collection):
error_msg = (
f"Collection {message.metadata.collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
for entity in message.entities:
if entity.entity.value != "" and entity.entity.value is not None:
@ -83,18 +98,21 @@ class Processor(GraphEmbeddingsStoreService):
help=f'Milvus store URI (default: {default_store_uri})'
)
async def on_storage_management(self, message):
async def on_storage_management(self, message, consumer, flow):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
request = message.value()
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
if request.operation == "create-collection":
await self.handle_create_collection(request)
elif request.operation == "delete-collection":
await self.handle_delete_collection(request)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
message=f"Unknown operation: {request.operation}"
)
)
await self.storage_response_producer.send(response)
@ -109,17 +127,40 @@ class Processor(GraphEmbeddingsStoreService):
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
async def handle_create_collection(self, request):
"""Create a Milvus collection for graph embeddings"""
try:
if self.vecstore.collection_exists(request.user, request.collection):
logger.info(f"Collection {request.user}/{request.collection} already exists")
else:
self.vecstore.create_collection(request.user, request.collection)
logger.info(f"Created collection {request.user}/{request.collection}")
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, request):
"""Delete the collection for graph embeddings"""
try:
self.vecstore.delete_collection(message.user, message.collection)
self.vecstore.delete_collection(request.user, request.collection)
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -115,8 +115,27 @@ class Processor(GraphEmbeddingsStoreService):
"Gave up waiting for index creation"
)
async def start(self):
"""Start the processor and its storage management consumer"""
await super().start()
await self.storage_request_consumer.start()
await self.storage_response_producer.start()
async def store_graph_embeddings(self, message):
index_name = (
"t-" + message.metadata.user + "-" + message.metadata.collection
)
# Validate collection exists before accepting writes
if not self.pinecone.has_index(index_name):
error_msg = (
f"Collection {message.metadata.collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
for entity in message.entities:
if entity.entity.value == "" or entity.entity.value is None:
@ -124,28 +143,6 @@ class Processor(GraphEmbeddingsStoreService):
for vec in entity.vectors:
dim = len(vec)
index_name = (
"t-" + message.metadata.user + "-" + message.metadata.collection
)
if index_name != self.last_index_name:
if not self.pinecone.has_index(index_name):
try:
self.create_index(index_name, dim)
except Exception as e:
logger.error("Pinecone index creation failed")
raise e
logger.info(f"Index {index_name} created")
self.last_index_name = index_name
index = self.pinecone.Index(index_name)
# Generate unique ID for each vector
@ -191,18 +188,21 @@ class Processor(GraphEmbeddingsStoreService):
help=f'Pinecone region, (default: {default_region}'
)
async def on_storage_management(self, message):
async def on_storage_management(self, message, consumer, flow):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
request = message.value()
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
if request.operation == "create-collection":
await self.handle_create_collection(request)
elif request.operation == "delete-collection":
await self.handle_delete_collection(request)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
message=f"Unknown operation: {request.operation}"
)
)
await self.storage_response_producer.send(response)
@ -217,10 +217,36 @@ class Processor(GraphEmbeddingsStoreService):
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
async def handle_create_collection(self, request):
"""Create a Pinecone index for graph embeddings"""
try:
index_name = f"t-{request.user}-{request.collection}"
if self.pinecone.has_index(index_name):
logger.info(f"Pinecone index {index_name} already exists")
else:
# Create with default dimension - will need to be recreated if dimension doesn't match
self.create_index(index_name, dim=384)
logger.info(f"Created Pinecone index: {index_name}")
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, request):
"""Delete the collection for graph embeddings"""
try:
index_name = f"t-{message.user}-{message.collection}"
index_name = f"t-{request.user}-{request.collection}"
if self.pinecone.has_index(index_name):
self.pinecone.delete_index(index_name)
@ -233,7 +259,7 @@ class Processor(GraphEmbeddingsStoreService):
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -36,8 +36,6 @@ class Processor(GraphEmbeddingsStoreService):
}
)
self.last_collection = None
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
# Set up storage management if base class attributes are available
@ -71,31 +69,30 @@ class Processor(GraphEmbeddingsStoreService):
metrics=storage_response_metrics,
)
def get_collection(self, dim, user, collection):
def get_collection(self, user, collection):
"""Get collection name and validate it exists"""
cname = (
"t_" + user + "_" + collection
)
if cname != self.last_collection:
if not self.qdrant.collection_exists(cname):
try:
self.qdrant.create_collection(
collection_name=cname,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
)
except Exception as e:
logger.error("Qdrant collection creation failed")
raise e
self.last_collection = cname
if not self.qdrant.collection_exists(cname):
error_msg = (
f"Collection {collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
return cname
async def start(self):
"""Start the processor and its storage management consumer"""
await super().start()
if hasattr(self, 'storage_request_consumer'):
await self.storage_request_consumer.start()
if hasattr(self, 'storage_response_producer'):
await self.storage_response_producer.start()
async def store_graph_embeddings(self, message):
for entity in message.entities:
@ -104,10 +101,8 @@ class Processor(GraphEmbeddingsStoreService):
for vec in entity.vectors:
dim = len(vec)
collection = self.get_collection(
dim, message.metadata.user, message.metadata.collection
message.metadata.user, message.metadata.collection
)
self.qdrant.upsert(
@ -140,18 +135,21 @@ class Processor(GraphEmbeddingsStoreService):
help=f'Qdrant API key'
)
async def on_storage_management(self, message):
async def on_storage_management(self, message, consumer, flow):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
request = message.value()
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
if request.operation == "create-collection":
await self.handle_create_collection(request)
elif request.operation == "delete-collection":
await self.handle_delete_collection(request)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
message=f"Unknown operation: {request.operation}"
)
)
await self.storage_response_producer.send(response)
@ -166,10 +164,43 @@ class Processor(GraphEmbeddingsStoreService):
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
async def handle_create_collection(self, request):
"""Create a Qdrant collection for graph embeddings"""
try:
collection_name = f"t_{request.user}_{request.collection}"
if self.qdrant.collection_exists(collection_name):
logger.info(f"Qdrant collection {collection_name} already exists")
else:
# Create collection with default dimension (will be recreated with correct dim on first write if needed)
# Using a placeholder dimension - actual dimension determined by first embedding
self.qdrant.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=384, # Default dimension, common for many models
distance=Distance.COSINE
)
)
logger.info(f"Created Qdrant collection: {collection_name}")
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, request):
"""Delete the collection for graph embeddings"""
try:
collection_name = f"t_{message.user}_{message.collection}"
collection_name = f"t_{request.user}_{request.collection}"
if self.qdrant.collection_exists(collection_name):
self.qdrant.delete_collection(collection_name)
@ -182,7 +213,7 @@ class Processor(GraphEmbeddingsStoreService):
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -295,6 +295,8 @@ class Processor(FlowProcessor):
try:
self.session.execute(create_table_cql)
if keyspace not in self.known_tables:
self.known_tables[keyspace] = set()
self.known_tables[keyspace].add(table_key)
logger.info(f"Ensured table exists: {safe_keyspace}.{safe_table}")
@ -340,18 +342,47 @@ class Processor(FlowProcessor):
logger.warning(f"Failed to convert value {value} to type {field_type}: {e}")
return str(value)
async def start(self):
"""Start the processor and its storage management consumer"""
await super().start()
await self.storage_request_consumer.start()
await self.storage_response_producer.start()
async def on_object(self, msg, consumer, flow):
"""Process incoming ExtractedObject and store in Cassandra"""
obj = msg.value()
logger.info(f"Storing {len(obj.values)} objects for schema {obj.schema_name} from {obj.metadata.id}")
# Validate collection/keyspace exists before accepting writes
safe_keyspace = self.sanitize_name(obj.metadata.user)
if safe_keyspace not in self.known_keyspaces:
# Check if keyspace actually exists in Cassandra
self.connect_cassandra()
check_keyspace_cql = """
SELECT keyspace_name FROM system_schema.keyspaces
WHERE keyspace_name = %s
"""
result = self.session.execute(check_keyspace_cql, (safe_keyspace,))
# Check if result is None (mock case) or has no rows
if result is None or not result.one():
error_msg = (
f"Collection {obj.metadata.collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
# Cache it if it exists
self.known_keyspaces.add(safe_keyspace)
if safe_keyspace not in self.known_tables:
self.known_tables[safe_keyspace] = set()
# Get schema definition
schema = self.schemas.get(obj.schema_name)
if not schema:
logger.warning(f"No schema found for {obj.schema_name} - skipping")
return
# Ensure table exists
keyspace = obj.metadata.user
table_name = obj.schema_name
@ -425,26 +456,36 @@ class Processor(FlowProcessor):
async def on_storage_management(self, msg, consumer, flow):
"""Handle storage management requests for collection operations"""
logger.info(f"Received storage management request: {msg.operation} for {msg.user}/{msg.collection}")
request = msg.value()
logger.info(f"Received storage management request: {request.operation} for {request.user}/{request.collection}")
try:
if msg.operation == "delete-collection":
await self.delete_collection(msg.user, msg.collection)
if request.operation == "create-collection":
await self.create_collection(request.user, request.collection)
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {msg.user}/{msg.collection}")
logger.info(f"Successfully created collection {request.user}/{request.collection}")
elif request.operation == "delete-collection":
await self.delete_collection(request.user, request.collection)
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
else:
logger.warning(f"Unknown storage management operation: {msg.operation}")
logger.warning(f"Unknown storage management operation: {request.operation}")
# Send error response
from .... schema import Error
response = StorageManagementResponse(
error=Error(
type="unknown_operation",
message=f"Unknown operation: {msg.operation}"
message=f"Unknown operation: {request.operation}"
)
)
await self.storage_response_producer.send(response)
@ -459,10 +500,28 @@ class Processor(FlowProcessor):
message=str(e)
)
)
await self.send("storage-response", response)
await self.storage_response_producer.send(response)
async def create_collection(self, user: str, collection: str):
"""Create/verify collection exists in Cassandra object store"""
# Connect if not already connected
self.connect_cassandra()
# Sanitize names for safety
safe_keyspace = self.sanitize_name(user)
# Ensure keyspace exists
if safe_keyspace not in self.known_keyspaces:
self.ensure_keyspace(safe_keyspace)
self.known_keyspaces.add(safe_keyspace)
# For Cassandra objects, collection is just a property in rows
# No need to create separate tables per collection
# Just mark that we've seen this collection
logger.info(f"Collection {collection} ready for user {user} (using keyspace {safe_keyspace})")
async def delete_collection(self, user: str, collection: str):
"""Delete all data for a specific collection"""
"""Delete all data for a specific collection using schema information"""
# Connect if not already connected
self.connect_cassandra()
@ -482,40 +541,78 @@ class Processor(FlowProcessor):
return
self.known_keyspaces.add(safe_keyspace)
# Get all tables in the keyspace that might contain collection data
get_tables_cql = """
SELECT table_name FROM system_schema.tables
WHERE keyspace_name = %s
"""
tables = self.session.execute(get_tables_cql, (safe_keyspace,))
# Iterate over schemas we manage to delete from relevant tables
tables_deleted = 0
for row in tables:
table_name = row.table_name
for schema_name, schema in self.schemas.items():
safe_table = self.sanitize_table(schema_name)
# Check if the table has a collection column
check_column_cql = """
SELECT column_name FROM system_schema.columns
WHERE keyspace_name = %s AND table_name = %s AND column_name = 'collection'
"""
# Check if table exists
table_key = f"{user}.{schema_name}"
if table_key not in self.known_tables.get(user, set()):
logger.debug(f"Table {safe_keyspace}.{safe_table} not in known tables, skipping")
continue
result = self.session.execute(check_column_cql, (safe_keyspace, table_name))
if result.one():
# Table has collection column, delete data for this collection
try:
delete_cql = f"""
DELETE FROM {safe_keyspace}.{table_name}
try:
# Get primary key fields from schema
primary_key_fields = [field for field in schema.fields if field.primary]
if primary_key_fields:
# Schema has primary keys: need to query for partition keys first
# Build SELECT query for primary key fields
pk_field_names = [self.sanitize_name(field.name) for field in primary_key_fields]
select_cql = f"""
SELECT {', '.join(pk_field_names)}
FROM {safe_keyspace}.{safe_table}
WHERE collection = %s
ALLOW FILTERING
"""
self.session.execute(delete_cql, (collection,))
tables_deleted += 1
logger.info(f"Deleted collection {collection} from table {safe_keyspace}.{table_name}")
except Exception as e:
logger.error(f"Failed to delete from table {safe_keyspace}.{table_name}: {e}")
raise
logger.info(f"Deleted collection {collection} from {tables_deleted} tables in keyspace {safe_keyspace}")
rows = self.session.execute(select_cql, (collection,))
# Delete each row using full partition key
for row in rows:
where_clauses = ["collection = %s"]
values = [collection]
for field_name in pk_field_names:
where_clauses.append(f"{field_name} = %s")
values.append(getattr(row, field_name))
delete_cql = f"""
DELETE FROM {safe_keyspace}.{safe_table}
WHERE {' AND '.join(where_clauses)}
"""
self.session.execute(delete_cql, tuple(values))
else:
# No primary keys, uses synthetic_id
# Need to query for synthetic_ids first
select_cql = f"""
SELECT synthetic_id
FROM {safe_keyspace}.{safe_table}
WHERE collection = %s
ALLOW FILTERING
"""
rows = self.session.execute(select_cql, (collection,))
# Delete each row using collection and synthetic_id
for row in rows:
delete_cql = f"""
DELETE FROM {safe_keyspace}.{safe_table}
WHERE collection = %s AND synthetic_id = %s
"""
self.session.execute(delete_cql, (collection, row.synthetic_id))
tables_deleted += 1
logger.info(f"Deleted collection {collection} from table {safe_keyspace}.{safe_table}")
except Exception as e:
logger.error(f"Failed to delete from table {safe_keyspace}.{safe_table}: {e}")
raise
logger.info(f"Deleted collection {collection} from {tables_deleted} schema-based tables in keyspace {safe_keyspace}")
def close(self):
"""Clean up Cassandra connections"""

View file

@ -109,6 +109,15 @@ class Processor(TriplesStoreService):
self.table = user
# Validate collection exists before accepting writes
if not self.tg.collection_exists(message.metadata.collection):
error_msg = (
f"Collection {message.metadata.collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
for t in message.triples:
self.tg.insert(
message.metadata.collection,
@ -117,18 +126,27 @@ class Processor(TriplesStoreService):
t.o.value
)
async def on_storage_management(self, message):
async def start(self):
"""Start the processor and its storage management consumer"""
await super().start()
await self.storage_request_consumer.start()
await self.storage_response_producer.start()
async def on_storage_management(self, message, consumer, flow):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
request = message.value()
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
if request.operation == "create-collection":
await self.handle_create_collection(request)
elif request.operation == "delete-collection":
await self.handle_delete_collection(request)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
message=f"Unknown operation: {request.operation}"
)
)
await self.storage_response_producer.send(response)
@ -143,42 +161,85 @@ class Processor(TriplesStoreService):
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
"""Delete all data for a specific collection from the unified triples table"""
async def handle_create_collection(self, request):
"""Create a collection in Cassandra triple store"""
try:
# Create or reuse connection for this user's keyspace
if self.table is None or self.table != message.user:
if self.table is None or self.table != request.user:
self.tg = None
try:
if self.cassandra_username and self.cassandra_password:
self.tg = KnowledgeGraph(
hosts=self.cassandra_host,
keyspace=message.user,
keyspace=request.user,
username=self.cassandra_username,
password=self.cassandra_password
)
else:
self.tg = KnowledgeGraph(
hosts=self.cassandra_host,
keyspace=message.user,
keyspace=request.user,
)
except Exception as e:
logger.error(f"Failed to connect to Cassandra for user {message.user}: {e}")
logger.error(f"Failed to connect to Cassandra for user {request.user}: {e}")
raise
self.table = message.user
self.table = request.user
# Delete all triples for this collection from the unified table
# In the unified table schema, collection is the partition key
delete_cql = """
DELETE FROM triples
WHERE collection = ?
"""
# Create collection using the built-in method
logger.info(f"Creating collection {request.collection} for user {request.user}")
if self.tg.collection_exists(request.collection):
logger.info(f"Collection {request.collection} already exists")
else:
self.tg.create_collection(request.collection)
logger.info(f"Created collection {request.collection}")
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, request):
"""Delete all data for a specific collection from the unified triples table"""
try:
# Create or reuse connection for this user's keyspace
if self.table is None or self.table != request.user:
self.tg = None
try:
if self.cassandra_username and self.cassandra_password:
self.tg = KnowledgeGraph(
hosts=self.cassandra_host,
keyspace=request.user,
username=self.cassandra_username,
password=self.cassandra_password
)
else:
self.tg = KnowledgeGraph(
hosts=self.cassandra_host,
keyspace=request.user,
)
except Exception as e:
logger.error(f"Failed to connect to Cassandra for user {request.user}: {e}")
raise
self.table = request.user
# Delete all triples for this collection using the built-in method
try:
self.tg.session.execute(delete_cql, (message.collection,))
logger.info(f"Deleted all triples for collection {message.collection} from keyspace {message.user}")
self.tg.delete_collection(request.collection)
logger.info(f"Deleted all triples for collection {request.collection} from keyspace {request.user}")
except Exception as e:
logger.error(f"Failed to delete collection data: {e}")
raise
@ -188,7 +249,7 @@ class Processor(TriplesStoreService):
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -152,11 +152,43 @@ class Processor(TriplesStoreService):
time=res.run_time_ms
))
def collection_exists(self, user, collection):
"""Check if collection metadata node exists"""
result = self.io.query(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"RETURN c LIMIT 1",
params={"user": user, "collection": collection}
)
return result.result_set is not None and len(result.result_set) > 0
def create_collection(self, user, collection):
"""Create collection metadata node"""
import datetime
self.io.query(
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) "
"SET c.created_at = $created_at",
params={
"user": user,
"collection": collection,
"created_at": datetime.datetime.now().isoformat()
}
)
logger.info(f"Created collection metadata node for {user}/{collection}")
async def store_triples(self, message):
# Extract user and collection from metadata
user = message.metadata.user if message.metadata.user else "default"
collection = message.metadata.collection if message.metadata.collection else "default"
# Validate collection exists before accepting writes
if not self.collection_exists(user, collection):
error_msg = (
f"Collection {collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
for t in message.triples:
self.create_node(t.s.value, user, collection)
@ -185,18 +217,27 @@ class Processor(TriplesStoreService):
help=f'FalkorDB database (default: {default_database})'
)
async def on_storage_management(self, message):
async def start(self):
"""Start the processor and its storage management consumer"""
await super().start()
await self.storage_request_consumer.start()
await self.storage_response_producer.start()
async def on_storage_management(self, message, consumer, flow):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
request = message.value()
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
if request.operation == "create-collection":
await self.handle_create_collection(request)
elif request.operation == "delete-collection":
await self.handle_delete_collection(request)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
message=f"Unknown operation: {request.operation}"
)
)
await self.storage_response_producer.send(response)
@ -211,28 +252,57 @@ class Processor(TriplesStoreService):
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
async def handle_create_collection(self, request):
"""Create collection metadata in FalkorDB"""
try:
if self.collection_exists(request.user, request.collection):
logger.info(f"Collection {request.user}/{request.collection} already exists")
else:
self.create_collection(request.user, request.collection)
logger.info(f"Created collection {request.user}/{request.collection}")
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, request):
"""Delete the collection for FalkorDB triples"""
try:
# Delete all nodes and literals for this user/collection
node_result = self.io.query(
"MATCH (n:Node {user: $user, collection: $collection}) DETACH DELETE n",
params={"user": message.user, "collection": message.collection}
params={"user": request.user, "collection": request.collection}
)
literal_result = self.io.query(
"MATCH (n:Literal {user: $user, collection: $collection}) DETACH DELETE n",
params={"user": message.user, "collection": message.collection}
params={"user": request.user, "collection": request.collection}
)
logger.info(f"Deleted {node_result.nodes_deleted} nodes and {literal_result.nodes_deleted} literals for collection {message.user}/{message.collection}")
# Delete collection metadata node
metadata_result = self.io.query(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) DELETE c",
params={"user": request.user, "collection": request.collection}
)
logger.info(f"Deleted {node_result.nodes_deleted} nodes, {literal_result.nodes_deleted} literals, and {metadata_result.nodes_deleted} metadata nodes for collection {request.user}/{request.collection}")
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -267,12 +267,43 @@ class Processor(TriplesStoreService):
src=t.s.value, dest=t.o.value, uri=t.p.value, user=user, collection=collection,
)
def collection_exists(self, user, collection):
"""Check if collection metadata node exists"""
with self.io.session(database=self.db) as session:
result = session.run(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"RETURN c LIMIT 1",
user=user, collection=collection
)
return bool(list(result))
def create_collection(self, user, collection):
"""Create collection metadata node"""
import datetime
with self.io.session(database=self.db) as session:
session.run(
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) "
"SET c.created_at = $created_at",
user=user, collection=collection,
created_at=datetime.datetime.now().isoformat()
)
logger.info(f"Created collection metadata node for {user}/{collection}")
async def store_triples(self, message):
# Extract user and collection from metadata
user = message.metadata.user if message.metadata.user else "default"
collection = message.metadata.collection if message.metadata.collection else "default"
# Validate collection exists before accepting writes
if not self.collection_exists(user, collection):
error_msg = (
f"Collection {collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
for t in message.triples:
self.create_node(t.s.value, user, collection)
@ -317,18 +348,27 @@ class Processor(TriplesStoreService):
help=f'Memgraph database (default: {default_database})'
)
async def on_storage_management(self, message):
async def start(self):
"""Start the processor and its storage management consumer"""
await super().start()
await self.storage_request_consumer.start()
await self.storage_response_producer.start()
async def on_storage_management(self, message, consumer, flow):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
request = message.value()
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
if request.operation == "create-collection":
await self.handle_create_collection(request)
elif request.operation == "delete-collection":
await self.handle_delete_collection(request)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
message=f"Unknown operation: {request.operation}"
)
)
await self.storage_response_producer.send(response)
@ -343,7 +383,30 @@ class Processor(TriplesStoreService):
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
async def handle_create_collection(self, request):
"""Create collection metadata in Memgraph"""
try:
if self.collection_exists(request.user, request.collection):
logger.info(f"Collection {request.user}/{request.collection} already exists")
else:
self.create_collection(request.user, request.collection)
logger.info(f"Created collection {request.user}/{request.collection}")
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, request):
"""Delete all data for a specific collection"""
try:
with self.io.session(database=self.db) as session:
@ -351,7 +414,7 @@ class Processor(TriplesStoreService):
node_result = session.run(
"MATCH (n:Node {user: $user, collection: $collection}) "
"DETACH DELETE n",
user=message.user, collection=message.collection
user=request.user, collection=request.collection
)
nodes_deleted = node_result.consume().counters.nodes_deleted
@ -359,20 +422,28 @@ class Processor(TriplesStoreService):
literal_result = session.run(
"MATCH (n:Literal {user: $user, collection: $collection}) "
"DETACH DELETE n",
user=message.user, collection=message.collection
user=request.user, collection=request.collection
)
literals_deleted = literal_result.consume().counters.nodes_deleted
# Delete collection metadata node
metadata_result = session.run(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"DELETE c",
user=request.user, collection=request.collection
)
metadata_deleted = metadata_result.consume().counters.nodes_deleted
# Note: Relationships are automatically deleted with DETACH DELETE
logger.info(f"Deleted {nodes_deleted} nodes and {literals_deleted} literals for {message.user}/{message.collection}")
logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {request.user}/{request.collection}")
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -228,6 +228,15 @@ class Processor(TriplesStoreService):
user = message.metadata.user if message.metadata.user else "default"
collection = message.metadata.collection if message.metadata.collection else "default"
# Validate collection exists before accepting writes
if not self.collection_exists(user, collection):
error_msg = (
f"Collection {collection} does not exist. "
f"Create it first with tg-set-collection."
)
logger.error(error_msg)
raise ValueError(error_msg)
for t in message.triples:
self.create_node(t.s.value, user, collection)
@ -268,18 +277,27 @@ class Processor(TriplesStoreService):
help=f'Neo4j database (default: {default_database})'
)
async def on_storage_management(self, message):
async def start(self):
"""Start the processor and its storage management consumer"""
await super().start()
await self.storage_request_consumer.start()
await self.storage_response_producer.start()
async def on_storage_management(self, message, consumer, flow):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
request = message.value()
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
if request.operation == "create-collection":
await self.handle_create_collection(request)
elif request.operation == "delete-collection":
await self.handle_delete_collection(request)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
message=f"Unknown operation: {request.operation}"
)
)
await self.storage_response_producer.send(response)
@ -294,7 +312,52 @@ class Processor(TriplesStoreService):
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
def collection_exists(self, user, collection):
"""Check if collection metadata node exists"""
with self.io.session(database=self.db) as session:
result = session.run(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"RETURN c LIMIT 1",
user=user, collection=collection
)
return bool(list(result))
def create_collection(self, user, collection):
"""Create collection metadata node"""
import datetime
with self.io.session(database=self.db) as session:
session.run(
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) "
"SET c.created_at = $created_at",
user=user, collection=collection,
created_at=datetime.datetime.now().isoformat()
)
logger.info(f"Created collection metadata node for {user}/{collection}")
async def handle_create_collection(self, request):
"""Create collection metadata in Neo4j"""
try:
if self.collection_exists(request.user, request.collection):
logger.info(f"Collection {request.user}/{request.collection} already exists")
else:
self.create_collection(request.user, request.collection)
logger.info(f"Created collection {request.user}/{request.collection}")
# Send success response
response = StorageManagementResponse(error=None)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Failed to create collection: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="creation_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, request):
"""Delete all data for a specific collection"""
try:
with self.io.session(database=self.db) as session:
@ -302,7 +365,7 @@ class Processor(TriplesStoreService):
node_result = session.run(
"MATCH (n:Node {user: $user, collection: $collection}) "
"DETACH DELETE n",
user=message.user, collection=message.collection
user=request.user, collection=request.collection
)
nodes_deleted = node_result.consume().counters.nodes_deleted
@ -310,20 +373,28 @@ class Processor(TriplesStoreService):
literal_result = session.run(
"MATCH (n:Literal {user: $user, collection: $collection}) "
"DETACH DELETE n",
user=message.user, collection=message.collection
user=request.user, collection=request.collection
)
literals_deleted = literal_result.consume().counters.nodes_deleted
# Note: Relationships are automatically deleted with DETACH DELETE
logger.info(f"Deleted {nodes_deleted} nodes and {literals_deleted} literals for {message.user}/{message.collection}")
# Delete collection metadata node
metadata_result = session.run(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"DELETE c",
user=request.user, collection=request.collection
)
metadata_deleted = metadata_result.consume().counters.nodes_deleted
logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {request.user}/{request.collection}")
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -145,7 +145,7 @@ class ConfigTableStore:
""")
self.get_all_stmt = self.cassandra.prepare("""
SELECT class, key, value FROM config;
SELECT class AS cls, key, value FROM config;
""")
self.get_values_stmt = self.cassandra.prepare("""

View file

@ -18,6 +18,7 @@ Supports both Google's Gemini models and Anthropic's Claude models.
from google.oauth2 import service_account
import google.auth
import google.api_core.exceptions
import vertexai
import logging
@ -59,8 +60,17 @@ class Processor(LlmService):
super(Processor, self).__init__(**params)
self.model = model
self.is_anthropic = 'claude' in self.model.lower()
# Store default model and configuration parameters
self.default_model = model
self.region = region
self.temperature = temperature
self.max_output = max_output
self.private_key = private_key
# Model client caches
self.model_clients = {} # Cache for model instances
self.generation_configs = {} # Cache for generation configs (Gemini only)
self.anthropic_client = None # Single Anthropic client (handles multiple models)
# Shared parameters for both model types
self.api_params = {
@ -89,75 +99,101 @@ class Processor(LlmService):
"Ensure it's set in your environment or service account."
)
# Initialize the appropriate client based on the model type
if self.is_anthropic:
logger.info(f"Initializing Anthropic model '{model}' via AnthropicVertex SDK")
# Initialize AnthropicVertex with credentials if provided, otherwise use ADC
anthropic_kwargs = {'region': region, 'project_id': project_id}
if credentials and private_key: # Pass credentials only if from a file
anthropic_kwargs['credentials'] = credentials
logger.debug(f"Using service account credentials for Anthropic model")
else:
logger.debug(f"Using Application Default Credentials for Anthropic model")
self.llm = AnthropicVertex(**anthropic_kwargs)
else:
# For Gemini models, initialize the Vertex AI SDK
logger.info(f"Initializing Google model '{model}' via Vertex AI SDK")
init_kwargs = {'location': region, 'project': project_id}
if credentials and private_key: # Pass credentials only if from a file
init_kwargs['credentials'] = credentials
vertexai.init(**init_kwargs)
# Store credentials and project info for later use
self.credentials = credentials
self.project_id = project_id
self.llm = GenerativeModel(model)
# Initialize Vertex AI SDK for Gemini models
init_kwargs = {'location': region, 'project': project_id}
if credentials and private_key: # Pass credentials only if from a file
init_kwargs['credentials'] = credentials
self.generation_config = GenerationConfig(
temperature=temperature,
top_p=1.0,
top_k=10,
candidate_count=1,
max_output_tokens=max_output,
)
vertexai.init(**init_kwargs)
# Block none doesn't seem to work
block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH
# block_level = HarmBlockThreshold.BLOCK_NONE
self.safety_settings = [
SafetySetting(
category = HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold = block_level,
),
SafetySetting(
category = HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold = block_level,
),
SafetySetting(
category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold = block_level,
),
SafetySetting(
category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold = block_level,
),
]
# Pre-initialize Anthropic client if needed (single client handles all Claude models)
if 'claude' in self.default_model.lower():
self._get_anthropic_client()
# Safety settings for Gemini models
block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH
self.safety_settings = [
SafetySetting(
category = HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold = block_level,
),
SafetySetting(
category = HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold = block_level,
),
SafetySetting(
category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold = block_level,
),
SafetySetting(
category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold = block_level,
),
]
logger.info("VertexAI initialization complete")
async def generate_content(self, system, prompt):
def _get_anthropic_client(self):
"""Get or create the Anthropic client (single client for all Claude models)"""
if self.anthropic_client is None:
logger.info(f"Initializing AnthropicVertex client")
anthropic_kwargs = {'region': self.region, 'project_id': self.project_id}
if self.credentials and self.private_key: # Pass credentials only if from a file
anthropic_kwargs['credentials'] = self.credentials
logger.debug(f"Using service account credentials for Anthropic models")
else:
logger.debug(f"Using Application Default Credentials for Anthropic models")
self.anthropic_client = AnthropicVertex(**anthropic_kwargs)
return self.anthropic_client
def _get_gemini_model(self, model_name, temperature=None):
"""Get or create a Gemini model instance"""
if model_name not in self.model_clients:
logger.info(f"Creating GenerativeModel instance for '{model_name}'")
self.model_clients[model_name] = GenerativeModel(model_name)
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
# Create generation config with the effective temperature
generation_config = GenerationConfig(
temperature=effective_temperature,
top_p=1.0,
top_k=10,
candidate_count=1,
max_output_tokens=self.max_output,
)
return self.model_clients[model_name], generation_config
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
try:
if self.is_anthropic:
if 'claude' in model_name.lower():
# Anthropic API uses a dedicated system prompt
logger.debug("Sending request to Anthropic model...")
response = self.llm.messages.create(
model=self.model,
logger.debug(f"Sending request to Anthropic model '{model_name}'...")
client = self._get_anthropic_client()
response = client.messages.create(
model=model_name,
system=system,
messages=[{"role": "user", "content": prompt}],
max_tokens=self.api_params['max_output_tokens'],
temperature=self.api_params['temperature'],
temperature=effective_temperature,
top_p=self.api_params['top_p'],
top_k=self.api_params['top_k'],
)
@ -166,15 +202,17 @@ class Processor(LlmService):
text=response.content[0].text,
in_token=response.usage.input_tokens,
out_token=response.usage.output_tokens,
model=self.model
model=model_name
)
else:
# Gemini API combines system and user prompts
logger.debug("Sending request to Gemini model...")
logger.debug(f"Sending request to Gemini model '{model_name}'...")
full_prompt = system + "\n\n" + prompt
response = self.llm.generate_content(
full_prompt, generation_config = self.generation_config,
llm, generation_config = self._get_gemini_model(model_name, effective_temperature)
response = llm.generate_content(
full_prompt, generation_config = generation_config,
safety_settings = self.safety_settings,
)
@ -182,7 +220,7 @@ class Processor(LlmService):
text = response.text,
in_token = response.usage_metadata.prompt_token_count,
out_token = response.usage_metadata.candidates_token_count,
model = self.model
model = model_name
)
logger.info(f"Input Tokens: {resp.in_token}")