mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
release/v1.4 -> master (#548)
This commit is contained in:
parent
3ec2cd54f9
commit
2bd68ed7f4
94 changed files with 8571 additions and 1740 deletions
2
Makefile
2
Makefile
|
|
@ -70,6 +70,8 @@ some-containers:
|
|||
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
|
||||
${DOCKER} build -f containers/Containerfile.flow \
|
||||
-t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
|
||||
${DOCKER} build -f containers/Containerfile.vertexai \
|
||||
-t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
|
||||
# ${DOCKER} build -f containers/Containerfile.mcp \
|
||||
# -t ${CONTAINER_BASE}/trustgraph-mcp:${VERSION} .
|
||||
# ${DOCKER} build -f containers/Containerfile.vertexai \
|
||||
|
|
|
|||
|
|
@ -158,17 +158,17 @@ The current primary key `PRIMARY KEY (collection, s, p, o)` provides minimal clu
|
|||
- Uneven load distribution across cluster nodes
|
||||
- Scalability bottlenecks as collections grow
|
||||
|
||||
## Proposed Solution: Multi-Table Denormalization Strategy
|
||||
## Proposed Solution: 4-Table Denormalization Strategy
|
||||
|
||||
### Overview
|
||||
|
||||
Replace the single `triples` table with three purpose-built tables, each optimized for specific query patterns. This eliminates the need for secondary indexes and ALLOW FILTERING while providing optimal performance for all query types.
|
||||
Replace the single `triples` table with four purpose-built tables, each optimized for specific query patterns. This eliminates the need for secondary indexes and ALLOW FILTERING while providing optimal performance for all query types. The fourth table enables efficient collection deletion despite compound partition keys.
|
||||
|
||||
### New Schema Design
|
||||
|
||||
**Table 1: Subject-Centric Queries**
|
||||
**Table 1: Subject-Centric Queries (triples_s)**
|
||||
```sql
|
||||
CREATE TABLE triples_by_subject (
|
||||
CREATE TABLE triples_s (
|
||||
collection text,
|
||||
s text,
|
||||
p text,
|
||||
|
|
@ -176,13 +176,13 @@ CREATE TABLE triples_by_subject (
|
|||
PRIMARY KEY ((collection, s), p, o)
|
||||
);
|
||||
```
|
||||
- **Optimizes:** get_s, get_sp, get_spo, get_os
|
||||
- **Optimizes:** get_s, get_sp, get_os
|
||||
- **Partition Key:** (collection, s) - Better distribution than collection alone
|
||||
- **Clustering:** (p, o) - Enables efficient predicate/object lookups for a subject
|
||||
|
||||
**Table 2: Predicate-Object Queries**
|
||||
**Table 2: Predicate-Object Queries (triples_p)**
|
||||
```sql
|
||||
CREATE TABLE triples_by_po (
|
||||
CREATE TABLE triples_p (
|
||||
collection text,
|
||||
p text,
|
||||
o text,
|
||||
|
|
@ -194,9 +194,9 @@ CREATE TABLE triples_by_po (
|
|||
- **Partition Key:** (collection, p) - Direct access by predicate
|
||||
- **Clustering:** (o, s) - Efficient object-subject traversal
|
||||
|
||||
**Table 3: Object-Centric Queries**
|
||||
**Table 3: Object-Centric Queries (triples_o)**
|
||||
```sql
|
||||
CREATE TABLE triples_by_object (
|
||||
CREATE TABLE triples_o (
|
||||
collection text,
|
||||
o text,
|
||||
s text,
|
||||
|
|
@ -204,30 +204,72 @@ CREATE TABLE triples_by_object (
|
|||
PRIMARY KEY ((collection, o), s, p)
|
||||
);
|
||||
```
|
||||
- **Optimizes:** get_o, get_os
|
||||
- **Optimizes:** get_o
|
||||
- **Partition Key:** (collection, o) - Direct access by object
|
||||
- **Clustering:** (s, p) - Efficient subject-predicate traversal
|
||||
|
||||
**Table 4: Collection Management & SPO Queries (triples_collection)**
|
||||
```sql
|
||||
CREATE TABLE triples_collection (
|
||||
collection text,
|
||||
s text,
|
||||
p text,
|
||||
o text,
|
||||
PRIMARY KEY (collection, s, p, o)
|
||||
);
|
||||
```
|
||||
- **Optimizes:** get_spo, delete_collection
|
||||
- **Partition Key:** collection only - Enables efficient collection-level operations
|
||||
- **Clustering:** (s, p, o) - Standard triple ordering
|
||||
- **Purpose:** Dual use for exact SPO lookups and as deletion index
|
||||
|
||||
### Query Mapping
|
||||
|
||||
| Original Query | Target Table | Performance Improvement |
|
||||
|----------------|-------------|------------------------|
|
||||
| get_all(collection) | triples_by_subject | Token-based pagination |
|
||||
| get_s(collection, s) | triples_by_subject | Direct partition access |
|
||||
| get_p(collection, p) | triples_by_po | Direct partition access |
|
||||
| get_o(collection, o) | triples_by_object | Direct partition access |
|
||||
| get_sp(collection, s, p) | triples_by_subject | Partition + clustering |
|
||||
| get_po(collection, p, o) | triples_by_po | **No more ALLOW FILTERING!** |
|
||||
| get_os(collection, o, s) | triples_by_subject | Partition + clustering |
|
||||
| get_spo(collection, s, p, o) | triples_by_subject | Exact key lookup |
|
||||
| get_all(collection) | triples_s | ALLOW FILTERING (acceptable for scan) |
|
||||
| get_s(collection, s) | triples_s | Direct partition access |
|
||||
| get_p(collection, p) | triples_p | Direct partition access |
|
||||
| get_o(collection, o) | triples_o | Direct partition access |
|
||||
| get_sp(collection, s, p) | triples_s | Partition + clustering |
|
||||
| get_po(collection, p, o) | triples_p | **No more ALLOW FILTERING!** |
|
||||
| get_os(collection, o, s) | triples_o | Partition + clustering |
|
||||
| get_spo(collection, s, p, o) | triples_collection | Exact key lookup |
|
||||
| delete_collection(collection) | triples_collection | Read index, batch delete all |
|
||||
|
||||
### Collection Deletion Strategy
|
||||
|
||||
With compound partition keys, we cannot simply execute `DELETE FROM table WHERE collection = ?`. Instead:
|
||||
|
||||
1. **Read Phase:** Query `triples_collection` to enumerate all triples:
|
||||
```sql
|
||||
SELECT s, p, o FROM triples_collection WHERE collection = ?
|
||||
```
|
||||
This is efficient since `collection` is the partition key for this table.
|
||||
|
||||
2. **Delete Phase:** For each triple (s, p, o), delete from all 4 tables using full partition keys:
|
||||
```sql
|
||||
DELETE FROM triples_s WHERE collection = ? AND s = ? AND p = ? AND o = ?
|
||||
DELETE FROM triples_p WHERE collection = ? AND p = ? AND o = ? AND s = ?
|
||||
DELETE FROM triples_o WHERE collection = ? AND o = ? AND s = ? AND p = ?
|
||||
DELETE FROM triples_collection WHERE collection = ? AND s = ? AND p = ? AND o = ?
|
||||
```
|
||||
Batched in groups of 100 for efficiency.
|
||||
|
||||
**Trade-off Analysis:**
|
||||
- ✅ Maintains optimal query performance with distributed partitions
|
||||
- ✅ No hot partitions for large collections
|
||||
- ❌ More complex deletion logic (read-then-delete)
|
||||
- ❌ Deletion time proportional to collection size
|
||||
|
||||
### Benefits
|
||||
|
||||
1. **Eliminates ALLOW FILTERING** - Every query has an optimal access path
|
||||
1. **Eliminates ALLOW FILTERING** - Every query has an optimal access path (except get_all scan)
|
||||
2. **No Secondary Indexes** - Each table IS the index for its query pattern
|
||||
3. **Better Data Distribution** - Composite partition keys spread load effectively
|
||||
4. **Predictable Performance** - Query time proportional to result size, not total data
|
||||
5. **Leverages Cassandra Strengths** - Designed for Cassandra's architecture
|
||||
6. **Enables Collection Deletion** - triples_collection serves as deletion index
|
||||
|
||||
## Implementation Plan
|
||||
|
||||
|
|
@ -295,10 +337,11 @@ def delete_collection(self, collection) -> None # Delete from all three tables
|
|||
### Implementation Strategy
|
||||
|
||||
#### Phase 1: Schema and Core Methods
|
||||
1. **Rewrite `init()` method** - Create three tables instead of one
|
||||
2. **Rewrite `insert()` method** - Batch writes to all three tables
|
||||
1. **Rewrite `init()` method** - Create four tables instead of one
|
||||
2. **Rewrite `insert()` method** - Batch writes to all four tables
|
||||
3. **Implement prepared statements** - For optimal performance
|
||||
4. **Add table routing logic** - Direct queries to optimal tables
|
||||
5. **Implement collection deletion** - Read from triples_collection, batch delete from all tables
|
||||
|
||||
#### Phase 2: Query Method Optimization
|
||||
1. **Rewrite each get_* method** to use optimal table
|
||||
|
|
@ -318,18 +361,11 @@ def delete_collection(self, collection) -> None # Delete from all three tables
|
|||
def insert(self, collection, s, p, o):
|
||||
batch = BatchStatement()
|
||||
|
||||
# Insert into all three tables
|
||||
batch.add(SimpleStatement(
|
||||
"INSERT INTO triples_by_subject (collection, s, p, o) VALUES (?, ?, ?, ?)"
|
||||
), (collection, s, p, o))
|
||||
|
||||
batch.add(SimpleStatement(
|
||||
"INSERT INTO triples_by_po (collection, p, o, s) VALUES (?, ?, ?, ?)"
|
||||
), (collection, p, o, s))
|
||||
|
||||
batch.add(SimpleStatement(
|
||||
"INSERT INTO triples_by_object (collection, o, s, p) VALUES (?, ?, ?, ?)"
|
||||
), (collection, o, s, p))
|
||||
# Insert into all four tables
|
||||
batch.add(self.insert_subject_stmt, (collection, s, p, o))
|
||||
batch.add(self.insert_po_stmt, (collection, p, o, s))
|
||||
batch.add(self.insert_object_stmt, (collection, o, s, p))
|
||||
batch.add(self.insert_collection_stmt, (collection, s, p, o))
|
||||
|
||||
self.session.execute(batch)
|
||||
```
|
||||
|
|
@ -337,11 +373,65 @@ def insert(self, collection, s, p, o):
|
|||
#### Query Routing Logic
|
||||
```python
|
||||
def get_po(self, collection, p, o, limit=10):
|
||||
# Route to triples_by_po table - NO ALLOW FILTERING!
|
||||
# Route to triples_p table - NO ALLOW FILTERING!
|
||||
return self.session.execute(
|
||||
"SELECT s FROM triples_by_po WHERE collection = ? AND p = ? AND o = ? LIMIT ?",
|
||||
self.get_po_stmt,
|
||||
(collection, p, o, limit)
|
||||
)
|
||||
|
||||
def get_spo(self, collection, s, p, o, limit=10):
|
||||
# Route to triples_collection table for exact SPO lookup
|
||||
return self.session.execute(
|
||||
self.get_spo_stmt,
|
||||
(collection, s, p, o, limit)
|
||||
)
|
||||
```
|
||||
|
||||
#### Collection Deletion Logic
|
||||
```python
|
||||
def delete_collection(self, collection):
|
||||
# Step 1: Read all triples from collection table
|
||||
rows = self.session.execute(
|
||||
f"SELECT s, p, o FROM {self.collection_table} WHERE collection = %s",
|
||||
(collection,)
|
||||
)
|
||||
|
||||
# Step 2: Batch delete from all 4 tables
|
||||
batch = BatchStatement()
|
||||
count = 0
|
||||
|
||||
for row in rows:
|
||||
s, p, o = row.s, row.p, row.o
|
||||
|
||||
# Delete using full partition keys for each table
|
||||
batch.add(SimpleStatement(
|
||||
f"DELETE FROM {self.subject_table} WHERE collection = ? AND s = ? AND p = ? AND o = ?"
|
||||
), (collection, s, p, o))
|
||||
|
||||
batch.add(SimpleStatement(
|
||||
f"DELETE FROM {self.po_table} WHERE collection = ? AND p = ? AND o = ? AND s = ?"
|
||||
), (collection, p, o, s))
|
||||
|
||||
batch.add(SimpleStatement(
|
||||
f"DELETE FROM {self.object_table} WHERE collection = ? AND o = ? AND s = ? AND p = ?"
|
||||
), (collection, o, s, p))
|
||||
|
||||
batch.add(SimpleStatement(
|
||||
f"DELETE FROM {self.collection_table} WHERE collection = ? AND s = ? AND p = ? AND o = ?"
|
||||
), (collection, s, p, o))
|
||||
|
||||
count += 1
|
||||
|
||||
# Execute every 100 triples to avoid oversized batches
|
||||
if count % 100 == 0:
|
||||
self.session.execute(batch)
|
||||
batch = BatchStatement()
|
||||
|
||||
# Execute remaining deletions
|
||||
if count % 100 != 0:
|
||||
self.session.execute(batch)
|
||||
|
||||
logger.info(f"Deleted {count} triples from collection {collection}")
|
||||
```
|
||||
|
||||
#### Prepared Statement Optimization
|
||||
|
|
@ -349,12 +439,18 @@ def get_po(self, collection, p, o, limit=10):
|
|||
def prepare_statements(self):
|
||||
# Cache prepared statements for better performance
|
||||
self.insert_subject_stmt = self.session.prepare(
|
||||
"INSERT INTO triples_by_subject (collection, s, p, o) VALUES (?, ?, ?, ?)"
|
||||
f"INSERT INTO {self.subject_table} (collection, s, p, o) VALUES (?, ?, ?, ?)"
|
||||
)
|
||||
self.insert_po_stmt = self.session.prepare(
|
||||
"INSERT INTO triples_by_po (collection, p, o, s) VALUES (?, ?, ?, ?)"
|
||||
f"INSERT INTO {self.po_table} (collection, p, o, s) VALUES (?, ?, ?, ?)"
|
||||
)
|
||||
# ... etc for all tables and queries
|
||||
self.insert_object_stmt = self.session.prepare(
|
||||
f"INSERT INTO {self.object_table} (collection, o, s, p) VALUES (?, ?, ?, ?)"
|
||||
)
|
||||
self.insert_collection_stmt = self.session.prepare(
|
||||
f"INSERT INTO {self.collection_table} (collection, s, p, o) VALUES (?, ?, ?, ?)"
|
||||
)
|
||||
# ... query statements
|
||||
```
|
||||
|
||||
## Migration Strategy
|
||||
|
|
@ -511,9 +607,10 @@ def rollback_to_legacy():
|
|||
## Risks and Considerations
|
||||
|
||||
### Performance Risks
|
||||
- **Write latency increase** - 3x write operations per insert
|
||||
- **Storage overhead** - 3x storage requirement
|
||||
- **Write latency increase** - 4x write operations per insert (33% more than 3-table approach)
|
||||
- **Storage overhead** - 4x storage requirement (33% more than 3-table approach)
|
||||
- **Batch write failures** - Need proper error handling
|
||||
- **Deletion complexity** - Collection deletion requires read-then-delete loop
|
||||
|
||||
### Operational Risks
|
||||
- **Migration complexity** - Data migration for large datasets
|
||||
|
|
|
|||
|
|
@ -2,16 +2,17 @@
|
|||
|
||||
## Overview
|
||||
|
||||
This specification describes the collection management capabilities for TrustGraph, enabling users to have explicit control over collections that are currently implicitly created during data loading and querying operations. The feature supports four primary use cases:
|
||||
This specification describes the collection management capabilities for TrustGraph, requiring explicit collection creation and providing direct control over the collection lifecycle. Collections must be explicitly created before use, ensuring proper synchronization between the librarian metadata and all storage backends. The feature supports four primary use cases:
|
||||
|
||||
1. **Collection Listing**: View all existing collections in the system
|
||||
2. **Collection Deletion**: Remove unwanted collections and their associated data
|
||||
3. **Collection Labeling**: Associate descriptive labels with collections for better organization
|
||||
4. **Collection Tagging**: Apply tags to collections for categorization and easier discovery
|
||||
1. **Collection Creation**: Explicitly create collections before storing data
|
||||
2. **Collection Listing**: View all existing collections in the system
|
||||
3. **Collection Metadata Management**: Update collection names, descriptions, and tags
|
||||
4. **Collection Deletion**: Remove collections and their associated data across all storage types
|
||||
|
||||
## Goals
|
||||
|
||||
- **Explicit Collection Control**: Provide users with direct management capabilities over collections beyond implicit creation
|
||||
- **Explicit Collection Creation**: Require collections to be created before data can be stored
|
||||
- **Storage Synchronization**: Ensure collections exist in all storage backends (vectors, objects, triples)
|
||||
- **Collection Visibility**: Enable users to list and inspect all collections in their environment
|
||||
- **Collection Cleanup**: Allow deletion of collections that are no longer needed
|
||||
- **Collection Organization**: Support labels and tags for better collection tracking and discovery
|
||||
|
|
@ -19,22 +20,25 @@ This specification describes the collection management capabilities for TrustGra
|
|||
- **Collection Discovery**: Make it easier to find specific collections through filtering and search
|
||||
- **Operational Transparency**: Provide clear visibility into collection lifecycle and usage
|
||||
- **Resource Management**: Enable cleanup of unused collections to optimize resource utilization
|
||||
- **Data Integrity**: Prevent orphaned collections in storage without metadata tracking
|
||||
|
||||
## Background
|
||||
|
||||
Currently, collections in TrustGraph are implicitly created during data loading operations and query execution. While this provides convenience for users, it lacks the explicit control needed for production environments and long-term data management.
|
||||
Previously, collections in TrustGraph were implicitly created during data loading operations, leading to synchronization issues where collections could exist in storage backends without corresponding metadata in the librarian. This created management challenges and potential orphaned data.
|
||||
|
||||
Current limitations include:
|
||||
- No way to list existing collections
|
||||
- No mechanism to delete unwanted collections
|
||||
- No ability to associate metadata with collections for tracking purposes
|
||||
- Difficulty in organizing and discovering collections over time
|
||||
The explicit collection creation model addresses these issues by:
|
||||
- Requiring collections to be created before use via `tg-set-collection`
|
||||
- Broadcasting collection creation to all storage backends
|
||||
- Maintaining synchronized state between librarian metadata and storage
|
||||
- Preventing writes to non-existent collections
|
||||
- Providing clear collection lifecycle management
|
||||
|
||||
This specification addresses these gaps by introducing explicit collection management operations. By providing collection management APIs and commands, TrustGraph can:
|
||||
- Give users full control over their collection lifecycle
|
||||
- Enable better organization through labels and tags
|
||||
- Support collection cleanup for resource optimization
|
||||
- Improve operational visibility and management
|
||||
This specification defines the explicit collection management model. By requiring explicit collection creation, TrustGraph ensures:
|
||||
- Collections are tracked in librarian metadata from creation
|
||||
- All storage backends are aware of collections before receiving data
|
||||
- No orphaned collections exist in storage
|
||||
- Clear operational visibility and control over collection lifecycle
|
||||
- Consistent error handling when operations reference non-existent collections
|
||||
|
||||
## Technical Design
|
||||
|
||||
|
|
@ -98,24 +102,52 @@ This approach allows:
|
|||
|
||||
#### Collection Lifecycle
|
||||
|
||||
Collections follow a lazy-creation pattern that aligns with existing TrustGraph behavior:
|
||||
Collections are explicitly created in the librarian before data operations can proceed:
|
||||
|
||||
1. **Lazy Creation**: Collections are automatically created when first referenced during data loading or query operations. No explicit create operation is needed.
|
||||
1. **Collection Creation** (Two Paths):
|
||||
|
||||
2. **Implicit Registration**: When a collection is used (data loading, querying), the system checks if a metadata record exists. If not, a new record is created with default values:
|
||||
- `name`: defaults to collection_id
|
||||
- `description`: empty
|
||||
- `tags`: empty set
|
||||
- `created_at`: current timestamp
|
||||
**Path A: User-Initiated Creation** via `tg-set-collection`:
|
||||
- User provides collection ID, name, description, and tags
|
||||
- Librarian creates metadata record in `collections` table
|
||||
- Librarian broadcasts "create-collection" to all storage backends
|
||||
- All storage processors create collection and confirm success
|
||||
- Collection is now ready for data operations
|
||||
|
||||
3. **Explicit Updates**: Users can update collection metadata (name, description, tags) through management operations after lazy creation.
|
||||
**Path B: Automatic Creation on Document Submission**:
|
||||
- User submits document specifying a collection ID
|
||||
- Librarian checks if collection exists in metadata table
|
||||
- If not exists: Librarian creates metadata with defaults (name=collection_id, empty description/tags)
|
||||
- Librarian broadcasts "create-collection" to all storage backends
|
||||
- All storage processors create collection and confirm success
|
||||
- Document processing proceeds with collection now established
|
||||
|
||||
4. **Explicit Deletion**: Users can delete collections, which removes both the metadata record and the underlying collection data across all store types.
|
||||
Both paths ensure collection exists in librarian metadata AND all storage backends before data operations.
|
||||
|
||||
5. **Multi-Store Deletion**: Collection deletion cascades across all storage backends (vector stores, object stores, triple stores) as each implements lazy creation and must support collection deletion.
|
||||
2. **Storage Validation**: Write operations validate collection exists:
|
||||
- Storage processors check collection state before accepting writes
|
||||
- Writes to non-existent collections return error
|
||||
- This prevents direct writes bypassing the librarian's collection creation logic
|
||||
|
||||
3. **Query Behavior**: Query operations handle non-existent collections gracefully:
|
||||
- Queries to non-existent collections return empty results
|
||||
- No error thrown for query operations
|
||||
- Allows exploration without requiring collection to exist
|
||||
|
||||
4. **Metadata Updates**: Users can update collection metadata after creation:
|
||||
- Update name, description, and tags via `tg-set-collection`
|
||||
- Updates apply to librarian metadata only
|
||||
- Storage backends maintain collection but metadata updates don't propagate
|
||||
|
||||
5. **Explicit Deletion**: Users delete collections via `tg-delete-collection`:
|
||||
- Librarian broadcasts "delete-collection" to all storage backends
|
||||
- Waits for confirmation from all storage processors
|
||||
- Deletes librarian metadata record only after storage cleanup complete
|
||||
- Ensures no orphaned data remains in storage
|
||||
|
||||
**Key Principle**: The librarian is the single point of control for collection creation. Whether initiated by user command or document submission, the librarian ensures proper metadata tracking and storage backend synchronization before allowing data operations.
|
||||
|
||||
Operations required:
|
||||
- **Collection Use Notification**: Internal operation triggered during data loading/querying to ensure metadata record exists
|
||||
- **Create Collection**: User operation via `tg-set-collection` OR automatic on document submission
|
||||
- **Update Collection Metadata**: User operation to modify name, description, and tags
|
||||
- **Delete Collection**: User operation to remove collection and its data across all stores
|
||||
- **List Collections**: User operation to view collections with filtering by tags
|
||||
|
|
@ -123,32 +155,65 @@ Operations required:
|
|||
#### Multi-Store Collection Management
|
||||
|
||||
Collections exist across multiple storage backends in TrustGraph:
|
||||
- **Vector Stores**: Store embeddings and vector data for collections
|
||||
- **Object Stores**: Store documents and file data for collections
|
||||
- **Triple Stores**: Store graph/RDF data for collections
|
||||
- **Vector Stores** (Qdrant, Milvus, Pinecone): Store embeddings and vector data
|
||||
- **Object Stores** (Cassandra): Store documents and file data
|
||||
- **Triple Stores** (Cassandra, Neo4j, Memgraph, FalkorDB): Store graph/RDF data
|
||||
|
||||
Each store type implements:
|
||||
- **Lazy Creation**: Collections are created implicitly when data is first stored
|
||||
- **Collection Deletion**: Store-specific deletion operations to remove collection data
|
||||
- **Collection State Tracking**: Maintain knowledge of which collections exist
|
||||
- **Collection Creation**: Accept and process "create-collection" operations
|
||||
- **Collection Validation**: Check collection exists before accepting writes
|
||||
- **Collection Deletion**: Remove all data for specified collection
|
||||
|
||||
The librarian service coordinates collection operations across all store types, ensuring consistent collection lifecycle management.
|
||||
The librarian service coordinates collection operations across all store types, ensuring:
|
||||
- Collections created in all backends before use
|
||||
- All backends confirm creation before returning success
|
||||
- Synchronized collection lifecycle across storage types
|
||||
- Consistent error handling when collections don't exist
|
||||
|
||||
#### Collection State Tracking by Storage Type
|
||||
|
||||
Each storage backend tracks collection state differently based on its capabilities:
|
||||
|
||||
**Cassandra Triple Store:**
|
||||
- Uses existing `triples_collection` table
|
||||
- Creates system marker triple when collection created
|
||||
- Query: `SELECT collection FROM triples_collection WHERE collection = ? LIMIT 1`
|
||||
- Efficient single-partition check for collection existence
|
||||
|
||||
**Qdrant/Milvus/Pinecone Vector Stores:**
|
||||
- Native collection APIs provide existence checking
|
||||
- Collections created with proper vector configuration
|
||||
- `collection_exists()` method uses storage API
|
||||
- Collection creation validates dimension requirements
|
||||
|
||||
**Neo4j/Memgraph/FalkorDB Graph Stores:**
|
||||
- Use `:CollectionMetadata` nodes to track collections
|
||||
- Node properties: `{user, collection, created_at}`
|
||||
- Query: `MATCH (c:CollectionMetadata {user: $user, collection: $collection})`
|
||||
- Separate from data nodes for clean separation
|
||||
- Enables efficient collection listing and validation
|
||||
|
||||
**Cassandra Object Store:**
|
||||
- Uses collection metadata table or marker rows
|
||||
- Similar pattern to triple store
|
||||
- Validates collection before document writes
|
||||
|
||||
### APIs
|
||||
|
||||
New APIs:
|
||||
Collection Management APIs (Librarian):
|
||||
- **Create/Update Collection**: Create new collection or update existing metadata via `tg-set-collection`
|
||||
- **List Collections**: Retrieve collections for a user with optional tag filtering
|
||||
- **Update Collection Metadata**: Modify collection name, description, and tags
|
||||
- **Delete Collection**: Remove collection and associated data with confirmation, cascading to all store types
|
||||
- **Collection Use Notification** (Internal): Ensure metadata record exists when collection is referenced
|
||||
- **Delete Collection**: Remove collection and associated data, cascading to all store types
|
||||
|
||||
Store Writer APIs (Enhanced):
|
||||
- **Vector Store Collection Deletion**: Remove vector data for specified user and collection
|
||||
- **Object Store Collection Deletion**: Remove object/document data for specified user and collection
|
||||
- **Triple Store Collection Deletion**: Remove graph/RDF data for specified user and collection
|
||||
Storage Management APIs (All Storage Processors):
|
||||
- **Create Collection**: Handle "create-collection" operation, establish collection in storage
|
||||
- **Delete Collection**: Handle "delete-collection" operation, remove all collection data
|
||||
- **Collection Exists Check**: Internal validation before accepting write operations
|
||||
|
||||
Modified APIs:
|
||||
- **Data Loading APIs**: Enhanced to trigger collection use notification for lazy metadata creation
|
||||
- **Query APIs**: Enhanced to trigger collection use notification and optionally include metadata in responses
|
||||
Data Operation APIs (Modified Behavior):
|
||||
- **Write APIs**: Validate collection exists before accepting data, return error if not
|
||||
- **Query APIs**: Return empty results for non-existent collections without error
|
||||
|
||||
### Implementation Details
|
||||
|
||||
|
|
@ -168,32 +233,35 @@ When a user initiates collection deletion through the librarian service:
|
|||
|
||||
#### Collection Management Interface
|
||||
|
||||
All store writers will implement a standardized collection management interface with a common schema across store types:
|
||||
All store writers implement a standardized collection management interface with a common schema:
|
||||
|
||||
**Message Schema:**
|
||||
**Message Schema (`StorageManagementRequest`):**
|
||||
```json
|
||||
{
|
||||
"operation": "delete-collection",
|
||||
"operation": "create-collection" | "delete-collection",
|
||||
"user": "user123",
|
||||
"collection": "documents-2024",
|
||||
"timestamp": "2024-01-15T10:30:00Z"
|
||||
"collection": "documents-2024"
|
||||
}
|
||||
```
|
||||
|
||||
**Queue Architecture:**
|
||||
- **Object Store Collection Management Queue**: Handles collection operations for object/document stores
|
||||
- **Vector Store Collection Management Queue**: Handles collection operations for vector/embedding stores
|
||||
- **Triple Store Collection Management Queue**: Handles collection operations for graph/RDF stores
|
||||
- **Vector Store Management Queue** (`vector-storage-management`): Vector/embedding stores
|
||||
- **Object Store Management Queue** (`object-storage-management`): Object/document stores
|
||||
- **Triple Store Management Queue** (`triples-storage-management`): Graph/RDF stores
|
||||
- **Storage Response Queue** (`storage-management-response`): All responses sent here
|
||||
|
||||
Each store writer implements:
|
||||
- **Collection Management Handler**: Separate from standard data storage handlers
|
||||
- **Delete Collection Operation**: Removes all data associated with the specified collection
|
||||
- **Message Processing**: Consumes from dedicated collection management queue
|
||||
- **Status Reporting**: Returns success/failure status for coordination
|
||||
- **Idempotent Operations**: Handles cases where collection doesn't exist (no-op)
|
||||
- **Collection Management Handler**: Processes `StorageManagementRequest` messages
|
||||
- **Create Collection Operation**: Establishes collection in storage backend
|
||||
- **Delete Collection Operation**: Removes all data associated with collection
|
||||
- **Collection State Tracking**: Maintains knowledge of which collections exist
|
||||
- **Message Processing**: Consumes from dedicated management queue
|
||||
- **Status Reporting**: Returns success/failure via `StorageManagementResponse`
|
||||
- **Idempotent Operations**: Safe to call create/delete multiple times
|
||||
|
||||
**Initial Implementation:**
|
||||
Only `delete-collection` operation will be implemented initially. The interface supports future operations like `archive-collection`, `migrate-collection`, etc.
|
||||
**Supported Operations:**
|
||||
- `create-collection`: Create collection in storage backend
|
||||
- `delete-collection`: Remove all collection data from storage backend
|
||||
|
||||
#### Cassandra Triple Store Refactor
|
||||
|
||||
|
|
@ -244,13 +312,11 @@ As part of this implementation, the Cassandra triple store will be refactored fr
|
|||
- Maintain same query logic with collection parameter
|
||||
|
||||
**Benefits:**
|
||||
- **Simplified Collection Deletion**: Simple `DELETE FROM triples WHERE collection = ?` instead of dropping tables
|
||||
- **Simplified Collection Deletion**: Delete using `collection` partition key across all 4 tables
|
||||
- **Resource Efficiency**: Fewer database connections and table objects
|
||||
- **Cross-Collection Operations**: Easier to implement operations spanning multiple collections
|
||||
- **Consistent Architecture**: Aligns with unified collection metadata approach
|
||||
|
||||
**Migration Strategy:**
|
||||
Existing table-per-collection data will need migration to the new unified schema during the upgrade process.
|
||||
- **Collection Validation**: Easy to check collection existence via `triples_collection` table
|
||||
|
||||
Collection operations will be atomic where possible and provide appropriate error handling and validation.
|
||||
|
||||
|
|
@ -264,37 +330,25 @@ Collection listing operations may need pagination for environments with large nu
|
|||
|
||||
## Testing Strategy
|
||||
|
||||
Comprehensive testing will cover collection lifecycle operations, metadata management, and CLI command functionality with both unit and integration tests.
|
||||
|
||||
## Migration Plan
|
||||
|
||||
This implementation requires both metadata and storage migrations:
|
||||
|
||||
### Collection Metadata Migration
|
||||
Existing collections will need to be registered in the new Cassandra collections metadata table. A migration process will:
|
||||
- Scan existing keyspaces and tables to identify collections
|
||||
- Create metadata records with default values (name=collection_id, empty description/tags)
|
||||
- Preserve creation timestamps where possible
|
||||
|
||||
### Cassandra Triple Store Migration
|
||||
The Cassandra storage refactor requires data migration from table-per-collection to unified table:
|
||||
- **Pre-migration**: Identify all user keyspaces and collection tables
|
||||
- **Data Transfer**: Copy triples from individual collection tables to unified "triples" table with collection
|
||||
- **Schema Validation**: Ensure new primary key structure maintains query performance
|
||||
- **Cleanup**: Remove old collection tables after successful migration
|
||||
- **Rollback Plan**: Maintain ability to restore table-per-collection structure if needed
|
||||
|
||||
Migration will be performed during a maintenance window to ensure data consistency.
|
||||
Comprehensive testing will cover:
|
||||
- Collection creation workflow end-to-end
|
||||
- Storage backend synchronization
|
||||
- Write validation for non-existent collections
|
||||
- Query handling of non-existent collections
|
||||
- Collection deletion cascade across all stores
|
||||
- Error handling and recovery scenarios
|
||||
- Unit tests for each storage backend
|
||||
- Integration tests for cross-store operations
|
||||
|
||||
## Implementation Status
|
||||
|
||||
### ✅ Completed Components
|
||||
|
||||
1. **Librarian Collection Management Service** (`trustgraph-flow/trustgraph/librarian/collection_service.py`)
|
||||
- Complete collection CRUD operations (list, update, delete)
|
||||
1. **Librarian Collection Management Service** (`trustgraph-flow/trustgraph/librarian/collection_manager.py`)
|
||||
- Collection metadata CRUD operations (list, update, delete)
|
||||
- Cassandra collection metadata table integration via `LibraryTableStore`
|
||||
- Async request/response handling with proper error management
|
||||
- Collection deletion cascade coordination across all storage types
|
||||
- Async request/response handling with proper error management
|
||||
|
||||
2. **Collection Metadata Schema** (`trustgraph-base/trustgraph/schema/services/collection.py`)
|
||||
- `CollectionManagementRequest` and `CollectionManagementResponse` schemas
|
||||
|
|
@ -303,47 +357,70 @@ Migration will be performed during a maintenance window to ensure data consisten
|
|||
|
||||
3. **Storage Management Schema** (`trustgraph-base/trustgraph/schema/services/storage.py`)
|
||||
- `StorageManagementRequest` and `StorageManagementResponse` schemas
|
||||
- Storage management queue topics defined
|
||||
- Message format for storage-level collection operations
|
||||
|
||||
### ❌ Missing Components
|
||||
4. **Cassandra 4-Table Schema** (`trustgraph-flow/trustgraph/direct/cassandra_kg.py`)
|
||||
- Compound partition keys for query performance
|
||||
- `triples_collection` table for SPO queries and deletion tracking
|
||||
- Collection deletion implemented with read-then-delete pattern
|
||||
|
||||
1. **Storage Management Queue Topics**
|
||||
- Missing topic definitions in schema for:
|
||||
- `vector_storage_management_topic`
|
||||
- `object_storage_management_topic`
|
||||
- `triples_storage_management_topic`
|
||||
- `storage_management_response_topic`
|
||||
- These are referenced by the librarian service but not yet defined
|
||||
### 🔄 In Progress Components
|
||||
|
||||
2. **Store Collection Management Handlers**
|
||||
- **Vector Store Writers** (Qdrant, Milvus, Pinecone): No collection deletion handlers
|
||||
- **Object Store Writers** (Cassandra): No collection deletion handlers
|
||||
- **Triple Store Writers** (Cassandra, Neo4j, Memgraph, FalkorDB): No collection deletion handlers
|
||||
- Need to implement `StorageManagementRequest` processing in each store writer
|
||||
1. **Collection Creation Broadcast** (`trustgraph-flow/trustgraph/librarian/collection_manager.py`)
|
||||
- Update `update_collection()` to send "create-collection" to storage backends
|
||||
- Wait for confirmations from all storage processors
|
||||
- Handle creation failures appropriately
|
||||
|
||||
3. **Collection Management Interface Implementation**
|
||||
- Store writers need collection management message consumers
|
||||
- Collection deletion operations need to be implemented per store type
|
||||
- Response handling back to librarian service
|
||||
2. **Document Submission Handler** (`trustgraph-flow/trustgraph/librarian/service.py` or similar)
|
||||
- Check if collection exists when document submitted
|
||||
- If not exists: Create collection with defaults before processing document
|
||||
- Trigger same "create-collection" broadcast as `tg-set-collection`
|
||||
- Ensure collection established before document flows to storage processors
|
||||
|
||||
### ❌ Pending Components
|
||||
|
||||
1. **Collection State Tracking** - Need to implement in each storage backend:
|
||||
- **Cassandra Triples**: Use `triples_collection` table with marker triples
|
||||
- **Neo4j/Memgraph/FalkorDB**: Create `:CollectionMetadata` nodes
|
||||
- **Qdrant/Milvus/Pinecone**: Use native collection APIs
|
||||
- **Cassandra Objects**: Add collection metadata tracking
|
||||
|
||||
2. **Storage Management Handlers** - Need "create-collection" support in 12 files:
|
||||
- `trustgraph-flow/trustgraph/storage/triples/cassandra/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/triples/neo4j/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/triples/memgraph/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/triples/falkordb/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/objects/cassandra/write.py`
|
||||
- Plus any other storage implementations
|
||||
|
||||
3. **Write Operation Validation** - Add collection existence checks to all `store_*` methods
|
||||
|
||||
4. **Query Operation Handling** - Update queries to return empty for non-existent collections
|
||||
|
||||
### Next Implementation Steps
|
||||
|
||||
1. **Define Storage Management Topics** in `trustgraph-base/trustgraph/schema/services/storage.py`
|
||||
2. **Implement Collection Management Handlers** in each storage writer:
|
||||
- Add `StorageManagementRequest` consumers
|
||||
- Implement collection deletion operations
|
||||
- Add response producers for status reporting
|
||||
3. **Test End-to-End Collection Deletion** across all storage types
|
||||
**Phase 1: Core Infrastructure (2-3 days)**
|
||||
1. Add collection state tracking methods to all storage backends
|
||||
2. Implement `collection_exists()` and `create_collection()` methods
|
||||
|
||||
## Timeline
|
||||
**Phase 2: Storage Handlers (1 week)**
|
||||
3. Add "create-collection" handlers to all storage processors
|
||||
4. Add write validation to reject non-existent collections
|
||||
5. Update query handling for non-existent collections
|
||||
|
||||
Phase 1 (Storage Topics): 1-2 days
|
||||
Phase 2 (Store Handlers): 1-2 weeks depending on number of storage backends
|
||||
Phase 3 (Testing & Integration): 3-5 days
|
||||
**Phase 3: Collection Manager (2-3 days)**
|
||||
6. Update collection_manager to broadcast creates
|
||||
7. Implement response tracking and error handling
|
||||
|
||||
## Open Questions
|
||||
|
||||
- Should collection deletion be soft or hard delete by default?
|
||||
- What metadata fields should be required vs optional?
|
||||
- Should we implement storage management handlers incrementally by store type?
|
||||
**Phase 4: Testing (3-5 days)**
|
||||
8. End-to-end testing of explicit creation workflow
|
||||
9. Test all storage backends
|
||||
10. Validate error handling and edge cases
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ A flow class defines a complete dataflow pattern template in the TrustGraph syst
|
|||
|
||||
## Structure
|
||||
|
||||
A flow class definition consists of four main sections:
|
||||
A flow class definition consists of five main sections:
|
||||
|
||||
### 1. Class Section
|
||||
Defines shared service processors that are instantiated once per flow class. These processors handle requests from all flow instances of this class.
|
||||
|
|
@ -15,7 +15,11 @@ Defines shared service processors that are instantiated once per flow class. The
|
|||
"class": {
|
||||
"service-name:{class}": {
|
||||
"request": "queue-pattern:{class}",
|
||||
"response": "queue-pattern:{class}"
|
||||
"response": "queue-pattern:{class}",
|
||||
"settings": {
|
||||
"setting-name": "fixed-value",
|
||||
"parameterized-setting": "{parameter-name}"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
|
@ -24,6 +28,7 @@ Defines shared service processors that are instantiated once per flow class. The
|
|||
- Shared across all flow instances of the same class
|
||||
- Typically expensive or stateless services (LLMs, embedding models)
|
||||
- Use `{class}` template variable for queue naming
|
||||
- Settings can be fixed values or parameterized with `{parameter-name}` syntax
|
||||
- Examples: `embeddings:{class}`, `text-completion:{class}`, `graph-rag:{class}`
|
||||
|
||||
### 2. Flow Section
|
||||
|
|
@ -33,7 +38,11 @@ Defines flow-specific processors that are instantiated for each individual flow
|
|||
"flow": {
|
||||
"processor-name:{id}": {
|
||||
"input": "queue-pattern:{id}",
|
||||
"output": "queue-pattern:{id}"
|
||||
"output": "queue-pattern:{id}",
|
||||
"settings": {
|
||||
"setting-name": "fixed-value",
|
||||
"parameterized-setting": "{parameter-name}"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
|
@ -42,6 +51,7 @@ Defines flow-specific processors that are instantiated for each individual flow
|
|||
- Unique instance per flow
|
||||
- Handle flow-specific data and state
|
||||
- Use `{id}` template variable for queue naming
|
||||
- Settings can be fixed values or parameterized with `{parameter-name}` syntax
|
||||
- Examples: `chunker:{id}`, `pdf-decoder:{id}`, `kg-extract-relationships:{id}`
|
||||
|
||||
### 3. Interfaces Section
|
||||
|
|
@ -72,7 +82,24 @@ Interfaces can take two forms:
|
|||
- **Service Interfaces**: Request/response patterns for services (`embeddings`, `text-completion`)
|
||||
- **Data Interfaces**: Fire-and-forget data flow connection points (`triples-store`, `entity-contexts-load`)
|
||||
|
||||
### 4. Metadata
|
||||
### 4. Parameters Section
|
||||
Maps flow-specific parameter names to centrally-stored parameter definitions:
|
||||
|
||||
```json
|
||||
"parameters": {
|
||||
"model": "llm-model",
|
||||
"temp": "temperature",
|
||||
"chunk": "chunk-size"
|
||||
}
|
||||
```
|
||||
|
||||
**Characteristics:**
|
||||
- Keys are parameter names used in processor settings (e.g., `{model}`)
|
||||
- Values reference parameter definitions stored in schema/config
|
||||
- Enables reuse of common parameter definitions across flows
|
||||
- Reduces duplication of parameter schemas
|
||||
|
||||
### 5. Metadata
|
||||
Additional information about the flow class:
|
||||
|
||||
```json
|
||||
|
|
@ -82,16 +109,98 @@ Additional information about the flow class:
|
|||
|
||||
## Template Variables
|
||||
|
||||
### {id}
|
||||
### System Variables
|
||||
|
||||
#### {id}
|
||||
- Replaced with the unique flow instance identifier
|
||||
- Creates isolated resources for each flow
|
||||
- Example: `flow-123`, `customer-A-flow`
|
||||
|
||||
### {class}
|
||||
#### {class}
|
||||
- Replaced with the flow class name
|
||||
- Creates shared resources across flows of the same class
|
||||
- Example: `standard-rag`, `enterprise-rag`
|
||||
|
||||
### Parameter Variables
|
||||
|
||||
#### {parameter-name}
|
||||
- Custom parameters defined at flow launch time
|
||||
- Parameter names match keys in the flow's `parameters` section
|
||||
- Used in processor settings to customize behavior
|
||||
- Examples: `{model}`, `{temp}`, `{chunk}`
|
||||
- Replaced with values provided when launching the flow
|
||||
- Validated against centrally-stored parameter definitions
|
||||
|
||||
## Processor Settings
|
||||
|
||||
Settings provide configuration values to processors at instantiation time. They can be:
|
||||
|
||||
### Fixed Settings
|
||||
Direct values that don't change:
|
||||
```json
|
||||
"settings": {
|
||||
"model": "gemma3:12b",
|
||||
"temperature": 0.7,
|
||||
"max_retries": 3
|
||||
}
|
||||
```
|
||||
|
||||
### Parameterized Settings
|
||||
Values that use parameters provided at flow launch:
|
||||
```json
|
||||
"settings": {
|
||||
"model": "{model}",
|
||||
"temperature": "{temp}",
|
||||
"endpoint": "https://{region}.api.example.com"
|
||||
}
|
||||
```
|
||||
|
||||
Parameter names in settings correspond to keys in the flow's `parameters` section.
|
||||
|
||||
### Settings Examples
|
||||
|
||||
**LLM Processor with Parameters:**
|
||||
```json
|
||||
// In parameters section:
|
||||
"parameters": {
|
||||
"model": "llm-model",
|
||||
"temp": "temperature",
|
||||
"tokens": "max-tokens",
|
||||
"key": "openai-api-key"
|
||||
}
|
||||
|
||||
// In processor definition:
|
||||
"text-completion:{class}": {
|
||||
"request": "non-persistent://tg/request/text-completion:{class}",
|
||||
"response": "non-persistent://tg/response/text-completion:{class}",
|
||||
"settings": {
|
||||
"model": "{model}",
|
||||
"temperature": "{temp}",
|
||||
"max_tokens": "{tokens}",
|
||||
"api_key": "{key}"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Chunker with Fixed and Parameterized Settings:**
|
||||
```json
|
||||
// In parameters section:
|
||||
"parameters": {
|
||||
"chunk": "chunk-size"
|
||||
}
|
||||
|
||||
// In processor definition:
|
||||
"chunker:{id}": {
|
||||
"input": "persistent://tg/flow/chunk:{id}",
|
||||
"output": "persistent://tg/flow/chunk-load:{id}",
|
||||
"settings": {
|
||||
"chunk_size": "{chunk}",
|
||||
"chunk_overlap": 100,
|
||||
"encoding": "utf-8"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Queue Patterns (Pulsar)
|
||||
|
||||
Flow classes use Apache Pulsar for messaging. Queue names follow the Pulsar format:
|
||||
|
|
@ -137,15 +246,27 @@ All processors (both `{id}` and `{class}`) work together as a cohesive dataflow
|
|||
Given:
|
||||
- Flow Instance ID: `customer-A-flow`
|
||||
- Flow Class: `standard-rag`
|
||||
- Flow parameter mappings:
|
||||
- `"model": "llm-model"`
|
||||
- `"temp": "temperature"`
|
||||
- `"chunk": "chunk-size"`
|
||||
- User-provided parameters:
|
||||
- `model`: `gpt-4`
|
||||
- `temp`: `0.5`
|
||||
- `chunk`: `512`
|
||||
|
||||
Template expansions:
|
||||
- `persistent://tg/flow/chunk-load:{id}` → `persistent://tg/flow/chunk-load:customer-A-flow`
|
||||
- `non-persistent://tg/request/embeddings:{class}` → `non-persistent://tg/request/embeddings:standard-rag`
|
||||
- `"model": "{model}"` → `"model": "gpt-4"`
|
||||
- `"temperature": "{temp}"` → `"temperature": "0.5"`
|
||||
- `"chunk_size": "{chunk}"` → `"chunk_size": "512"`
|
||||
|
||||
This creates:
|
||||
- Isolated document processing pipeline for `customer-A-flow`
|
||||
- Shared embedding service for all `standard-rag` flows
|
||||
- Complete dataflow from document ingestion through querying
|
||||
- Processors configured with the provided parameter values
|
||||
|
||||
## Benefits
|
||||
|
||||
|
|
|
|||
485
docs/tech-specs/flow-configurable-parameters.md
Normal file
485
docs/tech-specs/flow-configurable-parameters.md
Normal file
|
|
@ -0,0 +1,485 @@
|
|||
# Flow Class Configurable Parameters Technical Specification
|
||||
|
||||
## Overview
|
||||
|
||||
This specification describes the implementation of configurable parameters for flow classes in TrustGraph. Parameters enable users to customize processor parameters at flow launch time by providing values that replace parameter placeholders in the flow class definition.
|
||||
|
||||
Parameters work through template variable substitution in processor parameters, similar to how `{id}` and `{class}` variables work, but with user-provided values.
|
||||
|
||||
The integration supports four primary use cases:
|
||||
|
||||
1. **Model Selection**: Allowing users to choose different LLM models (e.g., `gemma3:8b`, `gpt-4`, `claude-3`) for processors
|
||||
2. **Resource Configuration**: Adjusting processor parameters like chunk sizes, batch sizes, and concurrency limits
|
||||
3. **Behavioral Tuning**: Modifying processor behavior through parameters like temperature, max-tokens, or retrieval thresholds
|
||||
4. **Environment-Specific Parameters**: Configuring endpoints, API keys, or region-specific URLs per deployment
|
||||
|
||||
## Goals
|
||||
|
||||
- **Dynamic Processor Configuration**: Enable runtime configuration of processor parameters through parameter substitution
|
||||
- **Parameter Validation**: Provide type checking and validation for parameters at flow launch time
|
||||
- **Default Values**: Support sensible defaults while allowing overrides for advanced users
|
||||
- **Template Substitution**: Seamlessly replace parameter placeholders in processor parameters
|
||||
- **UI Integration**: Enable parameter input through both API and UI interfaces
|
||||
- **Type Safety**: Ensure parameter types match expected processor parameter types
|
||||
- **Documentation**: Self-documenting parameter schemas within flow class definitions
|
||||
- **Backward Compatibility**: Maintain compatibility with existing flow classes that don't use parameters
|
||||
|
||||
## Background
|
||||
|
||||
Flow classes in TrustGraph now support processor parameters that can contain either fixed values or parameter placeholders. This creates an opportunity for runtime customization.
|
||||
|
||||
Current processor parameters support:
|
||||
- Fixed values: `"model": "gemma3:12b"`
|
||||
- Parameter placeholders: `"model": "gemma3:{model-size}"`
|
||||
|
||||
This specification defines how parameters are:
|
||||
- Declared in flow class definitions
|
||||
- Validated when flows are launched
|
||||
- Substituted in processor parameters
|
||||
- Exposed through APIs and UI
|
||||
|
||||
By leveraging parameterized processor parameters, TrustGraph can:
|
||||
- Reduce flow class duplication by using parameters for variations
|
||||
- Enable users to tune processor behavior without modifying definitions
|
||||
- Support environment-specific configurations through parameter values
|
||||
- Maintain type safety through parameter schema validation
|
||||
|
||||
## Technical Design
|
||||
|
||||
### Architecture
|
||||
|
||||
The configurable parameters system requires the following technical components:
|
||||
|
||||
1. **Parameter Schema Definition**
|
||||
- JSON Schema-based parameter definitions within flow class metadata
|
||||
- Type definitions including string, number, boolean, enum, and object types
|
||||
- Validation rules including min/max values, patterns, and required fields
|
||||
|
||||
Module: trustgraph-flow/trustgraph/flow/definition.py
|
||||
|
||||
2. **Parameter Resolution Engine**
|
||||
- Runtime parameter validation against schema
|
||||
- Default value application for unspecified parameters
|
||||
- Parameter injection into flow execution context
|
||||
- Type coercion and conversion as needed
|
||||
|
||||
Module: trustgraph-flow/trustgraph/flow/parameter_resolver.py
|
||||
|
||||
3. **Parameter Store Integration**
|
||||
- Retrieval of parameter definitions from schema/config store
|
||||
- Caching of frequently-used parameter definitions
|
||||
- Validation against centrally-stored schemas
|
||||
|
||||
Module: trustgraph-flow/trustgraph/flow/parameter_store.py
|
||||
|
||||
4. **Flow Launcher Extensions**
|
||||
- API extensions to accept parameter values during flow launch
|
||||
- Parameter mapping resolution (flow names to definition names)
|
||||
- Error handling for invalid parameter combinations
|
||||
|
||||
Module: trustgraph-flow/trustgraph/flow/launcher.py
|
||||
|
||||
5. **UI Parameter Forms**
|
||||
- Dynamic form generation from flow parameter metadata
|
||||
- Ordered parameter display using `order` field
|
||||
- Descriptive parameter labels using `description` field
|
||||
- Input validation against parameter type definitions
|
||||
- Parameter presets and templates
|
||||
|
||||
Module: trustgraph-ui/components/flow-parameters/
|
||||
|
||||
### Data Models
|
||||
|
||||
#### Parameter Definitions (Stored in Schema/Config)
|
||||
|
||||
Parameter definitions are stored centrally in the schema and config system with type "parameter-types":
|
||||
|
||||
```json
|
||||
{
|
||||
"llm-model": {
|
||||
"type": "string",
|
||||
"description": "LLM model to use",
|
||||
"default": "gpt-4",
|
||||
"enum": [
|
||||
{
|
||||
"id": "gpt-4",
|
||||
"description": "OpenAI GPT-4 (Most Capable)"
|
||||
},
|
||||
{
|
||||
"id": "gpt-3.5-turbo",
|
||||
"description": "OpenAI GPT-3.5 Turbo (Fast & Efficient)"
|
||||
},
|
||||
{
|
||||
"id": "claude-3",
|
||||
"description": "Anthropic Claude 3 (Thoughtful & Safe)"
|
||||
},
|
||||
{
|
||||
"id": "gemma3:8b",
|
||||
"description": "Google Gemma 3 8B (Open Source)"
|
||||
}
|
||||
],
|
||||
"required": false
|
||||
},
|
||||
"model-size": {
|
||||
"type": "string",
|
||||
"description": "Model size variant",
|
||||
"default": "8b",
|
||||
"enum": ["2b", "8b", "12b", "70b"],
|
||||
"required": false
|
||||
},
|
||||
"temperature": {
|
||||
"type": "number",
|
||||
"description": "Model temperature for generation",
|
||||
"default": 0.7,
|
||||
"minimum": 0.0,
|
||||
"maximum": 2.0,
|
||||
"required": false
|
||||
},
|
||||
"chunk-size": {
|
||||
"type": "integer",
|
||||
"description": "Document chunk size",
|
||||
"default": 512,
|
||||
"minimum": 128,
|
||||
"maximum": 2048,
|
||||
"required": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Flow Class with Parameter References
|
||||
|
||||
Flow classes define parameter metadata with type references, descriptions, and ordering:
|
||||
|
||||
```json
|
||||
{
|
||||
"flow_class": "document-analysis",
|
||||
"parameters": {
|
||||
"llm-model": {
|
||||
"type": "llm-model",
|
||||
"description": "Primary LLM model for text completion",
|
||||
"order": 1
|
||||
},
|
||||
"llm-rag-model": {
|
||||
"type": "llm-model",
|
||||
"description": "LLM model for RAG operations",
|
||||
"order": 2,
|
||||
"advanced": true,
|
||||
"controlled-by": "llm-model"
|
||||
},
|
||||
"llm-temperature": {
|
||||
"type": "temperature",
|
||||
"description": "Generation temperature for creativity control",
|
||||
"order": 3,
|
||||
"advanced": true
|
||||
},
|
||||
"chunk-size": {
|
||||
"type": "chunk-size",
|
||||
"description": "Document chunk size for processing",
|
||||
"order": 4,
|
||||
"advanced": true
|
||||
},
|
||||
"chunk-overlap": {
|
||||
"type": "integer",
|
||||
"description": "Overlap between document chunks",
|
||||
"order": 5,
|
||||
"advanced": true,
|
||||
"controlled-by": "chunk-size"
|
||||
}
|
||||
},
|
||||
"class": {
|
||||
"text-completion:{class}": {
|
||||
"request": "non-persistent://tg/request/text-completion:{class}",
|
||||
"response": "non-persistent://tg/response/text-completion:{class}",
|
||||
"parameters": {
|
||||
"model": "{llm-model}",
|
||||
"temperature": "{llm-temperature}"
|
||||
}
|
||||
},
|
||||
"rag-completion:{class}": {
|
||||
"request": "non-persistent://tg/request/rag-completion:{class}",
|
||||
"response": "non-persistent://tg/response/rag-completion:{class}",
|
||||
"parameters": {
|
||||
"model": "{llm-rag-model}",
|
||||
"temperature": "{llm-temperature}"
|
||||
}
|
||||
}
|
||||
},
|
||||
"flow": {
|
||||
"chunker:{id}": {
|
||||
"input": "persistent://tg/flow/chunk:{id}",
|
||||
"output": "persistent://tg/flow/chunk-load:{id}",
|
||||
"parameters": {
|
||||
"chunk_size": "{chunk-size}",
|
||||
"chunk_overlap": "{chunk-overlap}"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The `parameters` section maps flow-specific parameter names (keys) to parameter metadata objects containing:
|
||||
- `type`: Reference to centrally-defined parameter definition (e.g., "llm-model")
|
||||
- `description`: Human-readable description for UI display
|
||||
- `order`: Display order for parameter forms (lower numbers appear first)
|
||||
- `advanced` (optional): Boolean flag indicating if this is an advanced parameter (default: false). When set to true, the UI may hide this parameter by default or place it in an "Advanced" section
|
||||
- `controlled-by` (optional): Name of another parameter that controls this parameter's value when in simple mode. When specified, this parameter inherits its value from the controlling parameter unless explicitly overridden
|
||||
|
||||
This approach allows:
|
||||
- Reusable parameter type definitions across multiple flow classes
|
||||
- Centralized parameter type management and validation
|
||||
- Flow-specific parameter descriptions and ordering
|
||||
- Enhanced UI experience with descriptive parameter forms
|
||||
- Consistent parameter validation across flows
|
||||
- Easy addition of new standard parameter types
|
||||
- Simplified UI with basic/advanced mode separation
|
||||
- Parameter value inheritance for related settings
|
||||
|
||||
#### Flow Launch Request
|
||||
|
||||
The flow launch API accepts parameters using the flow's parameter names:
|
||||
|
||||
```json
|
||||
{
|
||||
"flow_class": "document-analysis",
|
||||
"flow_id": "customer-A-flow",
|
||||
"parameters": {
|
||||
"llm-model": "claude-3",
|
||||
"llm-temperature": 0.5,
|
||||
"chunk-size": 1024
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Note: In this example, `llm-rag-model` is not explicitly provided but will inherit the value "claude-3" from `llm-model` due to its `controlled-by` relationship. Similarly, `chunk-overlap` could inherit a calculated value based on `chunk-size`.
|
||||
|
||||
The system will:
|
||||
1. Extract parameter metadata from flow class definition
|
||||
2. Map flow parameter names to their type definitions (e.g., `llm-model` → `llm-model` type)
|
||||
3. Resolve controlled-by relationships (e.g., `llm-rag-model` inherits from `llm-model`)
|
||||
4. Validate user-provided and inherited values against the parameter type definitions
|
||||
5. Substitute resolved values into processor parameters during flow instantiation
|
||||
|
||||
### Implementation Details
|
||||
|
||||
#### Parameter Resolution Process
|
||||
|
||||
When a flow is started, the system performs the following parameter resolution steps:
|
||||
|
||||
1. **Flow Class Loading**: Load flow class definition and extract parameter metadata
|
||||
2. **Metadata Extraction**: Extract `type`, `description`, `order`, `advanced`, and `controlled-by` for each parameter defined in the flow class's `parameters` section
|
||||
3. **Type Definition Lookup**: For each parameter in the flow class:
|
||||
- Retrieve the parameter type definition from schema/config store using the `type` field
|
||||
- The type definitions are stored with type "parameter-types" in the config system
|
||||
- Each type definition contains the parameter's schema, default value, and validation rules
|
||||
4. **Default Value Resolution**:
|
||||
- For each parameter defined in the flow class:
|
||||
- Check if the user provided a value for this parameter
|
||||
- If no user value provided, use the `default` value from the parameter type definition
|
||||
- Build a complete parameter map containing both user-provided and default values
|
||||
5. **Parameter Inheritance Resolution** (controlled-by relationships):
|
||||
- For parameters with `controlled-by` field, check if a value was explicitly provided
|
||||
- If no explicit value provided, inherit the value from the controlling parameter
|
||||
- If the controlling parameter also has no value, use the default from the type definition
|
||||
- Validate that no circular dependencies exist in `controlled-by` relationships
|
||||
6. **Validation**: Validate the complete parameter set (user-provided, defaults, and inherited) against type definitions
|
||||
7. **Storage**: Store the complete resolved parameter set with the flow instance for auditability
|
||||
8. **Template Substitution**: Replace parameter placeholders in processor parameters with resolved values
|
||||
9. **Processor Instantiation**: Create processors with substituted parameters
|
||||
|
||||
**Important Implementation Notes:**
|
||||
- The flow service MUST merge user-provided parameters with defaults from parameter type definitions
|
||||
- The complete parameter set (including applied defaults) MUST be stored with the flow for traceability
|
||||
- Parameter resolution happens at flow start time, not at processor instantiation time
|
||||
- Missing required parameters without defaults MUST cause flow start to fail with a clear error message
|
||||
|
||||
#### Parameter Inheritance with controlled-by
|
||||
|
||||
The `controlled-by` field enables parameter value inheritance, particularly useful for simplifying user interfaces while maintaining flexibility:
|
||||
|
||||
**Example Scenario**:
|
||||
- `llm-model` parameter controls the primary LLM model
|
||||
- `llm-rag-model` parameter has `"controlled-by": "llm-model"`
|
||||
- In simple mode, setting `llm-model` to "gpt-4" automatically sets `llm-rag-model` to "gpt-4" as well
|
||||
- In advanced mode, users can override `llm-rag-model` with a different value
|
||||
|
||||
**Resolution Rules**:
|
||||
1. If a parameter has an explicitly provided value, use that value
|
||||
2. If no explicit value and `controlled-by` is set, use the controlling parameter's value
|
||||
3. If the controlling parameter has no value, fall back to the default from the type definition
|
||||
4. Circular dependencies in `controlled-by` relationships result in a validation error
|
||||
|
||||
**UI Behavior**:
|
||||
- In basic/simple mode: Parameters with `controlled-by` may be hidden or shown as read-only with inherited value
|
||||
- In advanced mode: All parameters are shown and can be individually configured
|
||||
- When a controlling parameter changes, dependent parameters update automatically unless explicitly overridden
|
||||
|
||||
#### Pulsar Integration
|
||||
|
||||
1. **Start-Flow Operation**
|
||||
- The Pulsar start-flow operation needs to accept a `parameters` field containing a map of parameter values
|
||||
- The Pulsar schema for the start-flow request must be updated to include the optional `parameters` field
|
||||
- Example request:
|
||||
```json
|
||||
{
|
||||
"flow_class": "document-analysis",
|
||||
"flow_id": "customer-A-flow",
|
||||
"parameters": {
|
||||
"model": "claude-3",
|
||||
"size": "12b",
|
||||
"temp": 0.5,
|
||||
"chunk": 1024
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
2. **Get-Flow Operation**
|
||||
- The Pulsar schema for the get-flow response must be updated to include the `parameters` field
|
||||
- This allows clients to retrieve the parameter values that were used when the flow was started
|
||||
- Example response:
|
||||
```json
|
||||
{
|
||||
"flow_id": "customer-A-flow",
|
||||
"flow_class": "document-analysis",
|
||||
"status": "running",
|
||||
"parameters": {
|
||||
"model": "claude-3",
|
||||
"size": "12b",
|
||||
"temp": 0.5,
|
||||
"chunk": 1024
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Flow Service Implementation
|
||||
|
||||
The flow configuration service (`trustgraph-flow/trustgraph/config/service/flow.py`) requires the following enhancements:
|
||||
|
||||
1. **Parameter Resolution Function**
|
||||
```python
|
||||
async def resolve_parameters(self, flow_class, user_params):
|
||||
"""
|
||||
Resolve parameters by merging user-provided values with defaults.
|
||||
|
||||
Args:
|
||||
flow_class: The flow class definition dict
|
||||
user_params: User-provided parameters dict
|
||||
|
||||
Returns:
|
||||
Complete parameter dict with user values and defaults merged
|
||||
"""
|
||||
```
|
||||
|
||||
This function should:
|
||||
- Extract parameter metadata from the flow class's `parameters` section
|
||||
- For each parameter, fetch its type definition from config store
|
||||
- Apply defaults for any parameters not provided by the user
|
||||
- Handle `controlled-by` inheritance relationships
|
||||
- Return the complete parameter set
|
||||
|
||||
2. **Modified `handle_start_flow` Method**
|
||||
- Call `resolve_parameters` after loading the flow class
|
||||
- Use the complete resolved parameter set for template substitution
|
||||
- Store the complete parameter set (not just user-provided) with the flow
|
||||
- Validate that all required parameters have values
|
||||
|
||||
3. **Parameter Type Fetching**
|
||||
- Parameter type definitions are stored in config with type "parameter-types"
|
||||
- Each type definition contains schema, default value, and validation rules
|
||||
- Cache frequently-used parameter types to reduce config lookups
|
||||
|
||||
#### Config System Integration
|
||||
|
||||
3. **Flow Object Storage**
|
||||
- When a flow is added to the config system by the flow component in the config manager, the flow object must include the resolved parameter values
|
||||
- The config manager needs to store both the original user-provided parameters and the resolved values (with defaults applied)
|
||||
- Flow objects in the config system should include:
|
||||
- `parameters`: The final resolved parameter values used for the flow
|
||||
|
||||
#### CLI Integration
|
||||
|
||||
4. **Library CLI Commands**
|
||||
- CLI commands that start flows need parameter support:
|
||||
- Accept parameter values via command-line flags or configuration files
|
||||
- Validate parameters against flow class definitions before submission
|
||||
- Support parameter file input (JSON/YAML) for complex parameter sets
|
||||
|
||||
- CLI commands that show flows need to display parameter information:
|
||||
- Show parameter values used when the flow was started
|
||||
- Display available parameters for a flow class
|
||||
- Show parameter validation schemas and defaults
|
||||
|
||||
#### Processor Base Class Integration
|
||||
|
||||
5. **ParameterSpec Support**
|
||||
- Processor base classes need to support parameter substitution through the existing ParametersSpec mechanism
|
||||
- The ParametersSpec class (located in the same module as ConsumerSpec and ProducerSpec) should be enhanced if necessary to support parameter template substitution
|
||||
- Processors should be able to invoke ParametersSpec to configure their parameters with parameter values resolved at flow launch time
|
||||
- The ParametersSpec implementation needs to:
|
||||
- Accept parameters configurations that contain parameter placeholders (e.g., `{model}`, `{temperature}`)
|
||||
- Support runtime parameter substitution when the processor is instantiated
|
||||
- Validate that substituted values match expected types and constraints
|
||||
- Provide error handling for missing or invalid parameter references
|
||||
|
||||
#### Substitution Rules
|
||||
|
||||
- Parameters use the format `{parameter-name}` in processor parameters
|
||||
- Parameter names in parameters match the keys in the flow's `parameters` section
|
||||
- Substitution occurs alongside `{id}` and `{class}` replacement
|
||||
- Invalid parameter references result in launch-time errors
|
||||
- Type validation happens based on the centrally-stored parameter definition
|
||||
- **IMPORTANT**: All parameter values are stored and transmitted as strings
|
||||
- Numbers are converted to strings (e.g., `0.7` becomes `"0.7"`)
|
||||
- Booleans are converted to lowercase strings (e.g., `true` becomes `"true"`)
|
||||
- This is required by the Pulsar schema which defines `parameters = Map(String())`
|
||||
|
||||
Example resolution:
|
||||
```
|
||||
Flow parameter mapping: "model": "llm-model"
|
||||
Processor parameter: "model": "{model}"
|
||||
User provides: "model": "gemma3:8b"
|
||||
Final parameter: "model": "gemma3:8b"
|
||||
|
||||
Example with type conversion:
|
||||
Parameter type default: 0.7 (number)
|
||||
Stored in flow: "0.7" (string)
|
||||
Substituted in processor: "0.7" (string)
|
||||
```
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
- Unit tests for parameter schema validation
|
||||
- Integration tests for parameter substitution in processor parameters
|
||||
- End-to-end tests for launching flows with different parameter values
|
||||
- UI tests for parameter form generation and validation
|
||||
- Performance tests for flows with many parameters
|
||||
- Edge cases: missing parameters, invalid types, undefined parameter references
|
||||
|
||||
## Migration Plan
|
||||
|
||||
1. The system should continue to support flow classes with no parameters
|
||||
declared.
|
||||
2. The system should continue to support flows no parameters specified:
|
||||
This works for flows with no parameters, and flows with parameters
|
||||
(they have defaults).
|
||||
|
||||
## Open Questions
|
||||
|
||||
Q: Should parameters support complex nested objects or keep to simple types?
|
||||
A: The parameter values will be string encoded, we're probably going to want
|
||||
to stick to strings.
|
||||
|
||||
Q: Should parameter placeholders be allowed in queue names or only in
|
||||
parameters?
|
||||
A: Only in parameters to remove strange injections and edge-cases.
|
||||
|
||||
Q: How to handle conflicts between parameter names and system variables like
|
||||
`id` and `class`?
|
||||
A: It is not valid to specify id and class when launching a flow
|
||||
|
||||
Q: Should we support computed parameters (derived from other parameters)?
|
||||
A: Just string substitution to remove strange injections and edge-cases.
|
||||
|
||||
## References
|
||||
|
||||
- JSON Schema Specification: https://json-schema.org/
|
||||
- Flow Class Definition Spec: docs/tech-specs/flow-class-definition.md
|
||||
629
docs/tech-specs/graphrag-performance-optimization.md
Normal file
629
docs/tech-specs/graphrag-performance-optimization.md
Normal file
|
|
@ -0,0 +1,629 @@
|
|||
# GraphRAG Performance Optimisation Technical Specification
|
||||
|
||||
## Overview
|
||||
|
||||
This specification describes comprehensive performance optimisations for the GraphRAG (Graph Retrieval-Augmented Generation) algorithm in TrustGraph. The current implementation suffers from significant performance bottlenecks that limit scalability and response times. This specification addresses four primary optimisation areas:
|
||||
|
||||
1. **Graph Traversal Optimisation**: Eliminate inefficient recursive database queries and implement batched graph exploration
|
||||
2. **Label Resolution Optimisation**: Replace sequential label fetching with parallel/batched operations
|
||||
3. **Caching Strategy Enhancement**: Implement intelligent caching with LRU eviction and prefetching
|
||||
4. **Query Optimisation**: Add result memoisation and embedding caching for improved response times
|
||||
|
||||
## Goals
|
||||
|
||||
- **Reduce Database Query Volume**: Achieve 50-80% reduction in total database queries through batching and caching
|
||||
- **Improve Response Times**: Target 3-5x faster subgraph construction and 2-3x faster label resolution
|
||||
- **Enhance Scalability**: Support larger knowledge graphs with better memory management
|
||||
- **Maintain Accuracy**: Preserve existing GraphRAG functionality and result quality
|
||||
- **Enable Concurrency**: Improve parallel processing capabilities for multiple concurrent requests
|
||||
- **Reduce Memory Footprint**: Implement efficient data structures and memory management
|
||||
- **Add Observability**: Include performance metrics and monitoring capabilities
|
||||
- **Ensure Reliability**: Add proper error handling and timeout mechanisms
|
||||
|
||||
## Background
|
||||
|
||||
The current GraphRAG implementation in `trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py` exhibits several critical performance issues that severely impact system scalability:
|
||||
|
||||
### Current Performance Problems
|
||||
|
||||
**1. Inefficient Graph Traversal (`follow_edges` function, lines 79-127)**
|
||||
- Makes 3 separate database queries per entity per depth level
|
||||
- Query pattern: subject-based, predicate-based, and object-based queries for each entity
|
||||
- No batching: Each query processes only one entity at a time
|
||||
- No cycle detection: Can revisit the same nodes multiple times
|
||||
- Recursive implementation without memoisation leads to exponential complexity
|
||||
- Time complexity: O(entities × max_path_length × triple_limit³)
|
||||
|
||||
**2. Sequential Label Resolution (`get_labelgraph` function, lines 144-171)**
|
||||
- Processes each triple component (subject, predicate, object) sequentially
|
||||
- Each `maybe_label` call potentially triggers a database query
|
||||
- No parallel execution or batching of label queries
|
||||
- Results in up to 3 × subgraph_size individual database calls
|
||||
|
||||
**3. Primitive Caching Strategy (`maybe_label` function, lines 62-77)**
|
||||
- Simple dictionary cache without size limits or TTL
|
||||
- No cache eviction policy leads to unbounded memory growth
|
||||
- Cache misses trigger individual database queries
|
||||
- No prefetching or intelligent cache warming
|
||||
|
||||
**4. Suboptimal Query Patterns**
|
||||
- Entity vector similarity queries not cached between similar requests
|
||||
- No result memoisation for repeated query patterns
|
||||
- Missing query optimisation for common access patterns
|
||||
|
||||
**5. Critical Object Lifetime Issues (`rag.py:96-102`)**
|
||||
- **GraphRag object recreated per request**: Fresh instance created for every query, losing all cache benefits
|
||||
- **Query object extremely short-lived**: Created and destroyed within single query execution (lines 201-207)
|
||||
- **Label cache reset per request**: Cache warming and accumulated knowledge lost between requests
|
||||
- **Client recreation overhead**: Database clients potentially re-established for each request
|
||||
- **No cross-request optimisation**: Cannot benefit from query patterns or result sharing
|
||||
|
||||
### Performance Impact Analysis
|
||||
|
||||
Current worst-case scenario for a typical query:
|
||||
- **Entity Retrieval**: 1 vector similarity query
|
||||
- **Graph Traversal**: entities × max_path_length × 3 × triple_limit queries
|
||||
- **Label Resolution**: subgraph_size × 3 individual label queries
|
||||
|
||||
For default parameters (50 entities, path length 2, 30 triple limit, 150 subgraph size):
|
||||
- **Minimum queries**: 1 + (50 × 2 × 3 × 30) + (150 × 3) = **9,451 database queries**
|
||||
- **Response time**: 15-30 seconds for moderate-sized graphs
|
||||
- **Memory usage**: Unbounded cache growth over time
|
||||
- **Cache effectiveness**: 0% - caches reset on every request
|
||||
- **Object creation overhead**: GraphRag + Query objects created/destroyed per request
|
||||
|
||||
This specification addresses these gaps by implementing batched queries, intelligent caching, and parallel processing. By optimizing query patterns and data access, TrustGraph can:
|
||||
- Support enterprise-scale knowledge graphs with millions of entities
|
||||
- Provide sub-second response times for typical queries
|
||||
- Handle hundreds of concurrent GraphRAG requests
|
||||
- Scale efficiently with graph size and complexity
|
||||
|
||||
## Technical Design
|
||||
|
||||
### Architecture
|
||||
|
||||
The GraphRAG performance optimisation requires the following technical components:
|
||||
|
||||
#### 1. **Object Lifetime Architectural Refactor**
|
||||
- **Make GraphRag long-lived**: Move GraphRag instance to Processor level for persistence across requests
|
||||
- **Preserve caches**: Maintain label cache, embedding cache, and query result cache between requests
|
||||
- **Optimize Query object**: Refactor Query as lightweight execution context, not data container
|
||||
- **Connection persistence**: Maintain database client connections across requests
|
||||
|
||||
Module: `trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py` (modified)
|
||||
|
||||
#### 2. **Optimized Graph Traversal Engine**
|
||||
- Replace recursive `follow_edges` with iterative breadth-first search
|
||||
- Implement batched entity processing at each traversal level
|
||||
- Add cycle detection using visited node tracking
|
||||
- Include early termination when limits are reached
|
||||
|
||||
Module: `trustgraph-flow/trustgraph/retrieval/graph_rag/optimized_traversal.py`
|
||||
|
||||
#### 3. **Parallel Label Resolution System**
|
||||
- Batch label queries for multiple entities simultaneously
|
||||
- Implement async/await patterns for concurrent database access
|
||||
- Add intelligent prefetching for common label patterns
|
||||
- Include label cache warming strategies
|
||||
|
||||
Module: `trustgraph-flow/trustgraph/retrieval/graph_rag/label_resolver.py`
|
||||
|
||||
#### 4. **Conservative Label Caching Layer**
|
||||
- LRU cache with short TTL for labels only (5min) to balance performance vs consistency
|
||||
- Cache metrics and hit ratio monitoring
|
||||
- **No embedding caching**: Already cached per-query, no cross-query benefit
|
||||
- **No query result caching**: Due to graph mutation consistency concerns
|
||||
|
||||
Module: `trustgraph-flow/trustgraph/retrieval/graph_rag/cache_manager.py`
|
||||
|
||||
#### 5. **Query Optimisation Framework**
|
||||
- Query pattern analysis and optimisation suggestions
|
||||
- Batch query coordinator for database access
|
||||
- Connection pooling and query timeout management
|
||||
- Performance monitoring and metrics collection
|
||||
|
||||
Module: `trustgraph-flow/trustgraph/retrieval/graph_rag/query_optimizer.py`
|
||||
|
||||
### Data Models
|
||||
|
||||
#### Optimized Graph Traversal State
|
||||
|
||||
The traversal engine maintains state to avoid redundant operations:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class TraversalState:
|
||||
visited_entities: Set[str]
|
||||
current_level_entities: Set[str]
|
||||
next_level_entities: Set[str]
|
||||
subgraph: Set[Tuple[str, str, str]]
|
||||
depth: int
|
||||
query_batch: List[TripleQuery]
|
||||
```
|
||||
|
||||
This approach allows:
|
||||
- Efficient cycle detection through visited entity tracking
|
||||
- Batched query preparation at each traversal level
|
||||
- Memory-efficient state management
|
||||
- Early termination when size limits are reached
|
||||
|
||||
#### Enhanced Cache Structure
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
value: Any
|
||||
timestamp: float
|
||||
access_count: int
|
||||
ttl: Optional[float]
|
||||
|
||||
class CacheManager:
|
||||
label_cache: LRUCache[str, CacheEntry]
|
||||
embedding_cache: LRUCache[str, CacheEntry]
|
||||
query_result_cache: LRUCache[str, CacheEntry]
|
||||
cache_stats: CacheStatistics
|
||||
```
|
||||
|
||||
#### Batch Query Structures
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class BatchTripleQuery:
|
||||
entities: List[str]
|
||||
query_type: QueryType # SUBJECT, PREDICATE, OBJECT
|
||||
limit_per_entity: int
|
||||
|
||||
@dataclass
|
||||
class BatchLabelQuery:
|
||||
entities: List[str]
|
||||
predicate: str = LABEL
|
||||
```
|
||||
|
||||
### APIs
|
||||
|
||||
#### New APIs:
|
||||
|
||||
**GraphTraversal API**
|
||||
```python
|
||||
async def optimized_follow_edges_batch(
|
||||
entities: List[str],
|
||||
max_depth: int,
|
||||
triple_limit: int,
|
||||
max_subgraph_size: int
|
||||
) -> Set[Tuple[str, str, str]]
|
||||
```
|
||||
|
||||
**Batch Label Resolution API**
|
||||
```python
|
||||
async def resolve_labels_batch(
|
||||
entities: List[str],
|
||||
cache_manager: CacheManager
|
||||
) -> Dict[str, str]
|
||||
```
|
||||
|
||||
**Cache Management API**
|
||||
```python
|
||||
class CacheManager:
|
||||
async def get_or_fetch_label(self, entity: str) -> str
|
||||
async def get_or_fetch_embeddings(self, query: str) -> List[float]
|
||||
async def cache_query_result(self, query_hash: str, result: Any, ttl: int)
|
||||
def get_cache_statistics(self) -> CacheStatistics
|
||||
```
|
||||
|
||||
#### Modified APIs:
|
||||
|
||||
**GraphRag.query()** - Enhanced with performance optimisations:
|
||||
- Add cache_manager parameter for cache control
|
||||
- Include performance_metrics return value
|
||||
- Add query_timeout parameter for reliability
|
||||
|
||||
**Query class** - Refactored for batch processing:
|
||||
- Replace individual entity processing with batch operations
|
||||
- Add async context managers for resource cleanup
|
||||
- Include progress callbacks for long-running operations
|
||||
|
||||
### Implementation Details
|
||||
|
||||
#### Phase 0: Critical Architectural Lifetime Refactor
|
||||
|
||||
**Current Problematic Implementation:**
|
||||
```python
|
||||
# INEFFICIENT: GraphRag recreated every request
|
||||
class Processor(FlowProcessor):
|
||||
async def on_request(self, msg, consumer, flow):
|
||||
# PROBLEM: New GraphRag instance per request!
|
||||
self.rag = GraphRag(
|
||||
embeddings_client = flow("embeddings-request"),
|
||||
graph_embeddings_client = flow("graph-embeddings-request"),
|
||||
triples_client = flow("triples-request"),
|
||||
prompt_client = flow("prompt-request"),
|
||||
verbose=True,
|
||||
)
|
||||
# Cache starts empty every time - no benefit from previous requests
|
||||
response = await self.rag.query(...)
|
||||
|
||||
# VERY SHORT-LIVED: Query object created/destroyed per request
|
||||
class GraphRag:
|
||||
async def query(self, query, user="trustgraph", collection="default", ...):
|
||||
q = Query(rag=self, user=user, collection=collection, ...) # Created
|
||||
kg = await q.get_labelgraph(query) # Used briefly
|
||||
# q automatically destroyed when function exits
|
||||
```
|
||||
|
||||
**Optimized Long-Lived Architecture:**
|
||||
```python
|
||||
class Processor(FlowProcessor):
|
||||
def __init__(self, **params):
|
||||
super().__init__(**params)
|
||||
self.rag_instance = None # Will be initialized once
|
||||
self.client_connections = {}
|
||||
|
||||
async def initialize_rag(self, flow):
|
||||
"""Initialize GraphRag once, reuse for all requests"""
|
||||
if self.rag_instance is None:
|
||||
self.rag_instance = LongLivedGraphRag(
|
||||
embeddings_client=flow("embeddings-request"),
|
||||
graph_embeddings_client=flow("graph-embeddings-request"),
|
||||
triples_client=flow("triples-request"),
|
||||
prompt_client=flow("prompt-request"),
|
||||
verbose=True,
|
||||
)
|
||||
return self.rag_instance
|
||||
|
||||
async def on_request(self, msg, consumer, flow):
|
||||
# REUSE the same GraphRag instance - caches persist!
|
||||
rag = await self.initialize_rag(flow)
|
||||
|
||||
# Query object becomes lightweight execution context
|
||||
response = await rag.query_with_context(
|
||||
query=v.query,
|
||||
execution_context=QueryContext(
|
||||
user=v.user,
|
||||
collection=v.collection,
|
||||
entity_limit=entity_limit,
|
||||
# ... other params
|
||||
)
|
||||
)
|
||||
|
||||
class LongLivedGraphRag:
|
||||
def __init__(self, ...):
|
||||
# CONSERVATIVE caches - balance performance vs consistency
|
||||
self.label_cache = LRUCacheWithTTL(max_size=5000, ttl=300) # 5min TTL for freshness
|
||||
# Note: No embedding cache - already cached per-query, no cross-query benefit
|
||||
# Note: No query result cache due to consistency concerns
|
||||
self.performance_metrics = PerformanceTracker()
|
||||
|
||||
async def query_with_context(self, query: str, context: QueryContext):
|
||||
# Use lightweight QueryExecutor instead of heavyweight Query object
|
||||
executor = QueryExecutor(self, context) # Minimal object
|
||||
return await executor.execute(query)
|
||||
|
||||
@dataclass
|
||||
class QueryContext:
|
||||
"""Lightweight execution context - no heavy operations"""
|
||||
user: str
|
||||
collection: str
|
||||
entity_limit: int
|
||||
triple_limit: int
|
||||
max_subgraph_size: int
|
||||
max_path_length: int
|
||||
|
||||
class QueryExecutor:
|
||||
"""Lightweight execution context - replaces old Query class"""
|
||||
def __init__(self, rag: LongLivedGraphRag, context: QueryContext):
|
||||
self.rag = rag
|
||||
self.context = context
|
||||
# No heavy initialization - just references
|
||||
|
||||
async def execute(self, query: str):
|
||||
# All heavy lifting uses persistent rag caches
|
||||
return await self.rag.execute_optimized_query(query, self.context)
|
||||
```
|
||||
|
||||
This architectural change provides:
|
||||
- **10-20% database query reduction** for graphs with common relationships (vs 0% currently)
|
||||
- **Eliminated object creation overhead** for every request
|
||||
- **Persistent connection pooling** and client reuse
|
||||
- **Cross-request optimization** within cache TTL windows
|
||||
|
||||
**Important Cache Consistency Limitation:**
|
||||
Long-term caching introduces staleness risk when entities/labels are deleted or modified in the underlying graph. The LRU cache with TTL provides a balance between performance gains and data freshness, but cannot detect real-time graph changes.
|
||||
|
||||
#### Phase 1: Graph Traversal Optimisation
|
||||
|
||||
**Current Implementation Problems:**
|
||||
```python
|
||||
# INEFFICIENT: 3 queries per entity per level
|
||||
async def follow_edges(self, ent, subgraph, path_length):
|
||||
# Query 1: s=ent, p=None, o=None
|
||||
res = await self.rag.triples_client.query(s=ent, p=None, o=None, limit=self.triple_limit)
|
||||
# Query 2: s=None, p=ent, o=None
|
||||
res = await self.rag.triples_client.query(s=None, p=ent, o=None, limit=self.triple_limit)
|
||||
# Query 3: s=None, p=None, o=ent
|
||||
res = await self.rag.triples_client.query(s=None, p=None, o=ent, limit=self.triple_limit)
|
||||
```
|
||||
|
||||
**Optimized Implementation:**
|
||||
```python
|
||||
async def optimized_traversal(self, entities: List[str], max_depth: int) -> Set[Triple]:
|
||||
visited = set()
|
||||
current_level = set(entities)
|
||||
subgraph = set()
|
||||
|
||||
for depth in range(max_depth):
|
||||
if not current_level or len(subgraph) >= self.max_subgraph_size:
|
||||
break
|
||||
|
||||
# Batch all queries for current level
|
||||
batch_queries = []
|
||||
for entity in current_level:
|
||||
if entity not in visited:
|
||||
batch_queries.extend([
|
||||
TripleQuery(s=entity, p=None, o=None),
|
||||
TripleQuery(s=None, p=entity, o=None),
|
||||
TripleQuery(s=None, p=None, o=entity)
|
||||
])
|
||||
|
||||
# Execute all queries concurrently
|
||||
results = await self.execute_batch_queries(batch_queries)
|
||||
|
||||
# Process results and prepare next level
|
||||
next_level = set()
|
||||
for result in results:
|
||||
subgraph.update(result.triples)
|
||||
next_level.update(result.new_entities)
|
||||
|
||||
visited.update(current_level)
|
||||
current_level = next_level - visited
|
||||
|
||||
return subgraph
|
||||
```
|
||||
|
||||
#### Phase 2: Parallel Label Resolution
|
||||
|
||||
**Current Sequential Implementation:**
|
||||
```python
|
||||
# INEFFICIENT: Sequential processing
|
||||
for edge in subgraph:
|
||||
s = await self.maybe_label(edge[0]) # Individual query
|
||||
p = await self.maybe_label(edge[1]) # Individual query
|
||||
o = await self.maybe_label(edge[2]) # Individual query
|
||||
```
|
||||
|
||||
**Optimized Parallel Implementation:**
|
||||
```python
|
||||
async def resolve_labels_parallel(self, subgraph: List[Triple]) -> List[Triple]:
|
||||
# Collect all unique entities needing labels
|
||||
entities_to_resolve = set()
|
||||
for s, p, o in subgraph:
|
||||
entities_to_resolve.update([s, p, o])
|
||||
|
||||
# Remove already cached entities
|
||||
uncached_entities = [e for e in entities_to_resolve if e not in self.label_cache]
|
||||
|
||||
# Batch query for all uncached labels
|
||||
if uncached_entities:
|
||||
label_results = await self.batch_label_query(uncached_entities)
|
||||
self.label_cache.update(label_results)
|
||||
|
||||
# Apply labels to subgraph
|
||||
return [
|
||||
(self.label_cache.get(s, s), self.label_cache.get(p, p), self.label_cache.get(o, o))
|
||||
for s, p, o in subgraph
|
||||
]
|
||||
```
|
||||
|
||||
#### Phase 3: Advanced Caching Strategy
|
||||
|
||||
**LRU Cache with TTL:**
|
||||
```python
|
||||
class LRUCacheWithTTL:
|
||||
def __init__(self, max_size: int, default_ttl: int = 3600):
|
||||
self.cache = OrderedDict()
|
||||
self.max_size = max_size
|
||||
self.default_ttl = default_ttl
|
||||
self.access_times = {}
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
if key in self.cache:
|
||||
# Check TTL expiration
|
||||
if time.time() - self.access_times[key] > self.default_ttl:
|
||||
del self.cache[key]
|
||||
del self.access_times[key]
|
||||
return None
|
||||
|
||||
# Move to end (most recently used)
|
||||
self.cache.move_to_end(key)
|
||||
return self.cache[key]
|
||||
return None
|
||||
|
||||
async def put(self, key: str, value: Any):
|
||||
if key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
else:
|
||||
if len(self.cache) >= self.max_size:
|
||||
# Remove least recently used
|
||||
oldest_key = next(iter(self.cache))
|
||||
del self.cache[oldest_key]
|
||||
del self.access_times[oldest_key]
|
||||
|
||||
self.cache[key] = value
|
||||
self.access_times[key] = time.time()
|
||||
```
|
||||
|
||||
#### Phase 4: Query Optimisation and Monitoring
|
||||
|
||||
**Performance Metrics Collection:**
|
||||
```python
|
||||
@dataclass
|
||||
class PerformanceMetrics:
|
||||
total_queries: int
|
||||
cache_hits: int
|
||||
cache_misses: int
|
||||
avg_response_time: float
|
||||
subgraph_construction_time: float
|
||||
label_resolution_time: float
|
||||
total_entities_processed: int
|
||||
memory_usage_mb: float
|
||||
```
|
||||
|
||||
**Query Timeout and Circuit Breaker:**
|
||||
```python
|
||||
async def execute_with_timeout(self, query_func, timeout: int = 30):
|
||||
try:
|
||||
return await asyncio.wait_for(query_func(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Query timeout after {timeout}s")
|
||||
raise GraphRagTimeoutError(f"Query exceeded timeout of {timeout}s")
|
||||
```
|
||||
|
||||
## Cache Consistency Considerations
|
||||
|
||||
**Data Staleness Trade-offs:**
|
||||
- **Label cache (5min TTL)**: Risk of serving deleted/renamed entity labels
|
||||
- **No embedding caching**: Not needed - embeddings already cached per-query
|
||||
- **No result caching**: Prevents stale subgraph results from deleted entities/relationships
|
||||
|
||||
**Mitigation Strategies:**
|
||||
- **Conservative TTL values**: Balance performance gains (10-20%) with data freshness
|
||||
- **Cache invalidation hooks**: Optional integration with graph mutation events
|
||||
- **Monitoring dashboards**: Track cache hit rates vs staleness incidents
|
||||
- **Configurable cache policies**: Allow per-deployment tuning based on mutation frequency
|
||||
|
||||
**Recommended Cache Configuration by Graph Mutation Rate:**
|
||||
- **High mutation (>100 changes/hour)**: TTL=60s, smaller cache sizes
|
||||
- **Medium mutation (10-100 changes/hour)**: TTL=300s (default)
|
||||
- **Low mutation (<10 changes/hour)**: TTL=600s, larger cache sizes
|
||||
|
||||
## Security Considerations
|
||||
|
||||
**Query Injection Prevention:**
|
||||
- Validate all entity identifiers and query parameters
|
||||
- Use parameterized queries for all database interactions
|
||||
- Implement query complexity limits to prevent DoS attacks
|
||||
|
||||
**Resource Protection:**
|
||||
- Enforce maximum subgraph size limits
|
||||
- Implement query timeouts to prevent resource exhaustion
|
||||
- Add memory usage monitoring and limits
|
||||
|
||||
**Access Control:**
|
||||
- Maintain existing user and collection isolation
|
||||
- Add audit logging for performance-impacting operations
|
||||
- Implement rate limiting for expensive operations
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Expected Performance Improvements
|
||||
|
||||
**Query Reduction:**
|
||||
- Current: ~9,000+ queries for typical request
|
||||
- Optimized: ~50-100 batched queries (98% reduction)
|
||||
|
||||
**Response Time Improvements:**
|
||||
- Graph traversal: 15-20s → 3-5s (4-5x faster)
|
||||
- Label resolution: 8-12s → 2-4s (3x faster)
|
||||
- Overall query: 25-35s → 6-10s (3-4x improvement)
|
||||
|
||||
**Memory Efficiency:**
|
||||
- Bounded cache sizes prevent memory leaks
|
||||
- Efficient data structures reduce memory footprint by ~40%
|
||||
- Better garbage collection through proper resource cleanup
|
||||
|
||||
**Realistic Performance Expectations:**
|
||||
- **Label cache**: 10-20% query reduction for graphs with common relationships
|
||||
- **Batching optimization**: 50-80% query reduction (primary optimization)
|
||||
- **Object lifetime optimization**: Eliminate per-request creation overhead
|
||||
- **Overall improvement**: 3-4x response time improvement primarily from batching
|
||||
|
||||
**Scalability Improvements:**
|
||||
- Support for 3-5x larger knowledge graphs (limited by cache consistency needs)
|
||||
- 3-5x higher concurrent request capacity
|
||||
- Better resource utilization through connection reuse
|
||||
|
||||
### Performance Monitoring
|
||||
|
||||
**Real-time Metrics:**
|
||||
- Query execution times by operation type
|
||||
- Cache hit ratios and effectiveness
|
||||
- Database connection pool utilisation
|
||||
- Memory usage and garbage collection impact
|
||||
|
||||
**Performance Benchmarking:**
|
||||
- Automated performance regression testing
|
||||
- Load testing with realistic data volumes
|
||||
- Comparison benchmarks against current implementation
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### Unit Testing
|
||||
- Individual component testing for traversal, caching, and label resolution
|
||||
- Mock database interactions for performance testing
|
||||
- Cache eviction and TTL expiration testing
|
||||
- Error handling and timeout scenarios
|
||||
|
||||
### Integration Testing
|
||||
- End-to-end GraphRAG query testing with optimisations
|
||||
- Database interaction testing with real data
|
||||
- Concurrent request handling and resource management
|
||||
- Memory leak detection and resource cleanup verification
|
||||
|
||||
### Performance Testing
|
||||
- Benchmark testing against current implementation
|
||||
- Load testing with varying graph sizes and complexities
|
||||
- Stress testing for memory and connection limits
|
||||
- Regression testing for performance improvements
|
||||
|
||||
### Compatibility Testing
|
||||
- Verify existing GraphRAG API compatibility
|
||||
- Test with various graph database backends
|
||||
- Validate result accuracy compared to current implementation
|
||||
|
||||
## Implementation Plan
|
||||
|
||||
### Direct Implementation Approach
|
||||
Since APIs are allowed to change, implement optimizations directly without migration complexity:
|
||||
|
||||
1. **Replace `follow_edges` method**: Rewrite with iterative batched traversal
|
||||
2. **Optimize `get_labelgraph`**: Implement parallel label resolution
|
||||
3. **Add long-lived GraphRag**: Modify Processor to maintain persistent instance
|
||||
4. **Implement label caching**: Add LRU cache with TTL to GraphRag class
|
||||
|
||||
### Scope of Changes
|
||||
- **Query class**: Replace ~50 lines in `follow_edges`, add ~30 lines batch handling
|
||||
- **GraphRag class**: Add caching layer (~40 lines)
|
||||
- **Processor class**: Modify to use persistent GraphRag instance (~20 lines)
|
||||
- **Total**: ~140 lines of focused changes, mostly within existing classes
|
||||
|
||||
## Timeline
|
||||
|
||||
**Week 1: Core Implementation**
|
||||
- Replace `follow_edges` with batched iterative traversal
|
||||
- Implement parallel label resolution in `get_labelgraph`
|
||||
- Add long-lived GraphRag instance to Processor
|
||||
- Implement label caching layer
|
||||
|
||||
**Week 2: Testing and Integration**
|
||||
- Unit tests for new traversal and caching logic
|
||||
- Performance benchmarking against current implementation
|
||||
- Integration testing with real graph data
|
||||
- Code review and optimization
|
||||
|
||||
**Week 3: Deployment**
|
||||
- Deploy optimized implementation
|
||||
- Monitor performance improvements
|
||||
- Fine-tune cache TTL and batch sizes based on real usage
|
||||
|
||||
## Open Questions
|
||||
|
||||
- **Database Connection Pooling**: Should we implement custom connection pooling or rely on existing database client pooling?
|
||||
- **Cache Persistence**: Should label and embedding caches persist across service restarts?
|
||||
- **Distributed Caching**: For multi-instance deployments, should we implement distributed caching with Redis/Memcached?
|
||||
- **Query Result Format**: Should we optimize the internal triple representation for better memory efficiency?
|
||||
- **Monitoring Integration**: Which metrics should be exposed to existing monitoring systems (Prometheus, etc.)?
|
||||
|
||||
## References
|
||||
|
||||
- [GraphRAG Original Implementation](trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py)
|
||||
- [TrustGraph Architecture Principles](architecture-principles.md)
|
||||
- [Collection Management Specification](collection-management.md)
|
||||
|
|
@ -29,23 +29,25 @@ class TestEndToEndConfigurationFlow:
|
|||
'CASSANDRA_USERNAME': 'integration-user',
|
||||
'CASSANDRA_PASSWORD': 'integration-pass'
|
||||
}
|
||||
|
||||
|
||||
mock_cluster_instance = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_cluster_instance.connect.return_value = mock_session
|
||||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = TriplesWriter(taskgroup=MagicMock())
|
||||
|
||||
|
||||
# Create a mock message to trigger TrustGraph creation
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
# This should create TrustGraph with environment config
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
# This should create TrustGraph with environment config
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify Cluster was created with correct hosts
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
@ -145,8 +147,10 @@ class TestConfigurationPriorityEndToEnd:
|
|||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Should use CLI parameters, not environment
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
@ -243,8 +247,10 @@ class TestNoBackwardCompatibilityEndToEnd:
|
|||
mock_message.metadata.user = 'legacy_user'
|
||||
mock_message.metadata.collection = 'legacy_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Should use defaults since old parameters are not recognized
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
@ -299,8 +305,10 @@ class TestNoBackwardCompatibilityEndToEnd:
|
|||
mock_message.metadata.user = 'precedence_user'
|
||||
mock_message.metadata.collection = 'precedence_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Should use new parameters, not old ones
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
@ -349,8 +357,10 @@ class TestMultipleHostsHandling:
|
|||
mock_message.metadata.user = 'single_user'
|
||||
mock_message.metadata.collection = 'single_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Single host should be converted to list
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
|
|||
276
tests/integration/test_dynamic_llm_parameters.py
Normal file
276
tests/integration/test_dynamic_llm_parameters.py
Normal file
|
|
@ -0,0 +1,276 @@
|
|||
"""
|
||||
Integration tests for Dynamic LLM Parameters
|
||||
Testing end-to-end flow of runtime parameter changes in LLM processors
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from trustgraph.model.text_completion.openai.llm import Processor as OpenAIProcessor
|
||||
from trustgraph.base import LlmResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestDynamicLlmParameters:
|
||||
"""Integration tests for dynamic parameter configuration"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_client(self):
|
||||
"""Mock OpenAI client that returns realistic responses"""
|
||||
client = MagicMock()
|
||||
|
||||
# Default mock response
|
||||
usage = CompletionUsage(prompt_tokens=25, completion_tokens=15, total_tokens=40)
|
||||
message = ChatCompletionMessage(role="assistant", content="Dynamic parameter test response")
|
||||
choice = Choice(index=0, message=message, finish_reason="stop")
|
||||
|
||||
completion = ChatCompletion(
|
||||
id="chatcmpl-test-dynamic",
|
||||
choices=[choice],
|
||||
created=1234567890,
|
||||
model="gpt-4", # Will be overridden based on test
|
||||
object="chat.completion",
|
||||
usage=usage
|
||||
)
|
||||
|
||||
client.chat.completions.create.return_value = completion
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def base_processor_config(self):
|
||||
"""Base configuration for test processors"""
|
||||
return {
|
||||
"api_key": "test-api-key",
|
||||
"url": "https://api.openai.com/v1",
|
||||
"temperature": 0.0, # Default temperature
|
||||
"max_output": 1024,
|
||||
}
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_runtime_temperature_override(self, mock_llm_init, mock_async_init,
|
||||
mock_openai_class, mock_openai_client, base_processor_config):
|
||||
"""Test that temperature can be overridden at runtime"""
|
||||
# Arrange
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = base_processor_config | {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"concurrency": 1,
|
||||
"taskgroup": AsyncMock(),
|
||||
"id": "test-processor"
|
||||
}
|
||||
|
||||
processor = OpenAIProcessor(**config)
|
||||
|
||||
# Act - Call with different temperature than configured default (0.0)
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Dynamic parameter test response"
|
||||
|
||||
# Verify the OpenAI API was called with the overridden temperature
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
|
||||
assert call_args.kwargs['temperature'] == 0.9 # Should use runtime parameter
|
||||
assert call_args.kwargs['model'] == "gpt-3.5-turbo" # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_runtime_model_override(self, mock_llm_init, mock_async_init,
|
||||
mock_openai_class, mock_openai_client, base_processor_config):
|
||||
"""Test that model can be overridden at runtime"""
|
||||
# Arrange
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = base_processor_config | {
|
||||
"model": "gpt-3.5-turbo", # Default model
|
||||
"concurrency": 1,
|
||||
"taskgroup": AsyncMock(),
|
||||
"id": "test-processor"
|
||||
}
|
||||
|
||||
processor = OpenAIProcessor(**config)
|
||||
|
||||
# Act - Call with different model than configured default
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gpt-4", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
|
||||
# Verify the OpenAI API was called with the overridden model
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
|
||||
assert call_args.kwargs['model'] == "gpt-4" # Should use runtime parameter
|
||||
assert call_args.kwargs['temperature'] == 0.0 # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_both_parameters_override(self, mock_llm_init, mock_async_init,
|
||||
mock_openai_class, mock_openai_client, base_processor_config):
|
||||
"""Test that both model and temperature can be overridden simultaneously"""
|
||||
# Arrange
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = base_processor_config | {
|
||||
"model": "gpt-3.5-turbo", # Default model
|
||||
"concurrency": 1,
|
||||
"taskgroup": AsyncMock(),
|
||||
"id": "test-processor"
|
||||
}
|
||||
|
||||
processor = OpenAIProcessor(**config)
|
||||
|
||||
# Act - Override both parameters
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gpt-4", # Override model
|
||||
temperature=0.5 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
|
||||
# Verify both parameters were overridden
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
|
||||
assert call_args.kwargs['model'] == "gpt-4" # Should use runtime parameter
|
||||
assert call_args.kwargs['temperature'] == 0.5 # Should use runtime parameter
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_fallback_to_defaults_when_no_override(self, mock_llm_init, mock_async_init,
|
||||
mock_openai_class, mock_openai_client, base_processor_config):
|
||||
"""Test that processor falls back to configured defaults when no parameters are provided"""
|
||||
# Arrange
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = base_processor_config | {
|
||||
"model": "gpt-3.5-turbo", # Default model
|
||||
"temperature": 0.2, # Default temperature
|
||||
"concurrency": 1,
|
||||
"taskgroup": AsyncMock(),
|
||||
"id": "test-processor"
|
||||
}
|
||||
|
||||
processor = OpenAIProcessor(**config)
|
||||
|
||||
# Act - Call with no parameter overrides
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default
|
||||
temperature=None # Use default
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
|
||||
# Verify defaults were used
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
|
||||
assert call_args.kwargs['model'] == "gpt-3.5-turbo" # Should use processor default
|
||||
assert call_args.kwargs['temperature'] == 0.2 # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_multiple_concurrent_calls_different_parameters(self, mock_llm_init, mock_async_init,
|
||||
mock_openai_class, mock_openai_client, base_processor_config):
|
||||
"""Test multiple concurrent calls with different parameters don't interfere"""
|
||||
# Arrange
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = base_processor_config | {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"concurrency": 1,
|
||||
"taskgroup": AsyncMock(),
|
||||
"id": "test-processor"
|
||||
}
|
||||
|
||||
processor = OpenAIProcessor(**config)
|
||||
|
||||
# Reset the mock to track multiple calls
|
||||
mock_openai_client.reset_mock()
|
||||
|
||||
# Act - Make multiple calls with different parameters concurrently
|
||||
import asyncio
|
||||
tasks = [
|
||||
processor.generate_content("System 1", "Prompt 1", model="gpt-3.5-turbo", temperature=0.1),
|
||||
processor.generate_content("System 2", "Prompt 2", model="gpt-4", temperature=0.8),
|
||||
processor.generate_content("System 3", "Prompt 3", model="gpt-3.5-turbo", temperature=0.5)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Assert
|
||||
assert len(results) == 3
|
||||
for result in results:
|
||||
assert isinstance(result, LlmResult)
|
||||
|
||||
# Verify all calls were made with correct parameters
|
||||
assert mock_openai_client.chat.completions.create.call_count == 3
|
||||
|
||||
# Get all call arguments
|
||||
call_args_list = mock_openai_client.chat.completions.create.call_args_list
|
||||
|
||||
# Verify each call had the expected parameters
|
||||
expected_params = [
|
||||
("gpt-3.5-turbo", 0.1),
|
||||
("gpt-4", 0.8),
|
||||
("gpt-3.5-turbo", 0.5)
|
||||
]
|
||||
|
||||
for i, (expected_model, expected_temp) in enumerate(expected_params):
|
||||
call_kwargs = call_args_list[i].kwargs
|
||||
assert call_kwargs['model'] == expected_model
|
||||
assert call_kwargs['temperature'] == expected_temp
|
||||
|
||||
async def test_parameter_boundary_values(self, mock_openai_client, base_processor_config):
|
||||
"""Test parameter boundary values (edge cases)"""
|
||||
# This would test extreme values like temperature=0.0, temperature=2.0, etc.
|
||||
# Implementation depends on specific validation requirements
|
||||
pass
|
||||
|
||||
async def test_invalid_parameter_types_handling(self, mock_openai_client, base_processor_config):
|
||||
"""Test handling of invalid parameter types"""
|
||||
# This would test what happens with invalid temperature values, non-existent models, etc.
|
||||
# Implementation depends on error handling requirements
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -22,7 +22,36 @@ class TestObjectsCassandraIntegration:
|
|||
def mock_cassandra_session(self):
|
||||
"""Mock Cassandra session for integration tests"""
|
||||
session = MagicMock()
|
||||
session.execute = MagicMock()
|
||||
|
||||
# Track if keyspaces have been created
|
||||
created_keyspaces = set()
|
||||
|
||||
# Mock the execute method to return a valid result for keyspace checks
|
||||
def execute_mock(query, *args, **kwargs):
|
||||
result = MagicMock()
|
||||
query_str = str(query)
|
||||
|
||||
# Track keyspace creation
|
||||
if "CREATE KEYSPACE" in query_str:
|
||||
# Extract keyspace name from query
|
||||
import re
|
||||
match = re.search(r'CREATE KEYSPACE IF NOT EXISTS (\w+)', query_str)
|
||||
if match:
|
||||
created_keyspaces.add(match.group(1))
|
||||
|
||||
# For keyspace existence checks
|
||||
if "system_schema.keyspaces" in query_str:
|
||||
# Check if this keyspace was created
|
||||
if args and args[0] in created_keyspaces:
|
||||
result.one.return_value = MagicMock() # Exists
|
||||
else:
|
||||
result.one.return_value = None # Doesn't exist
|
||||
else:
|
||||
result.one.return_value = None
|
||||
|
||||
return result
|
||||
|
||||
session.execute = MagicMock(side_effect=execute_mock)
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -57,7 +86,8 @@ class TestObjectsCassandraIntegration:
|
|||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
|
||||
processor.create_collection = Processor.create_collection.__get__(processor, Processor)
|
||||
|
||||
return processor, mock_cassandra_cluster, mock_cassandra_session
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -85,7 +115,10 @@ class TestObjectsCassandraIntegration:
|
|||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
assert "customer_records" in processor.schemas
|
||||
|
||||
|
||||
# Step 1.5: Create the collection first (simulate tg-set-collection)
|
||||
await processor.create_collection("test_user", "import_2024")
|
||||
|
||||
# Step 2: Process an ExtractedObject
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
|
|
@ -104,10 +137,10 @@ class TestObjectsCassandraIntegration:
|
|||
confidence=0.95,
|
||||
source_span="Customer: John Doe..."
|
||||
)
|
||||
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify Cassandra interactions
|
||||
|
|
@ -178,7 +211,11 @@ class TestObjectsCassandraIntegration:
|
|||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
assert len(processor.schemas) == 2
|
||||
|
||||
|
||||
# Create collections first
|
||||
await processor.create_collection("shop", "catalog")
|
||||
await processor.create_collection("shop", "sales")
|
||||
|
||||
# Process objects for different schemas
|
||||
product_obj = ExtractedObject(
|
||||
metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]),
|
||||
|
|
@ -187,7 +224,7 @@ class TestObjectsCassandraIntegration:
|
|||
confidence=0.9,
|
||||
source_span="Product..."
|
||||
)
|
||||
|
||||
|
||||
order_obj = ExtractedObject(
|
||||
metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]),
|
||||
schema_name="orders",
|
||||
|
|
@ -195,7 +232,7 @@ class TestObjectsCassandraIntegration:
|
|||
confidence=0.85,
|
||||
source_span="Order..."
|
||||
)
|
||||
|
||||
|
||||
# Process both objects
|
||||
for obj in [product_obj, order_obj]:
|
||||
msg = MagicMock()
|
||||
|
|
@ -225,6 +262,9 @@ class TestObjectsCassandraIntegration:
|
|||
]
|
||||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("test", "test")
|
||||
|
||||
# Create object missing required field
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
|
||||
|
|
@ -233,10 +273,10 @@ class TestObjectsCassandraIntegration:
|
|||
confidence=0.8,
|
||||
source_span="Test"
|
||||
)
|
||||
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
|
||||
# Should still process (Cassandra doesn't enforce NOT NULL)
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
|
|
@ -261,6 +301,9 @@ class TestObjectsCassandraIntegration:
|
|||
]
|
||||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("logger", "app_events")
|
||||
|
||||
# Process object
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="e1", user="logger", collection="app_events", metadata=[]),
|
||||
|
|
@ -269,10 +312,10 @@ class TestObjectsCassandraIntegration:
|
|||
confidence=1.0,
|
||||
source_span="Event"
|
||||
)
|
||||
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify synthetic_id was added
|
||||
|
|
@ -325,8 +368,10 @@ class TestObjectsCassandraIntegration:
|
|||
)
|
||||
|
||||
# Make insert fail
|
||||
mock_result = MagicMock()
|
||||
mock_result.one.return_value = MagicMock() # Keyspace exists
|
||||
mock_session.execute.side_effect = [
|
||||
None, # keyspace creation succeeds
|
||||
mock_result, # keyspace existence check succeeds
|
||||
None, # table creation succeeds
|
||||
Exception("Connection timeout") # insert fails
|
||||
]
|
||||
|
|
@ -359,7 +404,11 @@ class TestObjectsCassandraIntegration:
|
|||
|
||||
# Process objects from different collections
|
||||
collections = ["import_jan", "import_feb", "import_mar"]
|
||||
|
||||
|
||||
# Create all collections first
|
||||
for coll in collections:
|
||||
await processor.create_collection("analytics", coll)
|
||||
|
||||
for coll in collections:
|
||||
obj = ExtractedObject(
|
||||
metadata=Metadata(id=f"{coll}-1", user="analytics", collection=coll, metadata=[]),
|
||||
|
|
@ -368,7 +417,7 @@ class TestObjectsCassandraIntegration:
|
|||
confidence=0.9,
|
||||
source_span="Data"
|
||||
)
|
||||
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = obj
|
||||
await processor.on_object(msg, None, None)
|
||||
|
|
@ -436,9 +485,12 @@ class TestObjectsCassandraIntegration:
|
|||
source_span="Multiple customers extracted from document"
|
||||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("test_user", "batch_import")
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = batch_obj
|
||||
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify table creation
|
||||
|
|
@ -479,6 +531,9 @@ class TestObjectsCassandraIntegration:
|
|||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("test", "empty")
|
||||
|
||||
# Process empty batch object
|
||||
empty_obj = ExtractedObject(
|
||||
metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]),
|
||||
|
|
@ -487,10 +542,10 @@ class TestObjectsCassandraIntegration:
|
|||
confidence=1.0,
|
||||
source_span="No objects found"
|
||||
)
|
||||
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = empty_obj
|
||||
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Should still create table
|
||||
|
|
@ -517,6 +572,9 @@ class TestObjectsCassandraIntegration:
|
|||
]
|
||||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("test", "mixed")
|
||||
|
||||
# Single object (backward compatibility)
|
||||
single_obj = ExtractedObject(
|
||||
metadata=Metadata(id="single", user="test", collection="mixed", metadata=[]),
|
||||
|
|
@ -525,7 +583,7 @@ class TestObjectsCassandraIntegration:
|
|||
confidence=0.9,
|
||||
source_span="Single object"
|
||||
)
|
||||
|
||||
|
||||
# Batch object
|
||||
batch_obj = ExtractedObject(
|
||||
metadata=Metadata(id="batch", user="test", collection="mixed", metadata=[]),
|
||||
|
|
@ -537,7 +595,7 @@ class TestObjectsCassandraIntegration:
|
|||
confidence=0.85,
|
||||
source_span="Batch objects"
|
||||
)
|
||||
|
||||
|
||||
# Process both
|
||||
for obj in [single_obj, batch_obj]:
|
||||
msg = MagicMock()
|
||||
|
|
|
|||
|
|
@ -60,13 +60,13 @@ class TestTextCompletionIntegration:
|
|||
"""Create text completion processor with test configuration"""
|
||||
# Create a minimal processor instance for testing generate_content
|
||||
processor = MagicMock()
|
||||
processor.model = processor_config["model"]
|
||||
processor.default_model = processor_config["model"]
|
||||
processor.temperature = processor_config["temperature"]
|
||||
processor.max_output = processor_config["max_output"]
|
||||
|
||||
|
||||
# Add the actual generate_content method from Processor class
|
||||
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
|
||||
|
||||
|
||||
return processor
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -112,11 +112,11 @@ class TestTextCompletionIntegration:
|
|||
for config in test_configs:
|
||||
# Arrange - Create minimal processor mock
|
||||
processor = MagicMock()
|
||||
processor.model = config['model']
|
||||
processor.default_model = config['model']
|
||||
processor.temperature = config['temperature']
|
||||
processor.max_output = config['max_output']
|
||||
processor.openai = mock_openai_client
|
||||
|
||||
|
||||
# Add the actual generate_content method
|
||||
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
|
||||
|
||||
|
|
@ -242,7 +242,7 @@ class TestTextCompletionIntegration:
|
|||
processors = []
|
||||
for i in range(5):
|
||||
processor = MagicMock()
|
||||
processor.model = processor_config["model"]
|
||||
processor.default_model = processor_config["model"]
|
||||
processor.temperature = processor_config["temperature"]
|
||||
processor.max_output = processor_config["max_output"]
|
||||
processor.openai = mock_openai_client
|
||||
|
|
@ -348,7 +348,7 @@ class TestTextCompletionIntegration:
|
|||
"""Test that model parameters are correctly passed to OpenAI API"""
|
||||
# Arrange
|
||||
processor = MagicMock()
|
||||
processor.model = "gpt-4"
|
||||
processor.default_model = "gpt-4"
|
||||
processor.temperature = 0.8
|
||||
processor.max_output = 2048
|
||||
processor.openai = mock_openai_client
|
||||
|
|
|
|||
238
tests/unit/test_base/test_flow_parameter_specs.py
Normal file
238
tests/unit/test_base/test_flow_parameter_specs.py
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
"""
|
||||
Unit tests for Flow Parameter Specification functionality
|
||||
Testing parameter specification registration and handling in flow processors
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
from trustgraph.base.flow_processor import FlowProcessor
|
||||
from trustgraph.base import ParameterSpec, ConsumerSpec, ProducerSpec
|
||||
|
||||
|
||||
class MockAsyncProcessor:
|
||||
def __init__(self, **params):
|
||||
self.config_handlers = []
|
||||
self.id = params.get('id', 'test-service')
|
||||
self.specifications = []
|
||||
|
||||
|
||||
class TestFlowParameterSpecs(IsolatedAsyncioTestCase):
|
||||
"""Test flow processor parameter specification functionality"""
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_parameter_spec_registration(self):
|
||||
"""Test that parameter specs can be registered with flow processors"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
|
||||
# Create test parameter specs
|
||||
model_spec = ParameterSpec(name="model")
|
||||
temperature_spec = ParameterSpec(name="temperature")
|
||||
|
||||
# Act
|
||||
processor.register_specification(model_spec)
|
||||
processor.register_specification(temperature_spec)
|
||||
|
||||
# Assert
|
||||
assert len(processor.specifications) >= 2
|
||||
|
||||
param_specs = [spec for spec in processor.specifications
|
||||
if isinstance(spec, ParameterSpec)]
|
||||
assert len(param_specs) >= 2
|
||||
|
||||
param_names = [spec.name for spec in param_specs]
|
||||
assert "model" in param_names
|
||||
assert "temperature" in param_names
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_mixed_specification_types(self):
|
||||
"""Test registration of mixed specification types (parameters, consumers, producers)"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
|
||||
# Create different spec types
|
||||
param_spec = ParameterSpec(name="model")
|
||||
consumer_spec = ConsumerSpec(name="input", schema=MagicMock(), handler=MagicMock())
|
||||
producer_spec = ProducerSpec(name="output", schema=MagicMock())
|
||||
|
||||
# Act
|
||||
processor.register_specification(param_spec)
|
||||
processor.register_specification(consumer_spec)
|
||||
processor.register_specification(producer_spec)
|
||||
|
||||
# Assert
|
||||
assert len(processor.specifications) == 3
|
||||
|
||||
# Count each type
|
||||
param_specs = [s for s in processor.specifications if isinstance(s, ParameterSpec)]
|
||||
consumer_specs = [s for s in processor.specifications if isinstance(s, ConsumerSpec)]
|
||||
producer_specs = [s for s in processor.specifications if isinstance(s, ProducerSpec)]
|
||||
|
||||
assert len(param_specs) == 1
|
||||
assert len(consumer_specs) == 1
|
||||
assert len(producer_specs) == 1
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_parameter_spec_metadata(self):
|
||||
"""Test parameter specification metadata handling"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
|
||||
# Create parameter specs with metadata (if supported)
|
||||
model_spec = ParameterSpec(name="model")
|
||||
temperature_spec = ParameterSpec(name="temperature")
|
||||
|
||||
# Act
|
||||
processor.register_specification(model_spec)
|
||||
processor.register_specification(temperature_spec)
|
||||
|
||||
# Assert
|
||||
param_specs = [spec for spec in processor.specifications
|
||||
if isinstance(spec, ParameterSpec)]
|
||||
|
||||
model_spec_registered = next((s for s in param_specs if s.name == "model"), None)
|
||||
temperature_spec_registered = next((s for s in param_specs if s.name == "temperature"), None)
|
||||
|
||||
assert model_spec_registered is not None
|
||||
assert temperature_spec_registered is not None
|
||||
assert model_spec_registered.name == "model"
|
||||
assert temperature_spec_registered.name == "temperature"
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_duplicate_parameter_spec_handling(self):
|
||||
"""Test handling of duplicate parameter spec registration"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
|
||||
# Create duplicate parameter specs
|
||||
model_spec1 = ParameterSpec(name="model")
|
||||
model_spec2 = ParameterSpec(name="model")
|
||||
|
||||
# Act
|
||||
processor.register_specification(model_spec1)
|
||||
processor.register_specification(model_spec2)
|
||||
|
||||
# Assert - Should allow duplicates (or handle appropriately)
|
||||
param_specs = [spec for spec in processor.specifications
|
||||
if isinstance(spec, ParameterSpec) and spec.name == "model"]
|
||||
|
||||
# Either should have 2 duplicates or the system should handle deduplication
|
||||
assert len(param_specs) >= 1 # At least one should be registered
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
@patch('trustgraph.base.flow_processor.Flow')
|
||||
async def test_parameter_specs_available_to_flows(self, mock_flow_class):
|
||||
"""Test that parameter specs are available when flows are created"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
processor.id = 'test-processor'
|
||||
|
||||
# Register parameter specs
|
||||
model_spec = ParameterSpec(name="model")
|
||||
temperature_spec = ParameterSpec(name="temperature")
|
||||
processor.register_specification(model_spec)
|
||||
processor.register_specification(temperature_spec)
|
||||
|
||||
mock_flow = AsyncMock()
|
||||
mock_flow_class.return_value = mock_flow
|
||||
|
||||
flow_name = 'test-flow'
|
||||
flow_defn = {'config': 'test-config'}
|
||||
|
||||
# Act
|
||||
await processor.start_flow(flow_name, flow_defn)
|
||||
|
||||
# Assert - Flow should be created with access to processor specifications
|
||||
mock_flow_class.assert_called_once_with('test-processor', flow_name, processor, flow_defn)
|
||||
|
||||
# The flow should have access to the processor's specifications
|
||||
# (The exact mechanism depends on Flow implementation)
|
||||
assert len(processor.specifications) >= 2
|
||||
|
||||
|
||||
class TestParameterSpecValidation(IsolatedAsyncioTestCase):
|
||||
"""Test parameter specification validation functionality"""
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_parameter_spec_name_validation(self):
|
||||
"""Test parameter spec name validation"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
|
||||
# Act & Assert - Valid parameter names
|
||||
valid_specs = [
|
||||
ParameterSpec(name="model"),
|
||||
ParameterSpec(name="temperature"),
|
||||
ParameterSpec(name="max_tokens"),
|
||||
ParameterSpec(name="api_key")
|
||||
]
|
||||
|
||||
for spec in valid_specs:
|
||||
# Should not raise any exceptions
|
||||
processor.register_specification(spec)
|
||||
|
||||
assert len([s for s in processor.specifications if isinstance(s, ParameterSpec)]) >= 4
|
||||
|
||||
def test_parameter_spec_creation_validation(self):
|
||||
"""Test parameter spec creation with various inputs"""
|
||||
# Test valid parameter spec creation
|
||||
valid_specs = [
|
||||
ParameterSpec(name="model"),
|
||||
ParameterSpec(name="temperature"),
|
||||
ParameterSpec(name="max_output"),
|
||||
]
|
||||
|
||||
for spec in valid_specs:
|
||||
assert spec.name is not None
|
||||
assert isinstance(spec.name, str)
|
||||
|
||||
# Test edge cases (if parameter specs have validation)
|
||||
# This depends on the actual ParameterSpec implementation
|
||||
try:
|
||||
empty_name_spec = ParameterSpec(name="")
|
||||
# May or may not be valid depending on implementation
|
||||
except Exception:
|
||||
# If validation exists, it should catch invalid names
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
264
tests/unit/test_base/test_llm_service_parameters.py
Normal file
264
tests/unit/test_base/test_llm_service_parameters.py
Normal file
|
|
@ -0,0 +1,264 @@
|
|||
"""
|
||||
Unit tests for LLM Service Parameter Specifications
|
||||
Testing the new parameter-aware functionality added to the LLM base service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
from trustgraph.base.llm_service import LlmService, LlmResult
|
||||
from trustgraph.base import ParameterSpec, ConsumerSpec, ProducerSpec
|
||||
from trustgraph.schema import TextCompletionRequest, TextCompletionResponse
|
||||
|
||||
|
||||
class MockAsyncProcessor:
|
||||
def __init__(self, **params):
|
||||
self.config_handlers = []
|
||||
self.id = params.get('id', 'test-service')
|
||||
self.specifications = []
|
||||
|
||||
|
||||
|
||||
|
||||
class TestLlmServiceParameters(IsolatedAsyncioTestCase):
|
||||
"""Test LLM service parameter specification functionality"""
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_parameter_specs_registration(self):
|
||||
"""Test that LLM service registers model and temperature parameter specs"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-llm-service',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock() # Add required taskgroup
|
||||
}
|
||||
|
||||
# Act
|
||||
service = LlmService(**config)
|
||||
|
||||
# Assert
|
||||
param_specs = {spec.name: spec for spec in service.specifications
|
||||
if isinstance(spec, ParameterSpec)}
|
||||
|
||||
assert "model" in param_specs
|
||||
assert "temperature" in param_specs
|
||||
assert len(param_specs) >= 2 # May have other parameter specs
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_model_parameter_spec_properties(self):
|
||||
"""Test that model parameter spec has correct properties"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-llm-service',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
# Act
|
||||
service = LlmService(**config)
|
||||
|
||||
# Assert
|
||||
model_spec = None
|
||||
for spec in service.specifications:
|
||||
if isinstance(spec, ParameterSpec) and spec.name == "model":
|
||||
model_spec = spec
|
||||
break
|
||||
|
||||
assert model_spec is not None
|
||||
assert model_spec.name == "model"
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_temperature_parameter_spec_properties(self):
|
||||
"""Test that temperature parameter spec has correct properties"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-llm-service',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
# Act
|
||||
service = LlmService(**config)
|
||||
|
||||
# Assert
|
||||
temperature_spec = None
|
||||
for spec in service.specifications:
|
||||
if isinstance(spec, ParameterSpec) and spec.name == "temperature":
|
||||
temperature_spec = spec
|
||||
break
|
||||
|
||||
assert temperature_spec is not None
|
||||
assert temperature_spec.name == "temperature"
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_request_extracts_parameters_from_flow(self):
|
||||
"""Test that on_request method extracts model and temperature from flow"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-llm-service',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
service = LlmService(**config)
|
||||
|
||||
# Mock the metrics
|
||||
service.text_completion_model_metric = MagicMock()
|
||||
service.text_completion_model_metric.labels.return_value.info = AsyncMock()
|
||||
|
||||
# Mock the generate_content method to capture parameters
|
||||
service.generate_content = AsyncMock(return_value=LlmResult(
|
||||
text="test response",
|
||||
in_token=10,
|
||||
out_token=5,
|
||||
model="gpt-4"
|
||||
))
|
||||
|
||||
# Mock message and flow
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = MagicMock()
|
||||
mock_message.value.return_value.system = "system prompt"
|
||||
mock_message.value.return_value.prompt = "user prompt"
|
||||
mock_message.properties.return_value = {"id": "test-id"}
|
||||
|
||||
mock_consumer = MagicMock()
|
||||
mock_consumer.name = "request"
|
||||
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.name = "test-flow"
|
||||
mock_flow.return_value = "test-model" # flow("model") returns this
|
||||
mock_flow.side_effect = lambda param: {
|
||||
"model": "gpt-4",
|
||||
"temperature": 0.7
|
||||
}.get(param, f"mock-{param}")
|
||||
|
||||
mock_producer = AsyncMock()
|
||||
mock_flow.producer = {"response": mock_producer}
|
||||
|
||||
# Act
|
||||
await service.on_request(mock_message, mock_consumer, mock_flow)
|
||||
|
||||
# Assert
|
||||
# Verify that generate_content was called with parameters from flow
|
||||
service.generate_content.assert_called_once()
|
||||
call_args = service.generate_content.call_args
|
||||
|
||||
assert call_args[0][0] == "system prompt" # system
|
||||
assert call_args[0][1] == "user prompt" # prompt
|
||||
assert call_args[0][2] == "gpt-4" # model
|
||||
assert call_args[0][3] == 0.7 # temperature
|
||||
|
||||
# Verify flow was queried for both parameters
|
||||
mock_flow.assert_any_call("model")
|
||||
mock_flow.assert_any_call("temperature")
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_request_handles_missing_parameters_gracefully(self):
|
||||
"""Test that on_request handles missing parameters gracefully"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-llm-service',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
service = LlmService(**config)
|
||||
|
||||
# Mock the metrics
|
||||
service.text_completion_model_metric = MagicMock()
|
||||
service.text_completion_model_metric.labels.return_value.info = AsyncMock()
|
||||
|
||||
# Mock the generate_content method
|
||||
service.generate_content = AsyncMock(return_value=LlmResult(
|
||||
text="test response",
|
||||
in_token=10,
|
||||
out_token=5,
|
||||
model="default-model"
|
||||
))
|
||||
|
||||
# Mock message and flow where flow returns None for parameters
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = MagicMock()
|
||||
mock_message.value.return_value.system = "system prompt"
|
||||
mock_message.value.return_value.prompt = "user prompt"
|
||||
mock_message.properties.return_value = {"id": "test-id"}
|
||||
|
||||
mock_consumer = MagicMock()
|
||||
mock_consumer.name = "request"
|
||||
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.name = "test-flow"
|
||||
mock_flow.return_value = None # Both parameters return None
|
||||
|
||||
mock_producer = AsyncMock()
|
||||
mock_flow.producer = {"response": mock_producer}
|
||||
|
||||
# Act
|
||||
await service.on_request(mock_message, mock_consumer, mock_flow)
|
||||
|
||||
# Assert
|
||||
# Should still call generate_content, with None values that will use processor defaults
|
||||
service.generate_content.assert_called_once()
|
||||
call_args = service.generate_content.call_args
|
||||
|
||||
assert call_args[0][0] == "system prompt" # system
|
||||
assert call_args[0][1] == "user prompt" # prompt
|
||||
assert call_args[0][2] is None # model (will use processor default)
|
||||
assert call_args[0][3] is None # temperature (will use processor default)
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_request_error_handling_preserves_behavior(self):
|
||||
"""Test that parameter extraction doesn't break existing error handling"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-llm-service',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
service = LlmService(**config)
|
||||
|
||||
# Mock the metrics
|
||||
service.text_completion_model_metric = MagicMock()
|
||||
service.text_completion_model_metric.labels.return_value.info = AsyncMock()
|
||||
|
||||
# Mock generate_content to raise an exception
|
||||
service.generate_content = AsyncMock(side_effect=Exception("Test error"))
|
||||
|
||||
# Mock message and flow
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = MagicMock()
|
||||
mock_message.value.return_value.system = "system prompt"
|
||||
mock_message.value.return_value.prompt = "user prompt"
|
||||
mock_message.properties.return_value = {"id": "test-id"}
|
||||
|
||||
mock_consumer = MagicMock()
|
||||
mock_consumer.name = "request"
|
||||
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.name = "test-flow"
|
||||
mock_flow.side_effect = lambda param: {
|
||||
"model": "gpt-4",
|
||||
"temperature": 0.7
|
||||
}.get(param, f"mock-{param}")
|
||||
|
||||
mock_producer = AsyncMock()
|
||||
mock_flow.producer = {"response": mock_producer}
|
||||
|
||||
# Act
|
||||
await service.on_request(mock_message, mock_consumer, mock_flow)
|
||||
|
||||
# Assert
|
||||
# Should have sent error response
|
||||
mock_producer.send.assert_called_once()
|
||||
error_response = mock_producer.send.call_args[0][0]
|
||||
|
||||
assert error_response.error is not None
|
||||
assert error_response.error.type == "llm-error"
|
||||
assert "Test error" in error_response.error.message
|
||||
assert error_response.response is None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -1,211 +1,236 @@
|
|||
"""
|
||||
Unit tests for trustgraph.chunking.recursive
|
||||
Testing parameter override functionality for chunk-size and chunk-overlap
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, Mock, patch, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.chunking.recursive.chunker import Processor
|
||||
from trustgraph.schema import TextDocument, Chunk, Metadata
|
||||
from trustgraph.chunking.recursive.chunker import Processor as RecursiveChunker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flow():
|
||||
output_mock = AsyncMock()
|
||||
flow_mock = Mock(return_value=output_mock)
|
||||
return flow_mock, output_mock
|
||||
class MockAsyncProcessor:
|
||||
def __init__(self, **params):
|
||||
self.config_handlers = []
|
||||
self.id = params.get('id', 'test-service')
|
||||
self.specifications = []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_consumer():
|
||||
consumer = Mock()
|
||||
consumer.id = "test-consumer"
|
||||
consumer.flow = "test-flow"
|
||||
return consumer
|
||||
class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Recursive chunker functionality"""
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_processor_initialization_basic(self):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'chunk_size': 1500,
|
||||
'chunk_overlap': 150,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.default_chunk_size == 1500
|
||||
assert processor.default_chunk_overlap == 150
|
||||
assert hasattr(processor, 'text_splitter')
|
||||
|
||||
# Verify parameter specs are registered
|
||||
param_specs = [spec for spec in processor.specifications
|
||||
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
|
||||
assert len(param_specs) == 2
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_chunk_size_override(self):
|
||||
"""Test chunk_document with chunk-size parameter override"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'chunk_size': 1000, # Default chunk size
|
||||
'chunk_overlap': 100,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock message and flow
|
||||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
"chunk-size": 2000, # Override chunk size
|
||||
"chunk-overlap": None # Use default chunk overlap
|
||||
}.get(param)
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
mock_message, mock_consumer, mock_flow, 1000, 100
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert chunk_size == 2000 # Should use overridden value
|
||||
assert chunk_overlap == 100 # Should use default value
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_chunk_overlap_override(self):
|
||||
"""Test chunk_document with chunk-overlap parameter override"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'chunk_size': 1000,
|
||||
'chunk_overlap': 100, # Default chunk overlap
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock message and flow
|
||||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
"chunk-size": None, # Use default chunk size
|
||||
"chunk-overlap": 200 # Override chunk overlap
|
||||
}.get(param)
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
mock_message, mock_consumer, mock_flow, 1000, 100
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert chunk_size == 1000 # Should use default value
|
||||
assert chunk_overlap == 200 # Should use overridden value
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_both_parameters_override(self):
|
||||
"""Test chunk_document with both chunk-size and chunk-overlap overrides"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'chunk_size': 1000,
|
||||
'chunk_overlap': 100,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock message and flow
|
||||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
"chunk-size": 1500, # Override chunk size
|
||||
"chunk-overlap": 150 # Override chunk overlap
|
||||
}.get(param)
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
mock_message, mock_consumer, mock_flow, 1000, 100
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert chunk_size == 1500 # Should use overridden value
|
||||
assert chunk_overlap == 150 # Should use overridden value
|
||||
|
||||
@patch('trustgraph.chunking.recursive.chunker.RecursiveCharacterTextSplitter')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_uses_flow_parameters(self, mock_splitter_class):
|
||||
"""Test that on_message method uses parameters from flow"""
|
||||
# Arrange
|
||||
mock_splitter = MagicMock()
|
||||
mock_document = MagicMock()
|
||||
mock_document.page_content = "Test chunk content"
|
||||
mock_splitter.create_documents.return_value = [mock_document]
|
||||
mock_splitter_class.return_value = mock_splitter
|
||||
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'chunk_size': 1000,
|
||||
'chunk_overlap': 100,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock message with TextDocument
|
||||
mock_message = MagicMock()
|
||||
mock_text_doc = MagicMock()
|
||||
mock_text_doc.metadata = Metadata(
|
||||
id="test-doc-123",
|
||||
metadata=[],
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
mock_text_doc.text = b"This is test document content"
|
||||
mock_message.value.return_value = mock_text_doc
|
||||
|
||||
# Mock consumer and flow with parameter overrides
|
||||
mock_consumer = MagicMock()
|
||||
mock_producer = AsyncMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
"chunk-size": 1500,
|
||||
"chunk-overlap": 150,
|
||||
"output": mock_producer
|
||||
}.get(param)
|
||||
|
||||
# Act
|
||||
await processor.on_message(mock_message, mock_consumer, mock_flow)
|
||||
|
||||
# Assert
|
||||
# Verify RecursiveCharacterTextSplitter was called with overridden parameters (last call)
|
||||
actual_last_call = mock_splitter_class.call_args_list[-1]
|
||||
assert actual_last_call.kwargs['chunk_size'] == 1500
|
||||
assert actual_last_call.kwargs['chunk_overlap'] == 150
|
||||
assert actual_last_call.kwargs['length_function'] == len
|
||||
assert actual_last_call.kwargs['is_separator_regex'] == False
|
||||
|
||||
# Verify chunk was sent to output
|
||||
mock_producer.send.assert_called_once()
|
||||
sent_chunk = mock_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_chunk, Chunk)
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_no_overrides(self):
|
||||
"""Test chunk_document when no parameters are overridden (flow returns None)"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'chunk_size': 1000,
|
||||
'chunk_overlap': 100,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock message and flow that returns None for all parameters
|
||||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.return_value = None # No overrides
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
mock_message, mock_consumer, mock_flow, 1000, 100
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert chunk_size == 1000 # Should use default value
|
||||
assert chunk_overlap == 100 # Should use default value
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_document():
|
||||
metadata = Metadata(
|
||||
id="test-doc-1",
|
||||
metadata=[],
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
text = "This is a test document. " * 100 # Create text long enough to be chunked
|
||||
return TextDocument(
|
||||
metadata=metadata,
|
||||
text=text.encode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def short_document():
|
||||
metadata = Metadata(
|
||||
id="test-doc-2",
|
||||
metadata=[],
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
text = "This is a very short document."
|
||||
return TextDocument(
|
||||
metadata=metadata,
|
||||
text=text.encode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
class TestRecursiveChunker:
|
||||
|
||||
def test_init_default_params(self, mock_async_processor_init):
|
||||
processor = RecursiveChunker()
|
||||
assert processor.text_splitter._chunk_size == 2000
|
||||
assert processor.text_splitter._chunk_overlap == 100
|
||||
|
||||
def test_init_custom_params(self, mock_async_processor_init):
|
||||
processor = RecursiveChunker(chunk_size=500, chunk_overlap=50)
|
||||
assert processor.text_splitter._chunk_size == 500
|
||||
assert processor.text_splitter._chunk_overlap == 50
|
||||
|
||||
def test_init_with_id(self, mock_async_processor_init):
|
||||
processor = RecursiveChunker(id="custom-chunker")
|
||||
assert processor.id == "custom-chunker"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_single_chunk(self, mock_async_processor_init, mock_flow, mock_consumer, short_document):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = RecursiveChunker(chunk_size=2000, chunk_overlap=100)
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = short_document
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Should produce exactly one chunk for short text
|
||||
assert output_mock.send.call_count == 1
|
||||
|
||||
# Verify the chunk was created correctly
|
||||
chunk_call = output_mock.send.call_args[0][0]
|
||||
assert isinstance(chunk_call, Chunk)
|
||||
assert chunk_call.metadata == short_document.metadata
|
||||
assert chunk_call.chunk.decode("utf-8") == short_document.text.decode("utf-8")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_multiple_chunks(self, mock_async_processor_init, mock_flow, mock_consumer, sample_document):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = RecursiveChunker(chunk_size=100, chunk_overlap=20)
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = sample_document
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Should produce multiple chunks
|
||||
assert output_mock.send.call_count > 1
|
||||
|
||||
# Verify all chunks have correct metadata
|
||||
for call in output_mock.send.call_args_list:
|
||||
chunk = call[0][0]
|
||||
assert isinstance(chunk, Chunk)
|
||||
assert chunk.metadata == sample_document.metadata
|
||||
assert len(chunk.chunk) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_chunk_overlap(self, mock_async_processor_init, mock_flow, mock_consumer):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = RecursiveChunker(chunk_size=50, chunk_overlap=10)
|
||||
|
||||
# Create a document with predictable content
|
||||
metadata = Metadata(id="test", metadata=[], user="test-user", collection="test-collection")
|
||||
text = "ABCDEFGHIJ" * 10 # 100 characters
|
||||
document = TextDocument(metadata=metadata, text=text.encode("utf-8"))
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = document
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Collect all chunks
|
||||
chunks = []
|
||||
for call in output_mock.send.call_args_list:
|
||||
chunk_text = call[0][0].chunk.decode("utf-8")
|
||||
chunks.append(chunk_text)
|
||||
|
||||
# Verify chunks have expected overlap
|
||||
for i in range(len(chunks) - 1):
|
||||
# The end of chunk i should overlap with the beginning of chunk i+1
|
||||
# Check if there's some overlap (exact overlap depends on text splitter logic)
|
||||
assert len(chunks[i]) <= 50 + 10 # chunk_size + some tolerance
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_empty_document(self, mock_async_processor_init, mock_flow, mock_consumer):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = RecursiveChunker()
|
||||
|
||||
metadata = Metadata(id="empty", metadata=[], user="test-user", collection="test-collection")
|
||||
document = TextDocument(metadata=metadata, text=b"")
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = document
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Empty documents typically don't produce chunks with langchain splitters
|
||||
# This behavior is expected - no chunks should be produced
|
||||
assert output_mock.send.call_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_unicode_handling(self, mock_async_processor_init, mock_flow, mock_consumer):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = RecursiveChunker(chunk_size=500, chunk_overlap=20) # Fixed overlap < chunk_size
|
||||
|
||||
metadata = Metadata(id="unicode", metadata=[], user="test-user", collection="test-collection")
|
||||
text = "Hello 世界! 🌍 This is a test with émojis and spëcial characters."
|
||||
document = TextDocument(metadata=metadata, text=text.encode("utf-8"))
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = document
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Verify unicode is preserved correctly
|
||||
all_chunks = []
|
||||
for call in output_mock.send.call_args_list:
|
||||
chunk_text = call[0][0].chunk.decode("utf-8")
|
||||
all_chunks.append(chunk_text)
|
||||
|
||||
# Reconstruct text (approximately, due to overlap)
|
||||
reconstructed = "".join(all_chunks)
|
||||
assert "世界" in reconstructed
|
||||
assert "🌍" in reconstructed
|
||||
assert "émojis" in reconstructed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_recorded(self, mock_async_processor_init, mock_flow, mock_consumer, sample_document):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = RecursiveChunker(chunk_size=100)
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = sample_document
|
||||
|
||||
# Mock the metric
|
||||
with patch.object(RecursiveChunker.chunk_metric, 'labels') as mock_labels:
|
||||
mock_observe = Mock()
|
||||
mock_labels.return_value.observe = mock_observe
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Verify metrics were recorded
|
||||
mock_labels.assert_called_with(id="test-consumer", flow="test-flow")
|
||||
assert mock_observe.call_count > 0
|
||||
|
||||
# Verify chunk sizes were observed
|
||||
for call in mock_observe.call_args_list:
|
||||
chunk_size = call[0][0]
|
||||
assert chunk_size > 0
|
||||
|
||||
def test_add_args(self):
|
||||
parser = Mock()
|
||||
RecursiveChunker.add_args(parser)
|
||||
|
||||
# Verify arguments were added
|
||||
calls = parser.add_argument.call_args_list
|
||||
arg_names = [call[0][0] for call in calls]
|
||||
|
||||
assert '-z' in arg_names or '--chunk-size' in arg_names
|
||||
assert '-v' in arg_names or '--chunk-overlap' in arg_names
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -1,275 +1,256 @@
|
|||
"""
|
||||
Unit tests for trustgraph.chunking.token
|
||||
Testing parameter override functionality for chunk-size and chunk-overlap
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.chunking.token.chunker import Processor
|
||||
from trustgraph.schema import TextDocument, Chunk, Metadata
|
||||
from trustgraph.chunking.token.chunker import Processor as TokenChunker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flow():
|
||||
output_mock = AsyncMock()
|
||||
flow_mock = Mock(return_value=output_mock)
|
||||
return flow_mock, output_mock
|
||||
class MockAsyncProcessor:
|
||||
def __init__(self, **params):
|
||||
self.config_handlers = []
|
||||
self.id = params.get('id', 'test-service')
|
||||
self.specifications = []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_consumer():
|
||||
consumer = Mock()
|
||||
consumer.id = "test-consumer"
|
||||
consumer.flow = "test-flow"
|
||||
return consumer
|
||||
class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Token chunker functionality"""
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_processor_initialization_basic(self):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'chunk_size': 300,
|
||||
'chunk_overlap': 20,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.default_chunk_size == 300
|
||||
assert processor.default_chunk_overlap == 20
|
||||
assert hasattr(processor, 'text_splitter')
|
||||
|
||||
# Verify parameter specs are registered
|
||||
param_specs = [spec for spec in processor.specifications
|
||||
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
|
||||
assert len(param_specs) == 2
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_chunk_size_override(self):
|
||||
"""Test chunk_document with chunk-size parameter override"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'chunk_size': 250, # Default chunk size
|
||||
'chunk_overlap': 15,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock message and flow
|
||||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
"chunk-size": 400, # Override chunk size
|
||||
"chunk-overlap": None # Use default chunk overlap
|
||||
}.get(param)
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
mock_message, mock_consumer, mock_flow, 250, 15
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert chunk_size == 400 # Should use overridden value
|
||||
assert chunk_overlap == 15 # Should use default value
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_chunk_overlap_override(self):
|
||||
"""Test chunk_document with chunk-overlap parameter override"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'chunk_size': 250,
|
||||
'chunk_overlap': 15, # Default chunk overlap
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock message and flow
|
||||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
"chunk-size": None, # Use default chunk size
|
||||
"chunk-overlap": 25 # Override chunk overlap
|
||||
}.get(param)
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
mock_message, mock_consumer, mock_flow, 250, 15
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert chunk_size == 250 # Should use default value
|
||||
assert chunk_overlap == 25 # Should use overridden value
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_both_parameters_override(self):
|
||||
"""Test chunk_document with both chunk-size and chunk-overlap overrides"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'chunk_size': 250,
|
||||
'chunk_overlap': 15,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock message and flow
|
||||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
"chunk-size": 350, # Override chunk size
|
||||
"chunk-overlap": 30 # Override chunk overlap
|
||||
}.get(param)
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
mock_message, mock_consumer, mock_flow, 250, 15
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert chunk_size == 350 # Should use overridden value
|
||||
assert chunk_overlap == 30 # Should use overridden value
|
||||
|
||||
@patch('trustgraph.chunking.token.chunker.TokenTextSplitter')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_uses_flow_parameters(self, mock_splitter_class):
|
||||
"""Test that on_message method uses parameters from flow"""
|
||||
# Arrange
|
||||
mock_splitter = MagicMock()
|
||||
mock_document = MagicMock()
|
||||
mock_document.page_content = "Test token chunk content"
|
||||
mock_splitter.create_documents.return_value = [mock_document]
|
||||
mock_splitter_class.return_value = mock_splitter
|
||||
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'chunk_size': 250,
|
||||
'chunk_overlap': 15,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock message with TextDocument
|
||||
mock_message = MagicMock()
|
||||
mock_text_doc = MagicMock()
|
||||
mock_text_doc.metadata = Metadata(
|
||||
id="test-doc-456",
|
||||
metadata=[],
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
mock_text_doc.text = b"This is test document content for token chunking"
|
||||
mock_message.value.return_value = mock_text_doc
|
||||
|
||||
# Mock consumer and flow with parameter overrides
|
||||
mock_consumer = MagicMock()
|
||||
mock_producer = AsyncMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
"chunk-size": 400,
|
||||
"chunk-overlap": 40,
|
||||
"output": mock_producer
|
||||
}.get(param)
|
||||
|
||||
# Act
|
||||
await processor.on_message(mock_message, mock_consumer, mock_flow)
|
||||
|
||||
# Assert
|
||||
# Verify TokenTextSplitter was called with overridden parameters (last call)
|
||||
expected_call = [
|
||||
('encoding_name', 'cl100k_base'),
|
||||
('chunk_size', 400),
|
||||
('chunk_overlap', 40)
|
||||
]
|
||||
actual_last_call = mock_splitter_class.call_args_list[-1]
|
||||
assert actual_last_call.kwargs['encoding_name'] == "cl100k_base"
|
||||
assert actual_last_call.kwargs['chunk_size'] == 400
|
||||
assert actual_last_call.kwargs['chunk_overlap'] == 40
|
||||
|
||||
# Verify chunk was sent to output
|
||||
mock_producer.send.assert_called_once()
|
||||
sent_chunk = mock_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_chunk, Chunk)
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_no_overrides(self):
|
||||
"""Test chunk_document when no parameters are overridden (flow returns None)"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'chunk_size': 250,
|
||||
'chunk_overlap': 15,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock message and flow that returns None for all parameters
|
||||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.return_value = None # No overrides
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
mock_message, mock_consumer, mock_flow, 250, 15
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert chunk_size == 250 # Should use default value
|
||||
assert chunk_overlap == 15 # Should use default value
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_token_chunker_uses_different_defaults(self):
|
||||
"""Test that token chunker has different defaults than recursive chunker"""
|
||||
# Arrange & Act
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert - Token chunker should have different defaults
|
||||
assert processor.default_chunk_size == 250 # Token chunker default
|
||||
assert processor.default_chunk_overlap == 15 # Token chunker default
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_document():
|
||||
metadata = Metadata(
|
||||
id="test-doc-1",
|
||||
metadata=[],
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
# Create text that will result in multiple token chunks
|
||||
text = "The quick brown fox jumps over the lazy dog. " * 50
|
||||
return TextDocument(
|
||||
metadata=metadata,
|
||||
text=text.encode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def short_document():
|
||||
metadata = Metadata(
|
||||
id="test-doc-2",
|
||||
metadata=[],
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
text = "Short text."
|
||||
return TextDocument(
|
||||
metadata=metadata,
|
||||
text=text.encode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
class TestTokenChunker:
|
||||
|
||||
def test_init_default_params(self, mock_async_processor_init):
|
||||
processor = TokenChunker()
|
||||
assert processor.text_splitter._chunk_size == 250
|
||||
assert processor.text_splitter._chunk_overlap == 15
|
||||
# Just verify the text splitter was created (encoding verification is complex)
|
||||
assert processor.text_splitter is not None
|
||||
assert hasattr(processor.text_splitter, 'split_text')
|
||||
|
||||
def test_init_custom_params(self, mock_async_processor_init):
|
||||
processor = TokenChunker(chunk_size=100, chunk_overlap=10)
|
||||
assert processor.text_splitter._chunk_size == 100
|
||||
assert processor.text_splitter._chunk_overlap == 10
|
||||
|
||||
def test_init_with_id(self, mock_async_processor_init):
|
||||
processor = TokenChunker(id="custom-token-chunker")
|
||||
assert processor.id == "custom-token-chunker"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_single_chunk(self, mock_async_processor_init, mock_flow, mock_consumer, short_document):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = TokenChunker(chunk_size=250, chunk_overlap=15)
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = short_document
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Short text should produce exactly one chunk
|
||||
assert output_mock.send.call_count == 1
|
||||
|
||||
# Verify the chunk was created correctly
|
||||
chunk_call = output_mock.send.call_args[0][0]
|
||||
assert isinstance(chunk_call, Chunk)
|
||||
assert chunk_call.metadata == short_document.metadata
|
||||
assert chunk_call.chunk.decode("utf-8") == short_document.text.decode("utf-8")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_multiple_chunks(self, mock_async_processor_init, mock_flow, mock_consumer, sample_document):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = TokenChunker(chunk_size=50, chunk_overlap=5)
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = sample_document
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Should produce multiple chunks
|
||||
assert output_mock.send.call_count > 1
|
||||
|
||||
# Verify all chunks have correct metadata
|
||||
for call in output_mock.send.call_args_list:
|
||||
chunk = call[0][0]
|
||||
assert isinstance(chunk, Chunk)
|
||||
assert chunk.metadata == sample_document.metadata
|
||||
assert len(chunk.chunk) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_token_overlap(self, mock_async_processor_init, mock_flow, mock_consumer):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = TokenChunker(chunk_size=20, chunk_overlap=5)
|
||||
|
||||
# Create a document with repeated pattern
|
||||
metadata = Metadata(id="test", metadata=[], user="test-user", collection="test-collection")
|
||||
text = "one two three four five six seven eight nine ten " * 5
|
||||
document = TextDocument(metadata=metadata, text=text.encode("utf-8"))
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = document
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Collect all chunks
|
||||
chunks = []
|
||||
for call in output_mock.send.call_args_list:
|
||||
chunk_text = call[0][0].chunk.decode("utf-8")
|
||||
chunks.append(chunk_text)
|
||||
|
||||
# Should have multiple chunks
|
||||
assert len(chunks) > 1
|
||||
|
||||
# Verify chunks are not empty
|
||||
for chunk in chunks:
|
||||
assert len(chunk) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_empty_document(self, mock_async_processor_init, mock_flow, mock_consumer):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = TokenChunker()
|
||||
|
||||
metadata = Metadata(id="empty", metadata=[], user="test-user", collection="test-collection")
|
||||
document = TextDocument(metadata=metadata, text=b"")
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = document
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Empty documents typically don't produce chunks with langchain splitters
|
||||
# This behavior is expected - no chunks should be produced
|
||||
assert output_mock.send.call_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_unicode_handling(self, mock_async_processor_init, mock_flow, mock_consumer):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = TokenChunker(chunk_size=50)
|
||||
|
||||
metadata = Metadata(id="unicode", metadata=[], user="test-user", collection="test-collection")
|
||||
# Test with various unicode characters
|
||||
text = "Hello 世界! 🌍 Test émojis café naïve résumé. Greek: αβγδε Hebrew: אבגדה"
|
||||
document = TextDocument(metadata=metadata, text=text.encode("utf-8"))
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = document
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Verify unicode is preserved correctly
|
||||
all_chunks = []
|
||||
for call in output_mock.send.call_args_list:
|
||||
chunk_text = call[0][0].chunk.decode("utf-8")
|
||||
all_chunks.append(chunk_text)
|
||||
|
||||
# Reconstruct text
|
||||
reconstructed = "".join(all_chunks)
|
||||
assert "世界" in reconstructed
|
||||
assert "🌍" in reconstructed
|
||||
assert "émojis" in reconstructed
|
||||
assert "αβγδε" in reconstructed
|
||||
assert "אבגדה" in reconstructed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_token_boundary_preservation(self, mock_async_processor_init, mock_flow, mock_consumer):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = TokenChunker(chunk_size=10, chunk_overlap=2)
|
||||
|
||||
metadata = Metadata(id="boundary", metadata=[], user="test-user", collection="test-collection")
|
||||
# Text with clear word boundaries
|
||||
text = "This is a test of token boundaries and proper splitting."
|
||||
document = TextDocument(metadata=metadata, text=text.encode("utf-8"))
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = document
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Collect all chunks
|
||||
chunks = []
|
||||
for call in output_mock.send.call_args_list:
|
||||
chunk_text = call[0][0].chunk.decode("utf-8")
|
||||
chunks.append(chunk_text)
|
||||
|
||||
# Token chunker should respect token boundaries
|
||||
for chunk in chunks:
|
||||
# Chunks should not start or end with partial words (in most cases)
|
||||
assert len(chunk.strip()) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_recorded(self, mock_async_processor_init, mock_flow, mock_consumer, sample_document):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = TokenChunker(chunk_size=50)
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = sample_document
|
||||
|
||||
# Mock the metric
|
||||
with patch.object(TokenChunker.chunk_metric, 'labels') as mock_labels:
|
||||
mock_observe = Mock()
|
||||
mock_labels.return_value.observe = mock_observe
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Verify metrics were recorded
|
||||
mock_labels.assert_called_with(id="test-consumer", flow="test-flow")
|
||||
assert mock_observe.call_count > 0
|
||||
|
||||
# Verify chunk sizes were observed
|
||||
for call in mock_observe.call_args_list:
|
||||
chunk_size = call[0][0]
|
||||
assert chunk_size > 0
|
||||
|
||||
def test_add_args(self):
|
||||
parser = Mock()
|
||||
TokenChunker.add_args(parser)
|
||||
|
||||
# Verify arguments were added
|
||||
calls = parser.add_argument.call_args_list
|
||||
arg_names = [call[0][0] for call in calls]
|
||||
|
||||
assert '-z' in arg_names or '--chunk-size' in arg_names
|
||||
assert '-v' in arg_names or '--chunk-overlap' in arg_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_encoding_specific_behavior(self, mock_async_processor_init, mock_flow, mock_consumer):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = TokenChunker(chunk_size=10, chunk_overlap=0)
|
||||
|
||||
metadata = Metadata(id="encoding", metadata=[], user="test-user", collection="test-collection")
|
||||
# Test text that might tokenize differently with cl100k_base encoding
|
||||
text = "GPT-4 is an AI model. It uses tokens."
|
||||
document = TextDocument(metadata=metadata, text=text.encode("utf-8"))
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = document
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Verify chunking happened
|
||||
assert output_mock.send.call_count >= 1
|
||||
|
||||
# Collect all chunks
|
||||
chunks = []
|
||||
for call in output_mock.send.call_args_list:
|
||||
chunk_text = call[0][0].chunk.decode("utf-8")
|
||||
chunks.append(chunk_text)
|
||||
|
||||
# Verify all text is preserved (allowing for overlap)
|
||||
all_text = " ".join(chunks)
|
||||
assert "GPT-4" in all_text
|
||||
assert "AI model" in all_text
|
||||
assert "tokens" in all_text
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -34,7 +34,9 @@ class TestGraphRag:
|
|||
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
||||
assert graph_rag.triples_client == mock_triples_client
|
||||
assert graph_rag.verbose is False # Default value
|
||||
assert graph_rag.label_cache == {} # Empty cache initially
|
||||
# Verify label_cache is an LRUCacheWithTTL instance
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
|
||||
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
|
||||
|
||||
def test_graph_rag_initialization_with_verbose(self):
|
||||
"""Test GraphRag initialization with verbose enabled"""
|
||||
|
|
@ -59,7 +61,9 @@ class TestGraphRag:
|
|||
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
||||
assert graph_rag.triples_client == mock_triples_client
|
||||
assert graph_rag.verbose is True
|
||||
assert graph_rag.label_cache == {} # Empty cache initially
|
||||
# Verify label_cache is an LRUCacheWithTTL instance
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
|
||||
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
|
||||
|
||||
|
||||
class TestQuery:
|
||||
|
|
@ -228,8 +232,11 @@ class TestQuery:
|
|||
"""Test Query.maybe_label method with cached label"""
|
||||
# Create mock GraphRag with label cache
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.label_cache = {"entity1": "Entity One Label"}
|
||||
|
||||
# Create mock LRUCacheWithTTL
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.get.return_value = "Entity One Label"
|
||||
mock_rag.label_cache = mock_cache
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -237,27 +244,32 @@ class TestQuery:
|
|||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
|
||||
# Call maybe_label with cached entity
|
||||
result = await query.maybe_label("entity1")
|
||||
|
||||
|
||||
# Verify cached label is returned
|
||||
assert result == "Entity One Label"
|
||||
# Verify cache was checked with proper key format (user:collection:entity)
|
||||
mock_cache.get.assert_called_once_with("test_user:test_collection:entity1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_label_with_label_lookup(self):
|
||||
"""Test Query.maybe_label method with database label lookup"""
|
||||
# Create mock GraphRag with triples client
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.label_cache = {} # Empty cache
|
||||
# Create mock LRUCacheWithTTL that returns None (cache miss)
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.get.return_value = None
|
||||
mock_rag.label_cache = mock_cache
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
|
||||
# Mock triple result with label
|
||||
mock_triple = MagicMock()
|
||||
mock_triple.o = "Human Readable Label"
|
||||
mock_triples_client.query.return_value = [mock_triple]
|
||||
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -265,10 +277,10 @@ class TestQuery:
|
|||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
|
||||
# Call maybe_label
|
||||
result = await query.maybe_label("http://example.com/entity")
|
||||
|
||||
|
||||
# Verify triples client was called correctly
|
||||
mock_triples_client.query.assert_called_once_with(
|
||||
s="http://example.com/entity",
|
||||
|
|
@ -278,17 +290,21 @@ class TestQuery:
|
|||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
# Verify result and cache update
|
||||
|
||||
# Verify result and cache update with proper key
|
||||
assert result == "Human Readable Label"
|
||||
assert mock_rag.label_cache["http://example.com/entity"] == "Human Readable Label"
|
||||
cache_key = "test_user:test_collection:http://example.com/entity"
|
||||
mock_cache.put.assert_called_once_with(cache_key, "Human Readable Label")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_label_with_no_label_found(self):
|
||||
"""Test Query.maybe_label method when no label is found"""
|
||||
# Create mock GraphRag with triples client
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.label_cache = {} # Empty cache
|
||||
# Create mock LRUCacheWithTTL that returns None (cache miss)
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.get.return_value = None
|
||||
mock_rag.label_cache = mock_cache
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
|
|
@ -318,7 +334,8 @@ class TestQuery:
|
|||
|
||||
# Verify result is entity itself and cache is updated
|
||||
assert result == "unlabeled_entity"
|
||||
assert mock_rag.label_cache["unlabeled_entity"] == "unlabeled_entity"
|
||||
cache_key = "test_user:test_collection:unlabeled_entity"
|
||||
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_basic_functionality(self):
|
||||
|
|
@ -441,40 +458,40 @@ class TestQuery:
|
|||
@pytest.mark.asyncio
|
||||
async def test_get_subgraph_method(self):
|
||||
"""Test Query.get_subgraph method orchestrates entity and edge discovery"""
|
||||
# Create mock Query that patches get_entities and follow_edges
|
||||
# Create mock Query that patches get_entities and follow_edges_batch
|
||||
mock_rag = MagicMock()
|
||||
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_path_length=1
|
||||
)
|
||||
|
||||
|
||||
# Mock get_entities to return test entities
|
||||
query.get_entities = AsyncMock(return_value=["entity1", "entity2"])
|
||||
|
||||
# Mock follow_edges to add triples to subgraph
|
||||
async def mock_follow_edges(ent, subgraph, path_length):
|
||||
subgraph.add((ent, "predicate", "object"))
|
||||
|
||||
query.follow_edges = AsyncMock(side_effect=mock_follow_edges)
|
||||
|
||||
|
||||
# Mock follow_edges_batch to return test triples
|
||||
query.follow_edges_batch = AsyncMock(return_value={
|
||||
("entity1", "predicate1", "object1"),
|
||||
("entity2", "predicate2", "object2")
|
||||
})
|
||||
|
||||
# Call get_subgraph
|
||||
result = await query.get_subgraph("test query")
|
||||
|
||||
|
||||
# Verify get_entities was called
|
||||
query.get_entities.assert_called_once_with("test query")
|
||||
|
||||
# Verify follow_edges was called for each entity
|
||||
assert query.follow_edges.call_count == 2
|
||||
query.follow_edges.assert_any_call("entity1", unittest.mock.ANY, 1)
|
||||
query.follow_edges.assert_any_call("entity2", unittest.mock.ANY, 1)
|
||||
|
||||
# Verify result is list format
|
||||
|
||||
# Verify follow_edges_batch was called with entities and max_path_length
|
||||
query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1)
|
||||
|
||||
# Verify result is list format and contains expected triples
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert ("entity1", "predicate1", "object1") in result
|
||||
assert ("entity2", "predicate2", "object2") in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_labelgraph_method(self):
|
||||
|
|
|
|||
|
|
@ -178,37 +178,24 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
assert calls[2][1]['vectors'][0]['metadata']['doc'] == "This is the second document chunk"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_index_creation(self, processor):
|
||||
"""Test automatic index creation when index doesn't exist"""
|
||||
async def test_store_document_embeddings_index_validation(self, processor):
|
||||
"""Test that writing to non-existent index raises ValueError"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Test document content",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
# Mock index doesn't exist initially
|
||||
|
||||
# Mock index doesn't exist
|
||||
processor.pinecone.has_index.return_value = False
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# Mock index creation
|
||||
processor.pinecone.describe_index.return_value.status = {"ready": True}
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify index creation was called
|
||||
expected_index_name = "d-test_user-test_collection"
|
||||
processor.pinecone.create_index.assert_called_once()
|
||||
create_call = processor.pinecone.create_index.call_args
|
||||
assert create_call[1]['name'] == expected_index_name
|
||||
assert create_call[1]['dimension'] == 3
|
||||
assert create_call[1]['metric'] == "cosine"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_empty_chunk(self, processor):
|
||||
|
|
@ -357,47 +344,44 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
mock_index.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_index_creation_failure(self, processor):
|
||||
"""Test handling of index creation failure"""
|
||||
async def test_store_document_embeddings_validation_before_creation(self, processor):
|
||||
"""Test that validation error occurs before creation attempts"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Test document content",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
# Mock index doesn't exist and creation fails
|
||||
|
||||
# Mock index doesn't exist
|
||||
processor.pinecone.has_index.return_value = False
|
||||
processor.pinecone.create_index.side_effect = Exception("Index creation failed")
|
||||
|
||||
with pytest.raises(Exception, match="Index creation failed"):
|
||||
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_index_creation_timeout(self, processor):
|
||||
"""Test handling of index creation timeout"""
|
||||
async def test_store_document_embeddings_validates_before_timeout(self, processor):
|
||||
"""Test that validation error occurs before timeout checks"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Test document content",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
# Mock index doesn't exist and never becomes ready
|
||||
|
||||
# Mock index doesn't exist
|
||||
processor.pinecone.has_index.return_value = False
|
||||
processor.pinecone.describe_index.return_value.status = {"ready": False}
|
||||
|
||||
with patch('time.sleep'): # Speed up the test
|
||||
with pytest.raises(RuntimeError, match="Gave up waiting for index creation"):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_unicode_content(self, processor):
|
||||
|
|
|
|||
|
|
@ -43,8 +43,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
# Verify processor attributes
|
||||
assert hasattr(processor, 'qdrant')
|
||||
assert processor.qdrant == mock_qdrant_instance
|
||||
assert hasattr(processor, 'last_collection')
|
||||
assert processor.last_collection is None
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
|
|
@ -245,8 +243,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True # Collection exists
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
|
|
@ -255,36 +254,37 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Create mock message with empty chunk
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'empty_user'
|
||||
mock_message.metadata.collection = 'empty_collection'
|
||||
|
||||
|
||||
mock_chunk_empty = MagicMock()
|
||||
mock_chunk_empty.chunk.decode.return_value = "" # Empty string
|
||||
mock_chunk_empty.vectors = [[0.1, 0.2]]
|
||||
|
||||
|
||||
mock_message.chunks = [mock_chunk_empty]
|
||||
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should not call upsert for empty chunks
|
||||
mock_qdrant_instance.upsert.assert_not_called()
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
# But collection_exists should be called for validation
|
||||
mock_qdrant_instance.collection_exists.assert_called_once()
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_collection_creation_when_not_exists(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test collection creation when it doesn't exist"""
|
||||
"""Test that writing to non-existent collection raises ValueError"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
|
|
@ -293,46 +293,32 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'new_user'
|
||||
mock_message.metadata.collection = 'new_collection'
|
||||
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'test chunk'
|
||||
mock_chunk.vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] # 5 dimensions
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
expected_collection = 'd_new_user_new_collection'
|
||||
|
||||
# Verify collection existence check and creation
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
||||
mock_qdrant_instance.create_collection.assert_called_once()
|
||||
|
||||
# Verify create_collection was called with correct parameters
|
||||
create_call_args = mock_qdrant_instance.create_collection.call_args
|
||||
assert create_call_args[1]['collection_name'] == expected_collection
|
||||
|
||||
# Verify upsert was still called after collection creation
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_collection_creation_exception(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test collection creation handles exceptions"""
|
||||
"""Test that validation error occurs before connection errors"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False
|
||||
mock_qdrant_instance.create_collection.side_effect = Exception("Qdrant connection failed")
|
||||
mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
|
|
@ -341,32 +327,35 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'error_user'
|
||||
mock_message.metadata.collection = 'error_collection'
|
||||
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'test chunk'
|
||||
mock_chunk.vectors = [[0.1, 0.2]]
|
||||
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_collection_caching_behavior(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test collection caching with last_collection"""
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
async def test_collection_validation_on_write(self, mock_uuid, mock_base_init, mock_qdrant_client):
|
||||
"""Test collection validation checks collection exists before writing"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
mock_uuid.uuid4.return_value = MagicMock()
|
||||
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
|
|
@ -375,46 +364,45 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Create first mock message
|
||||
mock_message1 = MagicMock()
|
||||
mock_message1.metadata.user = 'cache_user'
|
||||
mock_message1.metadata.collection = 'cache_collection'
|
||||
|
||||
|
||||
mock_chunk1 = MagicMock()
|
||||
mock_chunk1.chunk.decode.return_value = 'first chunk'
|
||||
mock_chunk1.vectors = [[0.1, 0.2, 0.3]]
|
||||
|
||||
|
||||
mock_message1.chunks = [mock_chunk1]
|
||||
|
||||
|
||||
# First call
|
||||
await processor.store_document_embeddings(mock_message1)
|
||||
|
||||
|
||||
# Reset mock to track second call
|
||||
mock_qdrant_instance.reset_mock()
|
||||
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
|
||||
# Create second mock message with same dimensions
|
||||
mock_message2 = MagicMock()
|
||||
mock_message2.metadata.user = 'cache_user'
|
||||
mock_message2.metadata.collection = 'cache_collection'
|
||||
|
||||
|
||||
mock_chunk2 = MagicMock()
|
||||
mock_chunk2.chunk.decode.return_value = 'second chunk'
|
||||
mock_chunk2.vectors = [[0.4, 0.5, 0.6]] # Same dimension (3)
|
||||
|
||||
|
||||
mock_message2.chunks = [mock_chunk2]
|
||||
|
||||
|
||||
# Act - Second call with same collection
|
||||
await processor.store_document_embeddings(mock_message2)
|
||||
|
||||
# Assert
|
||||
expected_collection = 'd_cache_user_cache_collection'
|
||||
assert processor.last_collection == expected_collection
|
||||
|
||||
# Verify second call skipped existence check (cached)
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
||||
|
||||
# Verify collection existence is checked on each write
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
||||
|
||||
# But upsert should still be called
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
||||
|
|
|
|||
|
|
@ -178,37 +178,24 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
assert calls[2][1]['vectors'][0]['metadata']['entity'] == "entity2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_index_creation(self, processor):
|
||||
"""Test automatic index creation when index doesn't exist"""
|
||||
async def test_store_graph_embeddings_index_validation(self, processor):
|
||||
"""Test that writing to non-existent index raises ValueError"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="test_entity", is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
# Mock index doesn't exist initially
|
||||
|
||||
# Mock index doesn't exist
|
||||
processor.pinecone.has_index.return_value = False
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# Mock index creation
|
||||
processor.pinecone.describe_index.return_value.status = {"ready": True}
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify index creation was called
|
||||
expected_index_name = "t-test_user-test_collection"
|
||||
processor.pinecone.create_index.assert_called_once()
|
||||
create_call = processor.pinecone.create_index.call_args
|
||||
assert create_call[1]['name'] == expected_index_name
|
||||
assert create_call[1]['dimension'] == 3
|
||||
assert create_call[1]['metric'] == "cosine"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_empty_entity_value(self, processor):
|
||||
|
|
@ -328,47 +315,44 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
mock_index.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_index_creation_failure(self, processor):
|
||||
"""Test handling of index creation failure"""
|
||||
async def test_store_graph_embeddings_validation_before_creation(self, processor):
|
||||
"""Test that validation error occurs before any creation attempts"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="test_entity", is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
# Mock index doesn't exist and creation fails
|
||||
|
||||
# Mock index doesn't exist
|
||||
processor.pinecone.has_index.return_value = False
|
||||
processor.pinecone.create_index.side_effect = Exception("Index creation failed")
|
||||
|
||||
with pytest.raises(Exception, match="Index creation failed"):
|
||||
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_index_creation_timeout(self, processor):
|
||||
"""Test handling of index creation timeout"""
|
||||
async def test_store_graph_embeddings_validates_before_timeout(self, processor):
|
||||
"""Test that validation error occurs before timeout checks"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="test_entity", is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
# Mock index doesn't exist and never becomes ready
|
||||
|
||||
# Mock index doesn't exist
|
||||
processor.pinecone.has_index.return_value = False
|
||||
processor.pinecone.describe_index.return_value.status = {"ready": False}
|
||||
|
||||
with patch('time.sleep'): # Speed up the test
|
||||
with pytest.raises(RuntimeError, match="Gave up waiting for index creation"):
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
|
|
|
|||
|
|
@ -43,19 +43,17 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
# Verify processor attributes
|
||||
assert hasattr(processor, 'qdrant')
|
||||
assert processor.qdrant == mock_qdrant_instance
|
||||
assert hasattr(processor, 'last_collection')
|
||||
assert processor.last_collection is None
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_get_collection_creates_new_collection(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection creates a new collection when it doesn't exist"""
|
||||
async def test_get_collection_validates_existence(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection validates that collection exists"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
|
|
@ -64,22 +62,10 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
collection_name = processor.get_collection(dim=512, user='test_user', collection='test_collection')
|
||||
|
||||
# Assert
|
||||
expected_name = 't_test_user_test_collection'
|
||||
assert collection_name == expected_name
|
||||
assert processor.last_collection == expected_name
|
||||
|
||||
# Verify collection existence check and creation
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name)
|
||||
mock_qdrant_instance.create_collection.assert_called_once()
|
||||
|
||||
# Verify create_collection was called with correct parameters
|
||||
create_call_args = mock_qdrant_instance.create_collection.call_args
|
||||
assert create_call_args[1]['collection_name'] == expected_name
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
processor.get_collection(user='test_user', collection='test_collection')
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
||||
|
|
@ -142,7 +128,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True # Collection exists
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
|
|
@ -151,15 +137,14 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Act
|
||||
collection_name = processor.get_collection(dim=256, user='existing_user', collection='existing_collection')
|
||||
collection_name = processor.get_collection(user='existing_user', collection='existing_collection')
|
||||
|
||||
# Assert
|
||||
expected_name = 't_existing_user_existing_collection'
|
||||
assert collection_name == expected_name
|
||||
assert processor.last_collection == expected_name
|
||||
|
||||
|
||||
# Verify collection existence check was performed
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name)
|
||||
# Verify create_collection was NOT called
|
||||
|
|
@ -167,14 +152,14 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_get_collection_caches_last_collection(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection skips checks when using same collection"""
|
||||
async def test_get_collection_validates_on_each_call(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection validates collection existence on each call"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
|
|
@ -183,36 +168,36 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# First call
|
||||
collection_name1 = processor.get_collection(dim=128, user='cache_user', collection='cache_collection')
|
||||
|
||||
collection_name1 = processor.get_collection(user='cache_user', collection='cache_collection')
|
||||
|
||||
# Reset mock to track second call
|
||||
mock_qdrant_instance.reset_mock()
|
||||
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
|
||||
# Act - Second call with same parameters
|
||||
collection_name2 = processor.get_collection(dim=128, user='cache_user', collection='cache_collection')
|
||||
collection_name2 = processor.get_collection(user='cache_user', collection='cache_collection')
|
||||
|
||||
# Assert
|
||||
expected_name = 't_cache_user_cache_collection'
|
||||
assert collection_name1 == expected_name
|
||||
assert collection_name2 == expected_name
|
||||
|
||||
# Verify second call skipped existence check (cached)
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
|
||||
# Verify collection existence check happens on each call
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name)
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_get_collection_creation_exception(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection handles collection creation exceptions"""
|
||||
"""Test get_collection raises ValueError when collection doesn't exist"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False
|
||||
mock_qdrant_instance.create_collection.side_effect = Exception("Qdrant connection failed")
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
|
|
@ -221,10 +206,10 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
processor.get_collection(dim=512, user='error_user', collection='error_collection')
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
processor.get_collection(user='error_user', collection='error_collection')
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ class TestMemgraphUserCollectionIsolation:
|
|||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
|
|
@ -55,28 +55,30 @@ class TestMemgraphUserCollectionIsolation:
|
|||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
# Create mock triple with URI object
|
||||
triple = MagicMock()
|
||||
triple.s.value = "http://example.com/subject"
|
||||
triple.p.value = "http://example.com/predicate"
|
||||
triple.o.value = "http://example.com/object"
|
||||
triple.o.is_uri = True
|
||||
|
||||
|
||||
# Create mock message with metadata
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = [triple]
|
||||
mock_message.metadata.user = "test_user"
|
||||
mock_message.metadata.collection = "test_collection"
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify user/collection parameters were passed to all operations
|
||||
# Should have: create_node (subject), create_node (object), relate_node = 3 calls
|
||||
assert mock_driver.execute_query.call_count == 3
|
||||
|
||||
|
||||
# Check that user and collection were included in all calls
|
||||
for call in mock_driver.execute_query.call_args_list:
|
||||
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
|
||||
|
|
@ -93,7 +95,7 @@ class TestMemgraphUserCollectionIsolation:
|
|||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
|
|
@ -101,24 +103,26 @@ class TestMemgraphUserCollectionIsolation:
|
|||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
# Create mock triple
|
||||
triple = MagicMock()
|
||||
triple.s.value = "http://example.com/subject"
|
||||
triple.p.value = "http://example.com/predicate"
|
||||
triple.o.value = "literal_value"
|
||||
triple.o.is_uri = False
|
||||
|
||||
|
||||
# Create mock message without user/collection metadata
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = [triple]
|
||||
mock_message.metadata.user = None
|
||||
mock_message.metadata.collection = None
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify defaults were used
|
||||
for call in mock_driver.execute_query.call_args_list:
|
||||
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
|
||||
|
|
@ -295,7 +299,7 @@ class TestMemgraphUserCollectionRegression:
|
|||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
|
|
@ -303,23 +307,25 @@ class TestMemgraphUserCollectionRegression:
|
|||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
# Store data for user1
|
||||
triple = MagicMock()
|
||||
triple.s.value = "http://example.com/subject"
|
||||
triple.p.value = "http://example.com/predicate"
|
||||
triple.o.value = "user1_data"
|
||||
triple.o.is_uri = False
|
||||
|
||||
|
||||
message_user1 = MagicMock()
|
||||
message_user1.triples = [triple]
|
||||
message_user1.metadata.user = "user1"
|
||||
message_user1.metadata.collection = "collection1"
|
||||
|
||||
await processor.store_triples(message_user1)
|
||||
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(message_user1)
|
||||
|
||||
# Verify that all storage operations included user1/collection1 parameters
|
||||
for call in mock_driver.execute_query.call_args_list:
|
||||
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
|
||||
|
|
|
|||
|
|
@ -75,8 +75,10 @@ class TestNeo4jUserCollectionIsolation:
|
|||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_driver.execute_query.return_value.summary = mock_summary
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify nodes and relationships were created with user/collection properties
|
||||
expected_calls = [
|
||||
|
|
@ -141,8 +143,10 @@ class TestNeo4jUserCollectionIsolation:
|
|||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_driver.execute_query.return_value.summary = mock_summary
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify defaults were used
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
|
|
@ -273,10 +277,12 @@ class TestNeo4jUserCollectionIsolation:
|
|||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_driver.execute_query.return_value.summary = mock_summary
|
||||
|
||||
# Store data for both users
|
||||
await processor.store_triples(message_user1)
|
||||
await processor.store_triples(message_user2)
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
# Store data for both users
|
||||
await processor.store_triples(message_user1)
|
||||
await processor.store_triples(message_user2)
|
||||
|
||||
# Verify user1 data was stored with user1/coll1
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
|
|
@ -446,9 +452,11 @@ class TestNeo4jUserCollectionRegression:
|
|||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_driver.execute_query.return_value.summary = mock_summary
|
||||
|
||||
await processor.store_triples(message_user1)
|
||||
await processor.store_triples(message_user2)
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(message_user1)
|
||||
await processor.store_triples(message_user2)
|
||||
|
||||
# Verify two separate nodes were created with same URI but different user/collection
|
||||
user1_node_call = call(
|
||||
|
|
|
|||
|
|
@ -251,6 +251,8 @@ class TestObjectsCassandraStorageLogic:
|
|||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.session = MagicMock()
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.known_tables = {"test_user": set()} # Pre-populate
|
||||
|
||||
# Create test object
|
||||
test_obj = ExtractedObject(
|
||||
|
|
@ -291,18 +293,19 @@ class TestObjectsCassandraStorageLogic:
|
|||
"""Test that secondary indexes are created for indexed fields"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.known_keyspaces = set()
|
||||
processor.known_tables = {}
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.known_tables = {"test_user": set()} # Pre-populate
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
||||
def mock_ensure_keyspace(keyspace):
|
||||
processor.known_keyspaces.add(keyspace)
|
||||
processor.known_tables[keyspace] = set()
|
||||
if keyspace not in processor.known_tables:
|
||||
processor.known_tables[keyspace] = set()
|
||||
processor.ensure_keyspace = mock_ensure_keyspace
|
||||
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
|
||||
|
||||
|
||||
# Create schema with indexed field
|
||||
schema = RowSchema(
|
||||
name="products",
|
||||
|
|
@ -313,10 +316,10 @@ class TestObjectsCassandraStorageLogic:
|
|||
Field(name="price", type="float", size=8, indexed=True)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# Call ensure_table
|
||||
processor.ensure_table("test_user", "products", schema)
|
||||
|
||||
|
||||
# Should have 3 calls: create table + 2 indexes
|
||||
assert processor.session.execute.call_count == 3
|
||||
|
||||
|
|
@ -346,9 +349,10 @@ class TestObjectsCassandraStorageBatchLogic:
|
|||
]
|
||||
)
|
||||
}
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.ensure_table = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.session = MagicMock()
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
|
|
@ -415,6 +419,8 @@ class TestObjectsCassandraStorageBatchLogic:
|
|||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.session = MagicMock()
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.known_tables = {"test_user": set()} # Pre-populate
|
||||
|
||||
# Create empty batch object
|
||||
empty_batch_obj = ExtractedObject(
|
||||
|
|
@ -461,6 +467,8 @@ class TestObjectsCassandraStorageBatchLogic:
|
|||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.session = MagicMock()
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.known_tables = {"test_user": set()} # Pre-populate
|
||||
|
||||
# Create single-item batch object (backward compatibility case)
|
||||
single_batch_obj = ExtractedObject(
|
||||
|
|
|
|||
|
|
@ -194,7 +194,13 @@ class TestFalkorDBStorageProcessor:
|
|||
mock_result.run_time_ms = 10
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
await processor.store_triples(message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
|
||||
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify queries were called in the correct order
|
||||
expected_calls = [
|
||||
|
|
@ -225,7 +231,13 @@ class TestFalkorDBStorageProcessor:
|
|||
mock_result.run_time_ms = 10
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
|
||||
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify queries were called in the correct order
|
||||
expected_calls = [
|
||||
|
|
@ -273,7 +285,13 @@ class TestFalkorDBStorageProcessor:
|
|||
mock_result.run_time_ms = 10
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
await processor.store_triples(message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
|
||||
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify total number of queries (3 per triple)
|
||||
assert processor.io.query.call_count == 6
|
||||
|
|
@ -299,7 +317,13 @@ class TestFalkorDBStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
message.triples = []
|
||||
|
||||
await processor.store_triples(message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
|
||||
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify no queries were made
|
||||
processor.io.query.assert_not_called()
|
||||
|
|
@ -329,7 +353,13 @@ class TestFalkorDBStorageProcessor:
|
|||
mock_result.run_time_ms = 10
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
await processor.store_triples(message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
|
||||
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify total number of queries (3 per triple)
|
||||
assert processor.io.query.call_count == 6
|
||||
|
|
|
|||
|
|
@ -308,7 +308,13 @@ class TestMemgraphStorageProcessor:
|
|||
# Reset the mock to clear initialization calls
|
||||
processor.io.execute_query.reset_mock()
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
|
||||
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify execute_query was called for create_node, create_literal, and relate_literal
|
||||
# (since mock_message has a literal object)
|
||||
|
|
@ -352,7 +358,13 @@ class TestMemgraphStorageProcessor:
|
|||
)
|
||||
message.triples = [triple1, triple2]
|
||||
|
||||
await processor.store_triples(message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
|
||||
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify execute_query was called:
|
||||
# Triple1: create_node(s) + create_literal(o) + relate_literal = 3 calls
|
||||
|
|
@ -381,7 +393,13 @@ class TestMemgraphStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
message.triples = []
|
||||
|
||||
await processor.store_triples(message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
|
||||
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify no session calls were made (no triples to process)
|
||||
processor.io.session.assert_not_called()
|
||||
|
|
|
|||
|
|
@ -268,7 +268,9 @@ class TestNeo4jStorageProcessor:
|
|||
mock_message.metadata.user = "test_user"
|
||||
mock_message.metadata.collection = "test_collection"
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify create_node was called for subject and object
|
||||
# Verify relate_node was called
|
||||
|
|
@ -336,7 +338,9 @@ class TestNeo4jStorageProcessor:
|
|||
mock_message.metadata.user = "test_user"
|
||||
mock_message.metadata.collection = "test_collection"
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify create_node was called for subject
|
||||
# Verify create_literal was called for object
|
||||
|
|
@ -411,7 +415,9 @@ class TestNeo4jStorageProcessor:
|
|||
mock_message.metadata.user = "test_user"
|
||||
mock_message.metadata.collection = "test_collection"
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Should have processed both triples
|
||||
# Triple1: 2 nodes + 1 relationship = 3 calls
|
||||
|
|
@ -437,7 +443,9 @@ class TestNeo4jStorageProcessor:
|
|||
mock_message.metadata.user = "test_user"
|
||||
mock_message.metadata.collection = "test_collection"
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Should not have made any execute_query calls beyond index creation
|
||||
# Only index creation calls should have been made during initialization
|
||||
|
|
@ -552,7 +560,9 @@ class TestNeo4jStorageProcessor:
|
|||
mock_message.metadata.user = "test_user"
|
||||
mock_message.metadata.collection = "test_collection"
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify the triple was processed with special characters preserved
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-4'
|
||||
assert processor.default_model == 'gpt-4'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4192
|
||||
assert hasattr(processor, 'openai')
|
||||
|
|
@ -254,7 +254,7 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-35-turbo'
|
||||
assert processor.default_model == 'gpt-35-turbo'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 2048
|
||||
mock_azure_openai_class.assert_called_once_with(
|
||||
|
|
@ -289,7 +289,7 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-4'
|
||||
assert processor.default_model == 'gpt-4'
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 4192 # default_max_output
|
||||
mock_azure_openai_class.assert_called_once_with(
|
||||
|
|
@ -402,6 +402,156 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert call_args[1]['max_tokens'] == 1024
|
||||
assert call_args[1]['top_p'] == 1
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = 'Response with custom temperature'
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 12
|
||||
|
||||
mock_azure_client.chat.completions.create.return_value = mock_response
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify Azure OpenAI API was called with overridden temperature
|
||||
call_args = mock_azure_client.chat.completions.create.call_args
|
||||
assert call_args[1]['temperature'] == 0.8 # Should use runtime override
|
||||
assert call_args[1]['model'] == 'gpt-4'
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = 'Response with custom model'
|
||||
mock_response.usage.prompt_tokens = 18
|
||||
mock_response.usage.completion_tokens = 14
|
||||
|
||||
mock_azure_client.chat.completions.create.return_value = mock_response
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4', # Default model
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.1, # Default temperature
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gpt-4o", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Azure OpenAI API was called with overridden model
|
||||
call_args = mock_azure_client.chat.completions.create.call_args
|
||||
assert call_args[1]['model'] == 'gpt-4o' # Should use runtime override
|
||||
assert call_args[1]['temperature'] == 0.1 # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = 'Response with both overrides'
|
||||
mock_response.usage.prompt_tokens = 22
|
||||
mock_response.usage.completion_tokens = 16
|
||||
|
||||
mock_azure_client.chat.completions.create.return_value = mock_response
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4', # Default model
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gpt-4o-mini", # Override model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Azure OpenAI API was called with both overrides
|
||||
call_args = mock_azure_client.chat.completions.create.call_args
|
||||
assert call_args[1]['model'] == 'gpt-4o-mini' # Should use runtime override
|
||||
assert call_args[1]['temperature'] == 0.9 # Should use runtime override
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -43,7 +43,7 @@ class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert processor.token == 'test-token'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4192
|
||||
assert processor.model == 'AzureAI'
|
||||
assert processor.default_model == 'AzureAI'
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -261,7 +261,7 @@ class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert processor.token == 'custom-token'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 2048
|
||||
assert processor.model == 'AzureAI'
|
||||
assert processor.default_model == 'AzureAI'
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -289,7 +289,7 @@ class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert processor.token == 'test-token'
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 4192 # default_max_output
|
||||
assert processor.model == 'AzureAI' # default_model
|
||||
assert processor.default_model == 'AzureAI' # default_model
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -459,5 +459,150 @@ class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
|
|||
)
|
||||
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test generate_content with model parameter override"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response with model override'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 15,
|
||||
'completion_tokens': 10
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model
|
||||
result = await processor.generate_content("System", "Prompt", model="custom-azure-model")
|
||||
|
||||
# Assert
|
||||
assert result.model == "custom-azure-model" # Should use overridden model
|
||||
assert result.text == "Response with model override"
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test generate_content with temperature parameter override"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response with temperature override'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 15,
|
||||
'completion_tokens': 10
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature
|
||||
result = await processor.generate_content("System", "Prompt", temperature=0.8)
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with temperature override"
|
||||
|
||||
# Verify the request was made with the overridden temperature
|
||||
mock_requests.post.assert_called_once()
|
||||
call_args = mock_requests.post.call_args
|
||||
|
||||
import json
|
||||
request_body = json.loads(call_args[1]['data'])
|
||||
assert request_body['temperature'] == 0.8
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test generate_content with both model and temperature overrides"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response with both parameters override'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 18,
|
||||
'completion_tokens': 12
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters
|
||||
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.9)
|
||||
|
||||
# Assert
|
||||
assert result.model == "override-model"
|
||||
assert result.text == "Response with both parameters override"
|
||||
|
||||
# Verify the request was made with overridden temperature
|
||||
mock_requests.post.assert_called_once()
|
||||
call_args = mock_requests.post.call_args
|
||||
|
||||
import json
|
||||
request_body = json.loads(call_args[1]['data'])
|
||||
assert request_body['temperature'] == 0.9
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
280
tests/unit/test_text_completion/test_bedrock_processor.py
Normal file
280
tests/unit/test_text_completion/test_bedrock_processor.py
Normal file
|
|
@ -0,0 +1,280 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.bedrock
|
||||
Following the same successful pattern as other processor tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
import json
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.bedrock.llm import Processor, Mistral, Anthropic
|
||||
from trustgraph.base import LlmResult
|
||||
|
||||
|
||||
class TestBedrockProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Bedrock processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_bedrock = MagicMock()
|
||||
mock_session.client.return_value = mock_bedrock
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'mistral.mistral-large-2407-v1:0',
|
||||
'temperature': 0.1,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.default_model == 'mistral.mistral-large-2407-v1:0'
|
||||
assert processor.temperature == 0.1
|
||||
assert hasattr(processor, 'bedrock')
|
||||
mock_session_class.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success_mistral(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test successful content generation with Mistral model"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_bedrock = MagicMock()
|
||||
mock_session.client.return_value = mock_bedrock
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_response = {
|
||||
'body': MagicMock(),
|
||||
'ResponseMetadata': {
|
||||
'HTTPHeaders': {
|
||||
'x-amzn-bedrock-input-token-count': '15',
|
||||
'x-amzn-bedrock-output-token-count': '8'
|
||||
}
|
||||
}
|
||||
}
|
||||
mock_response['body'].read.return_value = json.dumps({
|
||||
'outputs': [{'text': 'Generated response from Bedrock'}]
|
||||
})
|
||||
mock_bedrock.invoke_model.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'mistral.mistral-large-2407-v1:0',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Bedrock"
|
||||
assert result.in_token == 15
|
||||
assert result.out_token == 8
|
||||
assert result.model == 'mistral.mistral-large-2407-v1:0'
|
||||
mock_bedrock.invoke_model.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_bedrock = MagicMock()
|
||||
mock_session.client.return_value = mock_bedrock
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_response = {
|
||||
'body': MagicMock(),
|
||||
'ResponseMetadata': {
|
||||
'HTTPHeaders': {
|
||||
'x-amzn-bedrock-input-token-count': '20',
|
||||
'x-amzn-bedrock-output-token-count': '12'
|
||||
}
|
||||
}
|
||||
}
|
||||
mock_response['body'].read.return_value = json.dumps({
|
||||
'outputs': [{'text': 'Response with custom temperature'}]
|
||||
})
|
||||
mock_bedrock.invoke_model.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'mistral.mistral-large-2407-v1:0',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify the model variant was created with overridden temperature
|
||||
# The cache key should include the temperature
|
||||
cache_key = f"mistral.mistral-large-2407-v1:0:0.8"
|
||||
assert cache_key in processor.model_variants
|
||||
variant = processor.model_variants[cache_key]
|
||||
assert variant.temperature == 0.8
|
||||
|
||||
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_bedrock = MagicMock()
|
||||
mock_session.client.return_value = mock_bedrock
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_response = {
|
||||
'body': MagicMock(),
|
||||
'ResponseMetadata': {
|
||||
'HTTPHeaders': {
|
||||
'x-amzn-bedrock-input-token-count': '18',
|
||||
'x-amzn-bedrock-output-token-count': '14'
|
||||
}
|
||||
}
|
||||
}
|
||||
mock_response['body'].read.return_value = json.dumps({
|
||||
'content': [{'text': 'Response with custom model'}]
|
||||
})
|
||||
mock_bedrock.invoke_model.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'mistral.mistral-large-2407-v1:0', # Default model
|
||||
'temperature': 0.1, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="anthropic.claude-3-sonnet-20240229-v1:0", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Bedrock API was called with overridden model
|
||||
mock_bedrock.invoke_model.assert_called_once()
|
||||
call_args = mock_bedrock.invoke_model.call_args
|
||||
assert call_args[1]['modelId'] == "anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
|
||||
# Verify the correct model variant (Anthropic) was used
|
||||
cache_key = f"anthropic.claude-3-sonnet-20240229-v1:0:0.1"
|
||||
assert cache_key in processor.model_variants
|
||||
variant = processor.model_variants[cache_key]
|
||||
assert isinstance(variant, Anthropic)
|
||||
|
||||
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_bedrock = MagicMock()
|
||||
mock_session.client.return_value = mock_bedrock
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_response = {
|
||||
'body': MagicMock(),
|
||||
'ResponseMetadata': {
|
||||
'HTTPHeaders': {
|
||||
'x-amzn-bedrock-input-token-count': '22',
|
||||
'x-amzn-bedrock-output-token-count': '16'
|
||||
}
|
||||
}
|
||||
}
|
||||
mock_response['body'].read.return_value = json.dumps({
|
||||
'generation': 'Response with both overrides'
|
||||
})
|
||||
mock_bedrock.invoke_model.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'mistral.mistral-large-2407-v1:0', # Default model
|
||||
'temperature': 0.0, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="meta.llama3-70b-instruct-v1:0", # Override model (Meta/Llama)
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Bedrock API was called with both overrides
|
||||
mock_bedrock.invoke_model.assert_called_once()
|
||||
call_args = mock_bedrock.invoke_model.call_args
|
||||
assert call_args[1]['modelId'] == "meta.llama3-70b-instruct-v1:0"
|
||||
|
||||
# Verify the correct model variant (Meta) was used with correct temperature
|
||||
cache_key = f"meta.llama3-70b-instruct-v1:0:0.9"
|
||||
assert cache_key in processor.model_variants
|
||||
variant = processor.model_variants[cache_key]
|
||||
assert variant.temperature == 0.9
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -42,7 +42,7 @@ class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'claude-3-5-sonnet-20240620'
|
||||
assert processor.default_model == 'claude-3-5-sonnet-20240620'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 8192
|
||||
assert hasattr(processor, 'claude')
|
||||
|
|
@ -217,7 +217,7 @@ class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'claude-3-haiku-20240307'
|
||||
assert processor.default_model == 'claude-3-haiku-20240307'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 4096
|
||||
mock_anthropic_class.assert_called_once_with(api_key='custom-api-key')
|
||||
|
|
@ -246,7 +246,7 @@ class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'claude-3-5-sonnet-20240620' # default_model
|
||||
assert processor.default_model == 'claude-3-5-sonnet-20240620' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 8192 # default_max_output
|
||||
mock_anthropic_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
|
@ -433,7 +433,157 @@ class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
|
|||
|
||||
# Verify processor has the client
|
||||
assert processor.claude == mock_claude_client
|
||||
assert processor.model == 'claude-3-opus-20240229'
|
||||
assert processor.default_model == 'claude-3-opus-20240229'
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Response with custom temperature"
|
||||
mock_response.usage.input_tokens = 20
|
||||
mock_response.usage.output_tokens = 12
|
||||
|
||||
mock_claude_client.messages.create.return_value = mock_response
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify Claude API was called with overridden temperature
|
||||
mock_claude_client.messages.create.assert_called_once()
|
||||
call_kwargs = mock_claude_client.messages.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['temperature'] == 0.9 # Should use runtime override
|
||||
assert call_kwargs['model'] == 'claude-3-5-sonnet-20240620' # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Response with custom model"
|
||||
mock_response.usage.input_tokens = 18
|
||||
mock_response.usage.output_tokens = 14
|
||||
|
||||
mock_claude_client.messages.create.return_value = mock_response
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.2, # Default temperature
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="claude-3-haiku-20240307", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Claude API was called with overridden model
|
||||
mock_claude_client.messages.create.assert_called_once()
|
||||
call_kwargs = mock_claude_client.messages.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['model'] == 'claude-3-haiku-20240307' # Should use runtime override
|
||||
assert call_kwargs['temperature'] == 0.2 # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Response with both overrides"
|
||||
mock_response.usage.input_tokens = 22
|
||||
mock_response.usage.output_tokens = 16
|
||||
|
||||
mock_claude_client.messages.create.return_value = mock_response
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="claude-3-opus-20240229", # Override model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Claude API was called with both overrides
|
||||
mock_claude_client.messages.create.assert_called_once()
|
||||
call_kwargs = mock_claude_client.messages.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['model'] == 'claude-3-opus-20240229' # Should use runtime override
|
||||
assert call_kwargs['temperature'] == 0.8 # Should use runtime override
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'c4ai-aya-23-8b'
|
||||
assert processor.default_model == 'c4ai-aya-23-8b'
|
||||
assert processor.temperature == 0.0
|
||||
assert hasattr(processor, 'cohere')
|
||||
mock_cohere_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
|
@ -201,7 +201,7 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'command-light'
|
||||
assert processor.default_model == 'command-light'
|
||||
assert processor.temperature == 0.7
|
||||
mock_cohere_class.assert_called_once_with(api_key='custom-api-key')
|
||||
|
||||
|
|
@ -229,7 +229,7 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'c4ai-aya-23-8b' # default_model
|
||||
assert processor.default_model == 'c4ai-aya-23-8b' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
mock_cohere_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
|
|
@ -395,7 +395,7 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
|
|||
|
||||
# Verify processor has the client
|
||||
assert processor.cohere == mock_cohere_client
|
||||
assert processor.model == 'command-r'
|
||||
assert processor.default_model == 'command-r'
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -442,6 +442,162 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert call_args[1]['prompt_truncation'] == 'auto'
|
||||
assert call_args[1]['connectors'] == []
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = 'Response with custom temperature'
|
||||
mock_output.meta.billed_units.input_tokens = 20
|
||||
mock_output.meta.billed_units.output_tokens = 12
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify Cohere API was called with overridden temperature
|
||||
mock_cohere_client.chat.assert_called_once_with(
|
||||
model='c4ai-aya-23-8b',
|
||||
message='User prompt',
|
||||
preamble='System prompt',
|
||||
temperature=0.8, # Should use runtime override
|
||||
chat_history=[],
|
||||
prompt_truncation='auto',
|
||||
connectors=[]
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = 'Response with custom model'
|
||||
mock_output.meta.billed_units.input_tokens = 18
|
||||
mock_output.meta.billed_units.output_tokens = 14
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.1, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="command-r-plus", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Cohere API was called with overridden model
|
||||
mock_cohere_client.chat.assert_called_once_with(
|
||||
model='command-r-plus', # Should use runtime override
|
||||
message='User prompt',
|
||||
preamble='System prompt',
|
||||
temperature=0.1, # Should use processor default
|
||||
chat_history=[],
|
||||
prompt_truncation='auto',
|
||||
connectors=[]
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = 'Response with both overrides'
|
||||
mock_output.meta.billed_units.input_tokens = 22
|
||||
mock_output.meta.billed_units.output_tokens = 16
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="command-r", # Override model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Cohere API was called with both overrides
|
||||
mock_cohere_client.chat.assert_called_once_with(
|
||||
model='command-r', # Should use runtime override
|
||||
message='User prompt',
|
||||
preamble='System prompt',
|
||||
temperature=0.9, # Should use runtime override
|
||||
chat_history=[],
|
||||
prompt_truncation='auto',
|
||||
connectors=[]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -42,7 +42,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-2.0-flash-001'
|
||||
assert processor.default_model == 'gemini-2.0-flash-001'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 8192
|
||||
assert hasattr(processor, 'client')
|
||||
|
|
@ -205,7 +205,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-1.5-pro'
|
||||
assert processor.default_model == 'gemini-1.5-pro'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 4096
|
||||
mock_genai_class.assert_called_once_with(api_key='custom-api-key')
|
||||
|
|
@ -234,7 +234,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-2.0-flash-001' # default_model
|
||||
assert processor.default_model == 'gemini-2.0-flash-001' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 8192 # default_max_output
|
||||
mock_genai_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
|
@ -431,7 +431,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
|
||||
# Verify processor has the client
|
||||
assert processor.client == mock_genai_client
|
||||
assert processor.model == 'gemini-1.5-flash'
|
||||
assert processor.default_model == 'gemini-1.5-flash'
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -477,6 +477,156 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
# The system instruction should be in the config object
|
||||
assert call_args[1]['contents'] == "Explain quantum computing"
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = 'Response with custom temperature'
|
||||
mock_response.usage_metadata.prompt_token_count = 20
|
||||
mock_response.usage_metadata.candidates_token_count = 12
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify the generation config was created with overridden temperature
|
||||
cache_key = f"gemini-2.0-flash-001:0.8"
|
||||
assert cache_key in processor.generation_configs
|
||||
config_obj = processor.generation_configs[cache_key]
|
||||
assert config_obj.temperature == 0.8
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = 'Response with custom model'
|
||||
mock_response.usage_metadata.prompt_token_count = 18
|
||||
mock_response.usage_metadata.candidates_token_count = 14
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.1, # Default temperature
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gemini-1.5-pro", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Google AI Studio API was called with overridden model
|
||||
call_args = mock_genai_client.models.generate_content.call_args
|
||||
assert call_args[1]['model'] == 'gemini-1.5-pro' # Should use runtime override
|
||||
|
||||
# Verify the generation config was created for the correct model
|
||||
cache_key = f"gemini-1.5-pro:0.1"
|
||||
assert cache_key in processor.generation_configs
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = 'Response with both overrides'
|
||||
mock_response.usage_metadata.prompt_token_count = 22
|
||||
mock_response.usage_metadata.candidates_token_count = 16
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gemini-1.5-flash", # Override model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Google AI Studio API was called with both overrides
|
||||
call_args = mock_genai_client.models.generate_content.call_args
|
||||
assert call_args[1]['model'] == 'gemini-1.5-flash' # Should use runtime override
|
||||
|
||||
# Verify the generation config was created with both overrides
|
||||
cache_key = f"gemini-1.5-flash:0.9"
|
||||
assert cache_key in processor.generation_configs
|
||||
config_obj = processor.generation_configs[cache_key]
|
||||
assert config_obj.temperature == 0.9
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -42,7 +42,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'LLaMA_CPP'
|
||||
assert processor.default_model == 'LLaMA_CPP'
|
||||
assert processor.llamafile == 'http://localhost:8080/v1'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4096
|
||||
|
|
@ -91,7 +91,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert result.text == "Generated response from LlamaFile"
|
||||
assert result.in_token == 20
|
||||
assert result.out_token == 12
|
||||
assert result.model == 'llama.cpp' # Note: model in result is hardcoded to 'llama.cpp'
|
||||
assert result.model == 'LLaMA_CPP' # Uses the default model name
|
||||
|
||||
# Verify the OpenAI API call structure
|
||||
mock_openai_client.chat.completions.create.assert_called_once_with(
|
||||
|
|
@ -99,7 +99,15 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
messages=[{
|
||||
"role": "user",
|
||||
"content": "System prompt\n\nUser prompt"
|
||||
}]
|
||||
}],
|
||||
temperature=0.0,
|
||||
max_tokens=4096,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
response_format={
|
||||
"type": "text"
|
||||
}
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
|
|
@ -157,7 +165,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'custom-llama'
|
||||
assert processor.default_model == 'custom-llama'
|
||||
assert processor.llamafile == 'http://custom-host:8080/v1'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 2048
|
||||
|
|
@ -189,7 +197,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'LLaMA_CPP' # default_model
|
||||
assert processor.default_model == 'LLaMA_CPP' # default_model
|
||||
assert processor.llamafile == 'http://localhost:8080/v1' # default_llamafile
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 4096 # default_max_output
|
||||
|
|
@ -237,7 +245,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'llama.cpp'
|
||||
assert result.model == 'LLaMA_CPP'
|
||||
|
||||
# Verify the combined prompt is sent correctly
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
|
|
@ -408,8 +416,8 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
result = await processor.generate_content("System", "User")
|
||||
|
||||
# Assert
|
||||
assert result.model == 'llama.cpp' # Should always be 'llama.cpp', not 'custom-model-name'
|
||||
assert processor.model == 'custom-model-name' # But processor.model should still be custom
|
||||
assert result.model == 'custom-model-name' # Uses the actual model name passed to generate_content
|
||||
assert processor.default_model == 'custom-model-name' # But processor.model should still be custom
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -450,5 +458,132 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
# No specific rate limit error handling tested since SLM presumably has no rate limits
|
||||
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test generate_content with model parameter override"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response from overridden model"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model
|
||||
result = await processor.generate_content("System", "Prompt", model="custom-llamafile-model")
|
||||
|
||||
# Assert
|
||||
assert result.model == "custom-llamafile-model" # Should use overridden model
|
||||
assert result.text == "Response from overridden model"
|
||||
|
||||
# Verify the API call was made with overridden model
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
assert call_args[1]['model'] == "custom-llamafile-model"
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test generate_content with temperature parameter override"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with temperature override"
|
||||
mock_response.usage.prompt_tokens = 18
|
||||
mock_response.usage.completion_tokens = 12
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature
|
||||
result = await processor.generate_content("System", "Prompt", temperature=0.7)
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with temperature override"
|
||||
|
||||
# Verify the API call was made with overridden temperature
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
assert call_args[1]['temperature'] == 0.7
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test generate_content with both model and temperature overrides"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with both parameters override"
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 15
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters
|
||||
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
|
||||
|
||||
# Assert
|
||||
assert result.model == "override-model"
|
||||
assert result.text == "Response with both parameters override"
|
||||
|
||||
# Verify the API call was made with overridden parameters
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
assert call_args[1]['model'] == "override-model"
|
||||
assert call_args[1]['temperature'] == 0.8
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
229
tests/unit/test_text_completion/test_lmstudio_processor.py
Normal file
229
tests/unit/test_text_completion/test_lmstudio_processor.py
Normal file
|
|
@ -0,0 +1,229 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.lmstudio
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.lmstudio.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestLMStudioProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test LMStudio processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_openai = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemma3:9b',
|
||||
'url': 'http://localhost:1234/',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.default_model == 'gemma3:9b'
|
||||
assert processor.url == 'http://localhost:1234/v1/'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4096
|
||||
assert hasattr(processor, 'openai')
|
||||
mock_openai_class.assert_called_once_with(
|
||||
base_url='http://localhost:1234/v1/',
|
||||
api_key='sk-no-key-required'
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_openai = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Generated response from LMStudio'
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 12
|
||||
|
||||
mock_openai.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemma3:9b',
|
||||
'url': 'http://localhost:1234/',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from LMStudio"
|
||||
assert result.in_token == 20
|
||||
assert result.out_token == 12
|
||||
assert result.model == 'gemma3:9b'
|
||||
|
||||
# Verify the API call was made correctly
|
||||
mock_openai.chat.completions.create.assert_called_once()
|
||||
call_args = mock_openai.chat.completions.create.call_args
|
||||
|
||||
# Check model and temperature
|
||||
assert call_args[1]['model'] == 'gemma3:9b'
|
||||
assert call_args[1]['temperature'] == 0.0
|
||||
assert call_args[1]['max_tokens'] == 4096
|
||||
|
||||
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test generate_content with model parameter override"""
|
||||
# Arrange
|
||||
mock_openai = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response from overridden model'
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemma3:9b',
|
||||
'url': 'http://localhost:1234/',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model
|
||||
result = await processor.generate_content("System", "Prompt", model="custom-lmstudio-model")
|
||||
|
||||
# Assert
|
||||
assert result.model == "custom-lmstudio-model" # Should use overridden model
|
||||
assert result.text == "Response from overridden model"
|
||||
|
||||
# Verify the API call was made with overridden model
|
||||
call_args = mock_openai.chat.completions.create.call_args
|
||||
assert call_args[1]['model'] == "custom-lmstudio-model"
|
||||
|
||||
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test generate_content with temperature parameter override"""
|
||||
# Arrange
|
||||
mock_openai = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response with temperature override'
|
||||
mock_response.usage.prompt_tokens = 18
|
||||
mock_response.usage.completion_tokens = 12
|
||||
|
||||
mock_openai.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemma3:9b',
|
||||
'url': 'http://localhost:1234/',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature
|
||||
result = await processor.generate_content("System", "Prompt", temperature=0.7)
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with temperature override"
|
||||
|
||||
# Verify the API call was made with overridden temperature
|
||||
call_args = mock_openai.chat.completions.create.call_args
|
||||
assert call_args[1]['temperature'] == 0.7
|
||||
|
||||
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test generate_content with both model and temperature overrides"""
|
||||
# Arrange
|
||||
mock_openai = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response with both parameters override'
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 15
|
||||
|
||||
mock_openai.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemma3:9b',
|
||||
'url': 'http://localhost:1234/',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters
|
||||
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
|
||||
|
||||
# Assert
|
||||
assert result.model == "override-model"
|
||||
assert result.text == "Response with both parameters override"
|
||||
|
||||
# Verify the API call was made with overridden parameters
|
||||
call_args = mock_openai.chat.completions.create.call_args
|
||||
assert call_args[1]['model'] == "override-model"
|
||||
assert call_args[1]['temperature'] == 0.8
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
275
tests/unit/test_text_completion/test_mistral_processor.py
Normal file
275
tests/unit/test_text_completion/test_mistral_processor.py
Normal file
|
|
@ -0,0 +1,275 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.mistral
|
||||
Following the same successful pattern as other processor tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.mistral.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
|
||||
|
||||
class TestMistralProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Mistral processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_mistral_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_mistral_class.return_value = mock_mistral_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'ministral-8b-latest',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.1,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.default_model == 'ministral-8b-latest'
|
||||
assert processor.temperature == 0.1
|
||||
assert processor.max_output == 2048
|
||||
assert hasattr(processor, 'mistral')
|
||||
mock_mistral_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_mistral_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Generated response from Mistral'
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 8
|
||||
mock_mistral_client.chat.complete.return_value = mock_response
|
||||
mock_mistral_class.return_value = mock_mistral_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'ministral-8b-latest',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Mistral"
|
||||
assert result.in_token == 15
|
||||
assert result.out_token == 8
|
||||
assert result.model == 'ministral-8b-latest'
|
||||
mock_mistral_client.chat.complete.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_mistral_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response with custom temperature'
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 12
|
||||
mock_mistral_client.chat.complete.return_value = mock_response
|
||||
mock_mistral_class.return_value = mock_mistral_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'ministral-8b-latest',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify Mistral API was called with overridden temperature
|
||||
call_args = mock_mistral_client.chat.complete.call_args
|
||||
assert call_args[1]['temperature'] == 0.8 # Should use runtime override
|
||||
assert call_args[1]['model'] == 'ministral-8b-latest'
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_mistral_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response with custom model'
|
||||
mock_response.usage.prompt_tokens = 18
|
||||
mock_response.usage.completion_tokens = 14
|
||||
mock_mistral_client.chat.complete.return_value = mock_response
|
||||
mock_mistral_class.return_value = mock_mistral_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'ministral-8b-latest', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.1, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="mistral-large-latest", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Mistral API was called with overridden model
|
||||
call_args = mock_mistral_client.chat.complete.call_args
|
||||
assert call_args[1]['model'] == 'mistral-large-latest' # Should use runtime override
|
||||
assert call_args[1]['temperature'] == 0.1 # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_mistral_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response with both overrides'
|
||||
mock_response.usage.prompt_tokens = 22
|
||||
mock_response.usage.completion_tokens = 16
|
||||
mock_mistral_client.chat.complete.return_value = mock_response
|
||||
mock_mistral_class.return_value = mock_mistral_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'ministral-8b-latest', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="mistral-large-latest", # Override model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Mistral API was called with both overrides
|
||||
call_args = mock_mistral_client.chat.complete.call_args
|
||||
assert call_args[1]['model'] == 'mistral-large-latest' # Should use runtime override
|
||||
assert call_args[1]['temperature'] == 0.9 # Should use runtime override
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_mistral_class):
|
||||
"""Test prompt construction with system and user prompts"""
|
||||
# Arrange
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response with system instructions'
|
||||
mock_response.usage.prompt_tokens = 25
|
||||
mock_response.usage.completion_tokens = 15
|
||||
mock_mistral_client.chat.complete.return_value = mock_response
|
||||
mock_mistral_class.return_value = mock_mistral_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'ministral-8b-latest',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with system instructions"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
|
||||
# Verify the combined prompt structure
|
||||
call_args = mock_mistral_client.chat.complete.call_args
|
||||
messages = call_args[1]['messages']
|
||||
assert len(messages) == 1
|
||||
assert messages[0]['role'] == 'user'
|
||||
assert messages[0]['content'][0]['type'] == 'text'
|
||||
assert messages[0]['content'][0]['text'] == "You are a helpful assistant\n\nWhat is AI?"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -40,7 +40,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'llama2'
|
||||
assert processor.default_model == 'llama2'
|
||||
assert hasattr(processor, 'llm')
|
||||
mock_client_class.assert_called_once_with(host='http://localhost:11434')
|
||||
|
||||
|
|
@ -81,7 +81,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert result.in_token == 15
|
||||
assert result.out_token == 8
|
||||
assert result.model == 'llama2'
|
||||
mock_client.generate.assert_called_once_with('llama2', "System prompt\n\nUser prompt")
|
||||
mock_client.generate.assert_called_once_with('llama2', "System prompt\n\nUser prompt", options={'temperature': 0.0})
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -134,7 +134,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'mistral'
|
||||
assert processor.default_model == 'mistral'
|
||||
mock_client_class.assert_called_once_with(host='http://192.168.1.100:11434')
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
|
|
@ -160,7 +160,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemma2:9b' # default_model
|
||||
assert processor.default_model == 'gemma2:9b' # default_model
|
||||
# Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env)
|
||||
mock_client_class.assert_called_once()
|
||||
|
||||
|
|
@ -203,7 +203,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert result.model == 'llama2'
|
||||
|
||||
# The prompt should be "" + "\n\n" + "" = "\n\n"
|
||||
mock_client.generate.assert_called_once_with('llama2', "\n\n")
|
||||
mock_client.generate.assert_called_once_with('llama2', "\n\n", options={'temperature': 0.0})
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -310,7 +310,151 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert result.out_token == 15
|
||||
|
||||
# Verify the combined prompt
|
||||
mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?")
|
||||
mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?", options={'temperature': 0.0})
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Response with custom temperature',
|
||||
'prompt_eval_count': 20,
|
||||
'eval_count': 12
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2',
|
||||
'ollama': 'http://localhost:11434',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify Ollama API was called with overridden temperature
|
||||
mock_client.generate.assert_called_once_with(
|
||||
'llama2',
|
||||
"System prompt\n\nUser prompt",
|
||||
options={'temperature': 0.8} # Should use runtime override
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Response with custom model',
|
||||
'prompt_eval_count': 18,
|
||||
'eval_count': 14
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2', # Default model
|
||||
'ollama': 'http://localhost:11434',
|
||||
'temperature': 0.1, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="mistral", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Ollama API was called with overridden model
|
||||
mock_client.generate.assert_called_once_with(
|
||||
'mistral', # Should use runtime override
|
||||
"System prompt\n\nUser prompt",
|
||||
options={'temperature': 0.1} # Should use processor default
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Response with both overrides',
|
||||
'prompt_eval_count': 22,
|
||||
'eval_count': 16
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2', # Default model
|
||||
'ollama': 'http://localhost:11434',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="codellama", # Override model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Ollama API was called with both overrides
|
||||
mock_client.generate.assert_called_once_with(
|
||||
'codellama', # Should use runtime override
|
||||
"System prompt\n\nUser prompt",
|
||||
options={'temperature': 0.9} # Should use runtime override
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-3.5-turbo'
|
||||
assert processor.default_model == 'gpt-3.5-turbo'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4096
|
||||
assert hasattr(processor, 'openai')
|
||||
|
|
@ -222,7 +222,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-4'
|
||||
assert processor.default_model == 'gpt-4'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 2048
|
||||
mock_openai_class.assert_called_once_with(base_url='https://custom-openai-url.com/v1', api_key='custom-api-key')
|
||||
|
|
@ -251,7 +251,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-3.5-turbo' # default_model
|
||||
assert processor.default_model == 'gpt-3.5-turbo' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 4096 # default_max_output
|
||||
mock_openai_class.assert_called_once_with(base_url='https://api.openai.com/v1', api_key='test-api-key')
|
||||
|
|
@ -391,5 +391,210 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert call_args[1]['response_format'] == {"type": "text"}
|
||||
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with custom temperature"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify the OpenAI API was called with overridden temperature
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['temperature'] == 0.9 # Should use runtime override
|
||||
assert call_kwargs['model'] == 'gpt-3.5-turbo' # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with custom model"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.2,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gpt-4", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify the OpenAI API was called with overridden model
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['model'] == 'gpt-4' # Should use runtime override
|
||||
assert call_kwargs['temperature'] == 0.2 # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with both overrides"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gpt-4", # Override model
|
||||
temperature=0.7 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify the OpenAI API was called with both overrides
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['model'] == 'gpt-4' # Should use runtime override
|
||||
assert call_kwargs['temperature'] == 0.7 # Should use runtime override
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_no_override_uses_defaults(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test that when no parameters are overridden, processor defaults are used"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with defaults"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.5, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Don't override any parameters (pass None)
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with defaults"
|
||||
|
||||
# Verify the OpenAI API was called with processor defaults
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['model'] == 'gpt-4' # Should use processor default
|
||||
assert call_kwargs['temperature'] == 0.5 # Should use processor default
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
186
tests/unit/test_text_completion/test_parameter_caching.py
Normal file
186
tests/unit/test_text_completion/test_parameter_caching.py
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
"""
|
||||
Unit tests for Parameter-Based Caching in LLM Processors
|
||||
Testing processors that cache based on temperature parameters (Bedrock, GoogleAIStudio)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
from trustgraph.model.text_completion.googleaistudio.llm import Processor as GoogleAIProcessor
|
||||
from trustgraph.base import LlmResult
|
||||
|
||||
|
||||
class TestParameterCaching(IsolatedAsyncioTestCase):
|
||||
"""Test parameter-based caching functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_googleai_temperature_cache_keys(self, mock_llm_init, mock_async_init, mock_genai):
|
||||
"""Test that GoogleAI processor creates separate cache entries for different temperatures"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Generated response"
|
||||
mock_response.usage_metadata.prompt_token_count = 10
|
||||
mock_response.usage_metadata.candidates_token_count = 5
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = GoogleAIProcessor(**config)
|
||||
|
||||
# Act - Call with different temperatures
|
||||
await processor.generate_content("System", "Prompt 1", model="gemini-2.0-flash-001", temperature=0.0)
|
||||
await processor.generate_content("System", "Prompt 2", model="gemini-2.0-flash-001", temperature=0.5)
|
||||
await processor.generate_content("System", "Prompt 3", model="gemini-2.0-flash-001", temperature=1.0)
|
||||
|
||||
# Assert - Should have 3 different cache entries
|
||||
cache_keys = list(processor.generation_configs.keys())
|
||||
|
||||
assert len(cache_keys) == 3
|
||||
assert "gemini-2.0-flash-001:0.0" in cache_keys
|
||||
assert "gemini-2.0-flash-001:0.5" in cache_keys
|
||||
assert "gemini-2.0-flash-001:1.0" in cache_keys
|
||||
|
||||
# Verify each cached config has the correct temperature
|
||||
assert processor.generation_configs["gemini-2.0-flash-001:0.0"].temperature == 0.0
|
||||
assert processor.generation_configs["gemini-2.0-flash-001:0.5"].temperature == 0.5
|
||||
assert processor.generation_configs["gemini-2.0-flash-001:1.0"].temperature == 1.0
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_googleai_cache_reuse_same_parameters(self, mock_llm_init, mock_async_init, mock_genai):
|
||||
"""Test that GoogleAI processor reuses cache for identical model+temperature combinations"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Generated response"
|
||||
mock_response.usage_metadata.prompt_token_count = 10
|
||||
mock_response.usage_metadata.candidates_token_count = 5
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = GoogleAIProcessor(**config)
|
||||
|
||||
# Act - Call multiple times with same parameters
|
||||
await processor.generate_content("System", "Prompt 1", model="gemini-2.0-flash-001", temperature=0.7)
|
||||
await processor.generate_content("System", "Prompt 2", model="gemini-2.0-flash-001", temperature=0.7)
|
||||
await processor.generate_content("System", "Prompt 3", model="gemini-2.0-flash-001", temperature=0.7)
|
||||
|
||||
# Assert - Should have only 1 cache entry for the repeated parameters
|
||||
cache_keys = list(processor.generation_configs.keys())
|
||||
assert len(cache_keys) == 1
|
||||
assert "gemini-2.0-flash-001:0.7" in cache_keys
|
||||
|
||||
# The same config object should be reused
|
||||
config_obj = processor.generation_configs["gemini-2.0-flash-001:0.7"]
|
||||
assert config_obj.temperature == 0.7
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_googleai_different_models_separate_caches(self, mock_llm_init, mock_async_init, mock_genai):
|
||||
"""Test that different models create separate cache entries even with same temperature"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Generated response"
|
||||
mock_response.usage_metadata.prompt_token_count = 10
|
||||
mock_response.usage_metadata.candidates_token_count = 5
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = GoogleAIProcessor(**config)
|
||||
|
||||
# Act - Call with different models, same temperature
|
||||
await processor.generate_content("System", "Prompt 1", model="gemini-2.0-flash-001", temperature=0.5)
|
||||
await processor.generate_content("System", "Prompt 2", model="gemini-1.5-flash-001", temperature=0.5)
|
||||
|
||||
# Assert - Should have separate cache entries for different models
|
||||
cache_keys = list(processor.generation_configs.keys())
|
||||
assert len(cache_keys) == 2
|
||||
assert "gemini-2.0-flash-001:0.5" in cache_keys
|
||||
assert "gemini-1.5-flash-001:0.5" in cache_keys
|
||||
|
||||
# Note: Bedrock tests would be similar but testing the Bedrock processor's caching behavior
|
||||
# The Bedrock processor caches model variants with temperature in the cache key
|
||||
|
||||
async def test_bedrock_temperature_cache_keys(self):
|
||||
"""Test Bedrock processor temperature-aware caching"""
|
||||
# This would test the Bedrock processor's _get_or_create_variant method
|
||||
# with different temperature values to ensure proper cache key generation
|
||||
|
||||
# Implementation would follow similar pattern to GoogleAI tests above
|
||||
# but using the Bedrock processor and testing model_variants cache
|
||||
pass
|
||||
|
||||
async def test_bedrock_cache_isolation_different_temperatures(self):
|
||||
"""Test that Bedrock processor isolates cache entries by temperature"""
|
||||
pass
|
||||
|
||||
async def test_cache_memory_efficiency(self):
|
||||
"""Test that caches don't grow unbounded with many different parameter combinations"""
|
||||
# This could test cache size limits or cleanup behavior if implemented
|
||||
pass
|
||||
|
||||
|
||||
class TestCachePerformance(IsolatedAsyncioTestCase):
|
||||
"""Test caching performance characteristics"""
|
||||
|
||||
async def test_cache_hit_performance(self):
|
||||
"""Test that cache hits are faster than cache misses"""
|
||||
# This would measure timing differences between cache hits and misses
|
||||
pass
|
||||
|
||||
async def test_concurrent_cache_access(self):
|
||||
"""Test concurrent access to cached configurations"""
|
||||
# This would test thread-safety of cache access
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
271
tests/unit/test_text_completion/test_tgi_processor.py
Normal file
271
tests/unit/test_text_completion/test_tgi_processor.py
Normal file
|
|
@ -0,0 +1,271 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.tgi
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.tgi.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestTGIProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test TGI processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'tgi',
|
||||
'url': 'http://tgi-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.default_model == 'tgi'
|
||||
assert processor.base_url == 'http://tgi-service:8899/v1'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 2048
|
||||
assert hasattr(processor, 'session')
|
||||
mock_session_class.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Generated response from TGI'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 20,
|
||||
'completion_tokens': 12
|
||||
}
|
||||
})
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'tgi',
|
||||
'url': 'http://tgi-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from TGI"
|
||||
assert result.in_token == 20
|
||||
assert result.out_token == 12
|
||||
assert result.model == 'tgi'
|
||||
|
||||
# Verify the API call was made correctly
|
||||
mock_session.post.assert_called_once()
|
||||
call_args = mock_session.post.call_args
|
||||
|
||||
# Check URL
|
||||
assert call_args[0][0] == 'http://tgi-service:8899/v1/chat/completions'
|
||||
|
||||
# Check request structure
|
||||
request_body = call_args[1]['json']
|
||||
assert request_body['model'] == 'tgi'
|
||||
assert request_body['temperature'] == 0.0
|
||||
assert request_body['max_tokens'] == 2048
|
||||
|
||||
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test generate_content with model parameter override"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response from overridden model'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 15,
|
||||
'completion_tokens': 10
|
||||
}
|
||||
})
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'tgi',
|
||||
'url': 'http://tgi-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model
|
||||
result = await processor.generate_content("System", "Prompt", model="custom-tgi-model")
|
||||
|
||||
# Assert
|
||||
assert result.model == "custom-tgi-model" # Should use overridden model
|
||||
assert result.text == "Response from overridden model"
|
||||
|
||||
# Verify the API call was made with overridden model
|
||||
call_args = mock_session.post.call_args
|
||||
assert call_args[1]['json']['model'] == "custom-tgi-model"
|
||||
|
||||
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test generate_content with temperature parameter override"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response with temperature override'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 18,
|
||||
'completion_tokens': 12
|
||||
}
|
||||
})
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'tgi',
|
||||
'url': 'http://tgi-service:8899/v1',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature
|
||||
result = await processor.generate_content("System", "Prompt", temperature=0.7)
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with temperature override"
|
||||
|
||||
# Verify the API call was made with overridden temperature
|
||||
call_args = mock_session.post.call_args
|
||||
assert call_args[1]['json']['temperature'] == 0.7
|
||||
|
||||
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test generate_content with both model and temperature overrides"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response with both parameters override'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 20,
|
||||
'completion_tokens': 15
|
||||
}
|
||||
})
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'tgi',
|
||||
'url': 'http://tgi-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters
|
||||
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
|
||||
|
||||
# Assert
|
||||
assert result.model == "override-model"
|
||||
assert result.text == "Response with both parameters override"
|
||||
|
||||
# Verify the API call was made with overridden parameters
|
||||
call_args = mock_session.post.call_args
|
||||
assert call_args[1]['json']['model'] == "override-model"
|
||||
assert call_args[1]['json']['temperature'] == 0.8
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -47,10 +47,10 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-2.0-flash-001' # It's stored as 'model', not 'model_name'
|
||||
assert hasattr(processor, 'generation_config')
|
||||
assert processor.default_model == 'gemini-2.0-flash-001' # It's stored as 'model', not 'model_name'
|
||||
assert hasattr(processor, 'generation_configs') # Now a cache dictionary
|
||||
assert hasattr(processor, 'safety_settings')
|
||||
assert hasattr(processor, 'llm')
|
||||
assert hasattr(processor, 'model_clients') # LLM clients are now cached
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('private.json')
|
||||
mock_vertexai.init.assert_called_once()
|
||||
|
||||
|
|
@ -102,7 +102,8 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
mock_model.generate_content.assert_called_once()
|
||||
# Verify the call was made with the expected parameters
|
||||
call_args = mock_model.generate_content.call_args
|
||||
assert call_args[1]['generation_config'] == processor.generation_config
|
||||
# Generation config is now created dynamically per model
|
||||
assert 'generation_config' in call_args[1]
|
||||
assert call_args[1]['safety_settings'] == processor.safety_settings
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
|
|
@ -223,7 +224,7 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-2.0-flash-001'
|
||||
assert processor.default_model == 'gemini-2.0-flash-001'
|
||||
mock_auth_default.assert_called_once()
|
||||
mock_vertexai.init.assert_called_once_with(
|
||||
location='us-central1',
|
||||
|
|
@ -296,11 +297,11 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-1.5-pro'
|
||||
assert processor.default_model == 'gemini-1.5-pro'
|
||||
|
||||
# Verify that generation_config object exists (can't easily check internal values)
|
||||
assert hasattr(processor, 'generation_config')
|
||||
assert processor.generation_config is not None
|
||||
assert hasattr(processor, 'generation_configs') # Now a cache dictionary
|
||||
assert processor.generation_configs == {} # Empty cache initially
|
||||
|
||||
# Verify that safety settings are configured
|
||||
assert len(processor.safety_settings) == 4
|
||||
|
|
@ -353,8 +354,8 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
project='test-project-123'
|
||||
)
|
||||
|
||||
# Verify GenerativeModel was created with the right model name
|
||||
mock_generative_model.assert_called_once_with('gemini-2.0-flash-001')
|
||||
# GenerativeModel is now created lazily on first use, not at initialization
|
||||
mock_generative_model.assert_not_called()
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
|
|
@ -440,8 +441,8 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'claude-3-sonnet@20240229'
|
||||
assert processor.is_anthropic == True
|
||||
assert processor.default_model == 'claude-3-sonnet@20240229'
|
||||
# is_anthropic logic is now determined dynamically per request
|
||||
|
||||
# Verify service account was called with private key
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('anthropic-key.json')
|
||||
|
|
@ -459,6 +460,180 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert processor.api_params["top_p"] == 1.0
|
||||
assert processor.api_params["top_k"] == 32
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Response with custom temperature"
|
||||
mock_response.usage_metadata.prompt_token_count = 20
|
||||
mock_response.usage_metadata.candidates_token_count = 12
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify Gemini API was called with overridden temperature
|
||||
mock_model.generate_content.assert_called_once()
|
||||
call_args = mock_model.generate_content.call_args
|
||||
|
||||
# Check that generation_config was created (we can't directly access temperature from mock)
|
||||
generation_config = call_args.kwargs['generation_config']
|
||||
assert generation_config is not None # Should use overridden temperature configuration
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
# Mock different models
|
||||
mock_model_default = MagicMock()
|
||||
mock_model_override = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Response with custom model"
|
||||
mock_response.usage_metadata.prompt_token_count = 18
|
||||
mock_response.usage_metadata.candidates_token_count = 14
|
||||
mock_model_override.generate_content.return_value = mock_response
|
||||
|
||||
# GenerativeModel should return different models based on input
|
||||
def model_factory(model_name):
|
||||
if model_name == 'gemini-1.5-pro':
|
||||
return mock_model_override
|
||||
return mock_model_default
|
||||
|
||||
mock_generative_model.side_effect = model_factory
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001', # Default model
|
||||
'temperature': 0.2, # Default temperature
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gemini-1.5-pro", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify the overridden model was used
|
||||
mock_model_override.generate_content.assert_called_once()
|
||||
# Verify GenerativeModel was called with the override model
|
||||
mock_generative_model.assert_called_with('gemini-1.5-pro')
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Response with both overrides"
|
||||
mock_response.usage_metadata.prompt_token_count = 22
|
||||
mock_response.usage_metadata.candidates_token_count = 16
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001', # Default model
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gemini-1.5-flash-001", # Override model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify both overrides were used
|
||||
mock_model.generate_content.assert_called_once()
|
||||
call_args = mock_model.generate_content.call_args
|
||||
|
||||
# Verify model override
|
||||
mock_generative_model.assert_called_with('gemini-1.5-flash-001') # Should use runtime override
|
||||
|
||||
# Verify temperature override (we can't directly access temperature from mock)
|
||||
generation_config = call_args.kwargs['generation_config']
|
||||
assert generation_config is not None # Should use overridden temperature configuration
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -42,7 +42,7 @@ class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'TheBloke/Mistral-7B-v0.1-AWQ'
|
||||
assert processor.default_model == 'TheBloke/Mistral-7B-v0.1-AWQ'
|
||||
assert processor.base_url == 'http://vllm-service:8899/v1'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 2048
|
||||
|
|
@ -199,7 +199,7 @@ class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'custom-model'
|
||||
assert processor.default_model == 'custom-model'
|
||||
assert processor.base_url == 'http://custom-vllm:8080/v1'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 1024
|
||||
|
|
@ -228,7 +228,7 @@ class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'TheBloke/Mistral-7B-v0.1-AWQ' # default_model
|
||||
assert processor.default_model == 'TheBloke/Mistral-7B-v0.1-AWQ' # default_model
|
||||
assert processor.base_url == 'http://vllm-service:8899/v1' # default_base_url
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 2048 # default_max_output
|
||||
|
|
@ -485,5 +485,148 @@ class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert call_args[1]['json']['prompt'] == expected_prompt
|
||||
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test generate_content with model parameter override"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Response from overridden model'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 12,
|
||||
'completion_tokens': 8
|
||||
}
|
||||
})
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model
|
||||
result = await processor.generate_content("System", "Prompt", model="custom-vllm-model")
|
||||
|
||||
# Assert
|
||||
assert result.model == "custom-vllm-model" # Should use overridden model
|
||||
assert result.text == "Response from overridden model"
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test generate_content with temperature parameter override"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Response with temperature override'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 15,
|
||||
'completion_tokens': 10
|
||||
}
|
||||
})
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature
|
||||
result = await processor.generate_content("System", "Prompt", temperature=0.7)
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with temperature override"
|
||||
|
||||
# Verify the request was made with overridden temperature
|
||||
call_args = mock_session.post.call_args
|
||||
assert call_args[1]['json']['temperature'] == 0.7
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test generate_content with both model and temperature overrides"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Response with both parameters override'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 18,
|
||||
'completion_tokens': 12
|
||||
}
|
||||
})
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters
|
||||
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
|
||||
|
||||
# Assert
|
||||
assert result.model == "override-model"
|
||||
assert result.text == "Response with both parameters override"
|
||||
|
||||
# Verify the request was made with overridden temperature
|
||||
call_args = mock_session.post.call_args
|
||||
assert call_args[1]['json']['temperature'] == 0.8
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -12,6 +12,7 @@ requires-python = ">=3.8"
|
|||
dependencies = [
|
||||
"pulsar-client",
|
||||
"prometheus-client",
|
||||
"requests",
|
||||
]
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ class Flow:
|
|||
|
||||
return json.loads(self.request(request = input)["flow"])
|
||||
|
||||
def start(self, class_name, id, description):
|
||||
def start(self, class_name, id, description, parameters=None):
|
||||
|
||||
# The input consists of system and prompt strings
|
||||
input = {
|
||||
|
|
@ -97,6 +97,9 @@ class Flow:
|
|||
"description": description,
|
||||
}
|
||||
|
||||
if parameters:
|
||||
input["parameters"] = parameters
|
||||
|
||||
self.request(request = input)
|
||||
|
||||
def stop(self, id):
|
||||
|
|
|
|||
|
|
@ -8,11 +8,12 @@ from . subscriber import Subscriber
|
|||
from . metrics import ProcessorMetrics, ConsumerMetrics, ProducerMetrics
|
||||
from . flow_processor import FlowProcessor
|
||||
from . consumer_spec import ConsumerSpec
|
||||
from . setting_spec import SettingSpec
|
||||
from . parameter_spec import ParameterSpec
|
||||
from . producer_spec import ProducerSpec
|
||||
from . subscriber_spec import SubscriberSpec
|
||||
from . request_response_spec import RequestResponseSpec
|
||||
from . llm_service import LlmService, LlmResult
|
||||
from . chunking_service import ChunkingService
|
||||
from . embeddings_service import EmbeddingsService
|
||||
from . embeddings_client import EmbeddingsClientSpec
|
||||
from . text_completion_client import TextCompletionClientSpec
|
||||
|
|
|
|||
62
trustgraph-base/trustgraph/base/chunking_service.py
Normal file
62
trustgraph-base/trustgraph/base/chunking_service.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
"""
|
||||
Base chunking service that provides parameter specification functionality
|
||||
for chunk-size and chunk-overlap parameters
|
||||
"""
|
||||
|
||||
import logging
|
||||
from .flow_processor import FlowProcessor
|
||||
from .parameter_spec import ParameterSpec
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ChunkingService(FlowProcessor):
|
||||
"""Base service for chunking processors with parameter specification support"""
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
# Call parent constructor
|
||||
super(ChunkingService, self).__init__(**params)
|
||||
|
||||
# Register parameter specifications for chunk-size and chunk-overlap
|
||||
self.register_specification(
|
||||
ParameterSpec(name="chunk-size")
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ParameterSpec(name="chunk-overlap")
|
||||
)
|
||||
|
||||
logger.debug("ChunkingService initialized with parameter specifications")
|
||||
|
||||
async def chunk_document(self, msg, consumer, flow, default_chunk_size, default_chunk_overlap):
|
||||
"""
|
||||
Extract chunk parameters from flow and return effective values
|
||||
|
||||
Args:
|
||||
msg: The message containing the document to chunk
|
||||
consumer: The consumer spec
|
||||
flow: The flow context
|
||||
default_chunk_size: Default chunk size from processor config
|
||||
default_chunk_overlap: Default chunk overlap from processor config
|
||||
|
||||
Returns:
|
||||
tuple: (chunk_size, chunk_overlap) - effective values to use
|
||||
"""
|
||||
# Extract parameters from flow (flow-configurable parameters)
|
||||
chunk_size = flow("chunk-size")
|
||||
chunk_overlap = flow("chunk-overlap")
|
||||
|
||||
# Use provided values or fall back to defaults
|
||||
effective_chunk_size = chunk_size if chunk_size is not None else default_chunk_size
|
||||
effective_chunk_overlap = chunk_overlap if chunk_overlap is not None else default_chunk_overlap
|
||||
|
||||
logger.debug(f"Using chunk-size: {effective_chunk_size}")
|
||||
logger.debug(f"Using chunk-overlap: {effective_chunk_overlap}")
|
||||
|
||||
return effective_chunk_size, effective_chunk_overlap
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add chunking service arguments to parser"""
|
||||
FlowProcessor.add_args(parser)
|
||||
|
|
@ -12,7 +12,7 @@ class Flow:
|
|||
# Consumers and publishers. Is this a bit untidy?
|
||||
self.consumer = {}
|
||||
|
||||
self.setting = {}
|
||||
self.parameter = {}
|
||||
|
||||
for spec in processor.specifications:
|
||||
spec.add(self, processor, defn)
|
||||
|
|
@ -28,5 +28,5 @@ class Flow:
|
|||
def __call__(self, key):
|
||||
if key in self.producer: return self.producer[key]
|
||||
if key in self.consumer: return self.consumer[key]
|
||||
if key in self.setting: return self.setting[key].value
|
||||
if key in self.parameter: return self.parameter[key].value
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ class FlowProcessor(AsyncProcessor):
|
|||
|
||||
# These can be overriden by a derived class:
|
||||
|
||||
# Array of specifications: ConsumerSpec, ProducerSpec, SettingSpec
|
||||
# Array of specifications: ConsumerSpec, ProducerSpec, ParameterSpec
|
||||
self.specifications = []
|
||||
|
||||
logger.info("Service initialised.")
|
||||
|
|
|
|||
|
|
@ -5,11 +5,11 @@ LLM text completion base class
|
|||
|
||||
import time
|
||||
import logging
|
||||
from prometheus_client import Histogram
|
||||
from prometheus_client import Histogram, Info
|
||||
|
||||
from .. schema import TextCompletionRequest, TextCompletionResponse, Error
|
||||
from .. exceptions import TooManyRequests
|
||||
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec, ParameterSpec
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -32,7 +32,7 @@ class LlmService(FlowProcessor):
|
|||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
id = params.get("id", default_ident)
|
||||
concurrency = params.get("concurrency", 1)
|
||||
|
||||
super(LlmService, self).__init__(**params | {
|
||||
|
|
@ -56,6 +56,18 @@ class LlmService(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ParameterSpec(
|
||||
name = "model",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ParameterSpec(
|
||||
name = "temperature",
|
||||
)
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "text_completion_metric"):
|
||||
__class__.text_completion_metric = Histogram(
|
||||
'text_completion_duration',
|
||||
|
|
@ -70,6 +82,13 @@ class LlmService(FlowProcessor):
|
|||
]
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "text_completion_model_metric"):
|
||||
__class__.text_completion_model_metric = Info(
|
||||
'text_completion_model',
|
||||
'Text completion model',
|
||||
["processor", "flow"]
|
||||
)
|
||||
|
||||
async def on_request(self, msg, consumer, flow):
|
||||
|
||||
try:
|
||||
|
|
@ -85,10 +104,21 @@ class LlmService(FlowProcessor):
|
|||
flow=f"{flow.name}-{consumer.name}",
|
||||
).time():
|
||||
|
||||
model = flow("model")
|
||||
temperature = flow("temperature")
|
||||
|
||||
response = await self.generate_content(
|
||||
request.system, request.prompt
|
||||
request.system, request.prompt, model, temperature
|
||||
)
|
||||
|
||||
__class__.text_completion_model_metric.labels(
|
||||
processor = self.id,
|
||||
flow = flow.name
|
||||
).info({
|
||||
"model": str(model) if model is not None else "",
|
||||
"temperature": str(temperature) if temperature is not None else "",
|
||||
})
|
||||
|
||||
await flow("response").send(
|
||||
TextCompletionResponse(
|
||||
error=None,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
from . spec import Spec
|
||||
|
||||
class Setting:
|
||||
class Parameter:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
async def start():
|
||||
|
|
@ -9,11 +9,13 @@ class Setting:
|
|||
async def stop():
|
||||
pass
|
||||
|
||||
class SettingSpec(Spec):
|
||||
class ParameterSpec(Spec):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def add(self, flow, processor, definition):
|
||||
|
||||
flow.config[self.name] = Setting(definition[self.name])
|
||||
value = definition.get(self.name, None)
|
||||
|
||||
flow.parameter[self.name] = Parameter(value)
|
||||
|
||||
|
|
@ -49,8 +49,6 @@ class RequestResponse(Subscriber):
|
|||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
logger.debug(f"Sending request {id}...")
|
||||
|
||||
q = await self.subscribe(id)
|
||||
|
||||
try:
|
||||
|
|
@ -75,8 +73,6 @@ class RequestResponse(Subscriber):
|
|||
timeout=timeout
|
||||
)
|
||||
|
||||
logger.debug("Received response")
|
||||
|
||||
if recipient is None:
|
||||
|
||||
# If no recipient handler, just return the first
|
||||
|
|
|
|||
|
|
@ -12,12 +12,13 @@ class FlowRequestTranslator(MessageTranslator):
|
|||
class_name=data.get("class-name"),
|
||||
class_definition=data.get("class-definition"),
|
||||
description=data.get("description"),
|
||||
flow_id=data.get("flow-id")
|
||||
flow_id=data.get("flow-id"),
|
||||
parameters=data.get("parameters")
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: FlowRequest) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
|
||||
if obj.operation is not None:
|
||||
result["operation"] = obj.operation
|
||||
if obj.class_name is not None:
|
||||
|
|
@ -28,7 +29,9 @@ class FlowRequestTranslator(MessageTranslator):
|
|||
result["description"] = obj.description
|
||||
if obj.flow_id is not None:
|
||||
result["flow-id"] = obj.flow_id
|
||||
|
||||
if obj.parameters is not None:
|
||||
result["parameters"] = obj.parameters
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -40,7 +43,7 @@ class FlowResponseTranslator(MessageTranslator):
|
|||
|
||||
def from_pulsar(self, obj: FlowResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
|
||||
if obj.class_names is not None:
|
||||
result["class-names"] = obj.class_names
|
||||
if obj.flow_ids is not None:
|
||||
|
|
@ -51,7 +54,9 @@ class FlowResponseTranslator(MessageTranslator):
|
|||
result["flow"] = obj.flow
|
||||
if obj.description is not None:
|
||||
result["description"] = obj.description
|
||||
|
||||
if obj.parameters is not None:
|
||||
result["parameters"] = obj.parameters
|
||||
|
||||
return result
|
||||
|
||||
def from_response_with_completion(self, obj: FlowResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
|
|
|
|||
|
|
@ -35,6 +35,9 @@ class FlowRequest(Record):
|
|||
# get_flow, start_flow, stop_flow
|
||||
flow_id = String()
|
||||
|
||||
# start_flow - optional parameters for flow customization
|
||||
parameters = Map(String())
|
||||
|
||||
class FlowResponse(Record):
|
||||
|
||||
# list_classes
|
||||
|
|
@ -52,6 +55,9 @@ class FlowResponse(Record):
|
|||
# get_flow
|
||||
description = String()
|
||||
|
||||
# get_flow - parameters used when flow was started
|
||||
parameters = Map(String())
|
||||
|
||||
# Everything
|
||||
error = Error()
|
||||
|
||||
|
|
|
|||
|
|
@ -183,13 +183,13 @@ class Processor(LlmService):
|
|||
}
|
||||
)
|
||||
|
||||
self.model = model
|
||||
# Store default configuration
|
||||
self.default_model = model
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
|
||||
self.variant = self.determine_variant(self.model)()
|
||||
self.variant.set_temperature(temperature)
|
||||
self.variant.set_max_output(max_output)
|
||||
# Cache for model variants to avoid re-initialization
|
||||
self.model_variants = {}
|
||||
|
||||
self.session = boto3.Session(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
|
|
@ -208,47 +208,75 @@ class Processor(LlmService):
|
|||
# FIXME: Missing, Amazon models, Deepseek
|
||||
|
||||
# This set of conditions deals with normal bedrock on-demand usage
|
||||
if self.model.startswith("mistral"):
|
||||
if model.startswith("mistral"):
|
||||
return Mistral
|
||||
elif self.model.startswith("meta"):
|
||||
elif model.startswith("meta"):
|
||||
return Meta
|
||||
elif self.model.startswith("anthropic"):
|
||||
elif model.startswith("anthropic"):
|
||||
return Anthropic
|
||||
elif self.model.startswith("ai21"):
|
||||
elif model.startswith("ai21"):
|
||||
return Ai21
|
||||
elif self.model.startswith("cohere"):
|
||||
elif model.startswith("cohere"):
|
||||
return Cohere
|
||||
|
||||
# The inference profiles
|
||||
if self.model.startswith("us.meta"):
|
||||
if model.startswith("us.meta"):
|
||||
return Meta
|
||||
elif self.model.startswith("us.anthropic"):
|
||||
elif model.startswith("us.anthropic"):
|
||||
return Anthropic
|
||||
elif self.model.startswith("eu.meta"):
|
||||
elif model.startswith("eu.meta"):
|
||||
return Meta
|
||||
elif self.model.startswith("eu.anthropic"):
|
||||
elif model.startswith("eu.anthropic"):
|
||||
return Anthropic
|
||||
|
||||
return Default
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
def _get_or_create_variant(self, model_name, temperature=None):
|
||||
"""Get cached model variant or create new one"""
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
# Create a cache key that includes temperature to avoid conflicts
|
||||
cache_key = f"{model_name}:{effective_temperature}"
|
||||
|
||||
if cache_key not in self.model_variants:
|
||||
logger.info(f"Creating model variant for '{model_name}' with temperature {effective_temperature}")
|
||||
variant_class = self.determine_variant(model_name)
|
||||
variant = variant_class()
|
||||
variant.set_temperature(effective_temperature)
|
||||
variant.set_max_output(self.max_output)
|
||||
self.model_variants[cache_key] = variant
|
||||
|
||||
return self.model_variants[cache_key]
|
||||
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
try:
|
||||
# Get the appropriate variant for this model
|
||||
variant = self._get_or_create_variant(model_name, effective_temperature)
|
||||
|
||||
promptbody = self.variant.encode_request(system, prompt)
|
||||
promptbody = variant.encode_request(system, prompt)
|
||||
|
||||
accept = 'application/json'
|
||||
contentType = 'application/json'
|
||||
|
||||
response = self.bedrock.invoke_model(
|
||||
body=promptbody,
|
||||
modelId=self.model,
|
||||
modelId=model_name,
|
||||
accept=accept,
|
||||
contentType=contentType
|
||||
)
|
||||
|
||||
# Response structure decode
|
||||
outputtext = self.variant.decode_response(response)
|
||||
outputtext = variant.decode_response(response)
|
||||
|
||||
metadata = response['ResponseMetadata']['HTTPHeaders']
|
||||
inputtokens = int(metadata['x-amzn-bedrock-input-token-count'])
|
||||
|
|
@ -262,7 +290,7 @@ class Processor(LlmService):
|
|||
text = outputtext,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = self.model
|
||||
model = model_name
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ tg-show-kg-cores = "trustgraph.cli.show_kg_cores:main"
|
|||
tg-show-library-documents = "trustgraph.cli.show_library_documents:main"
|
||||
tg-show-library-processing = "trustgraph.cli.show_library_processing:main"
|
||||
tg-show-mcp-tools = "trustgraph.cli.show_mcp_tools:main"
|
||||
tg-show-parameter-types = "trustgraph.cli.show_parameter_types:main"
|
||||
tg-show-processor-state = "trustgraph.cli.show_processor_state:main"
|
||||
tg-show-prompts = "trustgraph.cli.show_prompts:main"
|
||||
tg-show-token-costs = "trustgraph.cli.show_token_costs:main"
|
||||
|
|
|
|||
|
|
@ -599,8 +599,7 @@ def _send_to_trustgraph(objects, api_url, flow, batch_size=1000):
|
|||
imported_count += 1
|
||||
|
||||
if imported_count % 100 == 0:
|
||||
logger.info(f"Imported {imported_count}/{len(objects)} records...")
|
||||
print(f"✅ Imported {imported_count}/{len(objects)} records...")
|
||||
logger.debug(f"Imported {imported_count}/{len(objects)} records...")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send record {imported_count + 1}: {e}")
|
||||
|
|
|
|||
|
|
@ -5,38 +5,93 @@ Shows all defined flow classes.
|
|||
import argparse
|
||||
import os
|
||||
import tabulate
|
||||
from trustgraph.api import Api
|
||||
from trustgraph.api import Api, ConfigKey
|
||||
import json
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
|
||||
def format_parameters(params_metadata, config_api):
|
||||
"""
|
||||
Format parameter metadata for display
|
||||
|
||||
Args:
|
||||
params_metadata: Parameter definitions from flow class
|
||||
config_api: API client to get parameter type information
|
||||
|
||||
Returns:
|
||||
Formatted string describing parameters
|
||||
"""
|
||||
if not params_metadata:
|
||||
return "None"
|
||||
|
||||
param_list = []
|
||||
|
||||
# Sort parameters by order if available
|
||||
sorted_params = sorted(
|
||||
params_metadata.items(),
|
||||
key=lambda x: x[1].get("order", 999)
|
||||
)
|
||||
|
||||
for param_name, param_meta in sorted_params:
|
||||
description = param_meta.get("description", param_name)
|
||||
param_type = param_meta.get("type", "unknown")
|
||||
|
||||
# Get type information if available
|
||||
type_info = param_type
|
||||
if config_api:
|
||||
try:
|
||||
key = ConfigKey("parameter-types", param_type)
|
||||
type_def_value = config_api.get([key])[0].value
|
||||
param_type_def = json.loads(type_def_value)
|
||||
|
||||
# Add default value if available
|
||||
default = param_type_def.get("default")
|
||||
if default is not None:
|
||||
type_info = f"{param_type} (default: {default})"
|
||||
|
||||
except:
|
||||
# If we can't get type definition, just show the type name
|
||||
pass
|
||||
|
||||
param_list.append(f" {param_name}: {description} [{type_info}]")
|
||||
|
||||
return "\n".join(param_list)
|
||||
|
||||
def show_flow_classes(url):
|
||||
|
||||
api = Api(url).flow()
|
||||
api = Api(url)
|
||||
flow_api = api.flow()
|
||||
config_api = api.config()
|
||||
|
||||
class_names = api.list_classes()
|
||||
class_names = flow_api.list_classes()
|
||||
|
||||
if len(class_names) == 0:
|
||||
print("No flows.")
|
||||
print("No flow classes.")
|
||||
return
|
||||
|
||||
classes = []
|
||||
|
||||
for class_name in class_names:
|
||||
cls = api.get_class(class_name)
|
||||
classes.append((
|
||||
class_name,
|
||||
cls.get("description", ""),
|
||||
", ".join(cls.get("tags", [])),
|
||||
))
|
||||
cls = flow_api.get_class(class_name)
|
||||
|
||||
print(tabulate.tabulate(
|
||||
classes,
|
||||
tablefmt="pretty",
|
||||
maxcolwidths=[None, 40, 20],
|
||||
stralign="left",
|
||||
headers = ["flow class", "description", "tags"],
|
||||
))
|
||||
table = []
|
||||
table.append(("name", class_name))
|
||||
table.append(("description", cls.get("description", "")))
|
||||
|
||||
tags = cls.get("tags", [])
|
||||
if tags:
|
||||
table.append(("tags", ", ".join(tags)))
|
||||
|
||||
# Show parameters if they exist
|
||||
parameters = cls.get("parameters", {})
|
||||
if parameters:
|
||||
param_str = format_parameters(parameters, config_api)
|
||||
table.append(("parameters", param_str))
|
||||
|
||||
print(tabulate.tabulate(
|
||||
table,
|
||||
tablefmt="pretty",
|
||||
stralign="left",
|
||||
))
|
||||
print()
|
||||
|
||||
def main():
|
||||
|
||||
|
|
|
|||
|
|
@ -45,6 +45,89 @@ def describe_interfaces(intdefs, flow):
|
|||
|
||||
return "\n".join(lst)
|
||||
|
||||
def get_enum_description(param_value, param_type_def):
|
||||
"""
|
||||
Get the human-readable description for an enum value
|
||||
|
||||
Args:
|
||||
param_value: The actual parameter value (e.g., "gpt-4")
|
||||
param_type_def: The parameter type definition containing enum objects
|
||||
|
||||
Returns:
|
||||
Human-readable description or the original value if not found
|
||||
"""
|
||||
enum_list = param_type_def.get("enum", [])
|
||||
|
||||
# Handle both old format (strings) and new format (objects with id/description)
|
||||
for enum_item in enum_list:
|
||||
if isinstance(enum_item, dict):
|
||||
if enum_item.get("id") == param_value:
|
||||
return enum_item.get("description", param_value)
|
||||
elif enum_item == param_value:
|
||||
return param_value
|
||||
|
||||
# If not found in enum, return original value
|
||||
return param_value
|
||||
|
||||
def format_parameters(flow_params, class_params_metadata, config_api):
|
||||
"""
|
||||
Format flow parameters with their human-readable descriptions
|
||||
|
||||
Args:
|
||||
flow_params: The actual parameter values used in the flow
|
||||
class_params_metadata: The parameter metadata from the flow class definition
|
||||
config_api: API client to retrieve parameter type definitions
|
||||
|
||||
Returns:
|
||||
Formatted string of parameters with descriptions
|
||||
"""
|
||||
if not flow_params:
|
||||
return "None"
|
||||
|
||||
param_list = []
|
||||
|
||||
# Sort parameters by order if available
|
||||
sorted_params = sorted(
|
||||
class_params_metadata.items(),
|
||||
key=lambda x: x[1].get("order", 999)
|
||||
)
|
||||
|
||||
for param_name, param_meta in sorted_params:
|
||||
if param_name in flow_params:
|
||||
value = flow_params[param_name]
|
||||
description = param_meta.get("description", param_name)
|
||||
param_type = param_meta.get("type", "")
|
||||
controlled_by = param_meta.get("controlled-by", None)
|
||||
|
||||
# Try to get enum description if this parameter has a type definition
|
||||
display_value = value
|
||||
if param_type and config_api:
|
||||
try:
|
||||
from trustgraph.api import ConfigKey
|
||||
key = ConfigKey("parameter-types", param_type)
|
||||
type_def_value = config_api.get([key])[0].value
|
||||
param_type_def = json.loads(type_def_value)
|
||||
display_value = get_enum_description(value, param_type_def)
|
||||
except:
|
||||
# If we can't get the type definition, just use the original value
|
||||
display_value = value
|
||||
|
||||
# Format the parameter line
|
||||
line = f"• {description}: {display_value}"
|
||||
|
||||
# Add controlled-by indicator if present
|
||||
if controlled_by:
|
||||
line += f" (controlled by {controlled_by})"
|
||||
|
||||
param_list.append(line)
|
||||
|
||||
# Add any parameters that aren't in the class metadata (shouldn't happen normally)
|
||||
for param_name, value in flow_params.items():
|
||||
if param_name not in class_params_metadata:
|
||||
param_list.append(f"• {param_name}: {value} (undefined)")
|
||||
|
||||
return "\n".join(param_list) if param_list else "None"
|
||||
|
||||
def show_flows(url):
|
||||
|
||||
api = Api(url)
|
||||
|
|
@ -74,6 +157,26 @@ def show_flows(url):
|
|||
table.append(("id", id))
|
||||
table.append(("class", flow.get("class-name", "")))
|
||||
table.append(("desc", flow.get("description", "")))
|
||||
|
||||
# Display parameters with human-readable descriptions
|
||||
parameters = flow.get("parameters", {})
|
||||
if parameters:
|
||||
# Try to get the flow class definition for parameter metadata
|
||||
class_name = flow.get("class-name", "")
|
||||
if class_name:
|
||||
try:
|
||||
flow_class = flow_api.get_class(class_name)
|
||||
class_params_metadata = flow_class.get("parameters", {})
|
||||
param_str = format_parameters(parameters, class_params_metadata, config_api)
|
||||
except Exception as e:
|
||||
# Fallback to JSON if we can't get the class definition
|
||||
param_str = json.dumps(parameters, indent=2)
|
||||
else:
|
||||
# No class name, fallback to JSON
|
||||
param_str = json.dumps(parameters, indent=2)
|
||||
|
||||
table.append(("parameters", param_str))
|
||||
|
||||
table.append(("queue", describe_interfaces(interface_defs, flow)))
|
||||
|
||||
print(tabulate.tabulate(
|
||||
|
|
|
|||
210
trustgraph-cli/trustgraph/cli/show_parameter_types.py
Normal file
210
trustgraph-cli/trustgraph/cli/show_parameter_types.py
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
"""
|
||||
Shows all defined parameter types used in flow classes.
|
||||
|
||||
Parameter types define the schema and constraints for parameters that can
|
||||
be used in flow class definitions. This includes data types, default values,
|
||||
valid enums, and validation rules.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import tabulate
|
||||
from trustgraph.api import Api, ConfigKey
|
||||
import json
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
|
||||
def format_enum_values(enum_list):
|
||||
"""
|
||||
Format enum values for display, handling both old and new formats
|
||||
|
||||
Args:
|
||||
enum_list: List of enum values (strings or objects with id/description)
|
||||
|
||||
Returns:
|
||||
Formatted string describing enum options
|
||||
"""
|
||||
if not enum_list:
|
||||
return "Any value"
|
||||
|
||||
enum_items = []
|
||||
for item in enum_list:
|
||||
if isinstance(item, dict):
|
||||
# New format: objects with id and description
|
||||
enum_id = item.get("id", "")
|
||||
description = item.get("description", "")
|
||||
if description:
|
||||
enum_items.append(f"{enum_id} ({description})")
|
||||
else:
|
||||
enum_items.append(enum_id)
|
||||
else:
|
||||
# Old format: simple strings
|
||||
enum_items.append(str(item))
|
||||
|
||||
return "\n".join(f"• {item}" for item in enum_items)
|
||||
|
||||
def format_constraints(param_type_def):
|
||||
"""
|
||||
Format validation constraints for display
|
||||
|
||||
Args:
|
||||
param_type_def: Parameter type definition
|
||||
|
||||
Returns:
|
||||
Formatted string describing constraints
|
||||
"""
|
||||
constraints = []
|
||||
|
||||
# Handle numeric constraints
|
||||
if "minimum" in param_type_def:
|
||||
constraints.append(f"min: {param_type_def['minimum']}")
|
||||
if "maximum" in param_type_def:
|
||||
constraints.append(f"max: {param_type_def['maximum']}")
|
||||
|
||||
# Handle string constraints
|
||||
if "minLength" in param_type_def:
|
||||
constraints.append(f"min length: {param_type_def['minLength']}")
|
||||
if "maxLength" in param_type_def:
|
||||
constraints.append(f"max length: {param_type_def['maxLength']}")
|
||||
if "pattern" in param_type_def:
|
||||
constraints.append(f"pattern: {param_type_def['pattern']}")
|
||||
|
||||
# Handle required field
|
||||
if param_type_def.get("required", False):
|
||||
constraints.append("required")
|
||||
|
||||
return ", ".join(constraints) if constraints else "None"
|
||||
|
||||
def show_parameter_types(url):
|
||||
"""
|
||||
Show all parameter type definitions
|
||||
"""
|
||||
api = Api(url)
|
||||
config_api = api.config()
|
||||
|
||||
# Get list of all parameter types
|
||||
try:
|
||||
param_type_names = config_api.list("parameter-types")
|
||||
except Exception as e:
|
||||
print(f"Error retrieving parameter types: {e}")
|
||||
return
|
||||
|
||||
if len(param_type_names) == 0:
|
||||
print("No parameter types defined.")
|
||||
return
|
||||
|
||||
for param_type_name in param_type_names:
|
||||
try:
|
||||
# Get the parameter type definition
|
||||
key = ConfigKey("parameter-types", param_type_name)
|
||||
type_def_value = config_api.get([key])[0].value
|
||||
param_type_def = json.loads(type_def_value)
|
||||
|
||||
table = []
|
||||
table.append(("name", param_type_name))
|
||||
table.append(("description", param_type_def.get("description", "")))
|
||||
table.append(("type", param_type_def.get("type", "unknown")))
|
||||
|
||||
# Show default value if present
|
||||
default = param_type_def.get("default")
|
||||
if default is not None:
|
||||
table.append(("default", str(default)))
|
||||
|
||||
# Show enum values if present
|
||||
enum_list = param_type_def.get("enum")
|
||||
if enum_list:
|
||||
enum_str = format_enum_values(enum_list)
|
||||
table.append(("valid values", enum_str))
|
||||
|
||||
# Show constraints
|
||||
constraints = format_constraints(param_type_def)
|
||||
if constraints != "None":
|
||||
table.append(("constraints", constraints))
|
||||
|
||||
print(tabulate.tabulate(
|
||||
table,
|
||||
tablefmt="pretty",
|
||||
stralign="left",
|
||||
))
|
||||
print()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error retrieving parameter type '{param_type_name}': {e}")
|
||||
print()
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='tg-show-parameter-types',
|
||||
description=__doc__,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-u', '--api-url',
|
||||
default=default_url,
|
||||
help=f'API URL (default: {default_url})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--type',
|
||||
help='Show only the specified parameter type',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
if args.type:
|
||||
# Show specific parameter type
|
||||
show_specific_parameter_type(args.api_url, args.type)
|
||||
else:
|
||||
# Show all parameter types
|
||||
show_parameter_types(args.api_url)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception:", e, flush=True)
|
||||
|
||||
def show_specific_parameter_type(url, param_type_name):
|
||||
"""
|
||||
Show a specific parameter type definition
|
||||
"""
|
||||
api = Api(url)
|
||||
config_api = api.config()
|
||||
|
||||
try:
|
||||
# Get the parameter type definition
|
||||
key = ConfigKey("parameter-types", param_type_name)
|
||||
type_def_value = config_api.get([key])[0].value
|
||||
param_type_def = json.loads(type_def_value)
|
||||
|
||||
table = []
|
||||
table.append(("name", param_type_name))
|
||||
table.append(("description", param_type_def.get("description", "")))
|
||||
table.append(("type", param_type_def.get("type", "unknown")))
|
||||
|
||||
# Show default value if present
|
||||
default = param_type_def.get("default")
|
||||
if default is not None:
|
||||
table.append(("default", str(default)))
|
||||
|
||||
# Show enum values if present
|
||||
enum_list = param_type_def.get("enum")
|
||||
if enum_list:
|
||||
enum_str = format_enum_values(enum_list)
|
||||
table.append(("valid values", enum_str))
|
||||
|
||||
# Show constraints
|
||||
constraints = format_constraints(param_type_def)
|
||||
if constraints != "None":
|
||||
table.append(("constraints", constraints))
|
||||
|
||||
print(tabulate.tabulate(
|
||||
table,
|
||||
tablefmt="pretty",
|
||||
stralign="left",
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error retrieving parameter type '{param_type_name}': {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,5 +1,13 @@
|
|||
"""
|
||||
Starts a processing flow using a defined flow class
|
||||
Starts a processing flow using a defined flow class.
|
||||
|
||||
Parameters can be provided in three ways:
|
||||
1. As key=value pairs: --param model=gpt-4 --param temp=0.7
|
||||
2. As JSON string: -p '{"model": "gpt-4", "temp": 0.7}'
|
||||
3. As JSON file: --parameters-file params.json
|
||||
|
||||
Note: All parameter values are stored as strings internally, regardless of their
|
||||
input format. Numbers and booleans will be converted to string representation.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
|
@ -10,7 +18,7 @@ import json
|
|||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
|
||||
def start_flow(url, class_name, flow_id, description):
|
||||
def start_flow(url, class_name, flow_id, description, parameters=None):
|
||||
|
||||
api = Api(url).flow()
|
||||
|
||||
|
|
@ -18,6 +26,7 @@ def start_flow(url, class_name, flow_id, description):
|
|||
class_name = class_name,
|
||||
id = flow_id,
|
||||
description = description,
|
||||
parameters = parameters,
|
||||
)
|
||||
|
||||
def main():
|
||||
|
|
@ -51,15 +60,58 @@ def main():
|
|||
help=f'Flow description',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-p', '--parameters',
|
||||
help='Flow parameters as JSON string (e.g., \'{"model": "gpt-4", "temp": 0.7}\')',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--parameters-file',
|
||||
help='Path to JSON file containing flow parameters',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--param',
|
||||
action='append',
|
||||
help='Flow parameter as key=value pair (can be used multiple times, e.g., --param model=gpt-4 --param temp=0.7)',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Parse parameters from command line arguments
|
||||
parameters = None
|
||||
|
||||
if args.parameters_file:
|
||||
with open(args.parameters_file, 'r') as f:
|
||||
params_data = json.load(f)
|
||||
# Convert all values to strings
|
||||
parameters = {k: str(v) for k, v in params_data.items()}
|
||||
elif args.parameters:
|
||||
params_data = json.loads(args.parameters)
|
||||
# Convert all values to strings
|
||||
parameters = {k: str(v) for k, v in params_data.items()}
|
||||
elif args.param:
|
||||
# Parse key=value pairs
|
||||
parameters = {}
|
||||
for param in args.param:
|
||||
if '=' not in param:
|
||||
raise ValueError(f"Invalid parameter format: {param}. Expected key=value")
|
||||
|
||||
key, value = param.split('=', 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
# All parameter values must be strings for Pulsar
|
||||
# Just store everything as a string
|
||||
parameters[key] = value
|
||||
|
||||
start_flow(
|
||||
url = args.api_url,
|
||||
class_name = args.class_name,
|
||||
flow_id = args.flow_id,
|
||||
description = args.description,
|
||||
parameters = parameters,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -9,14 +9,14 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|||
from prometheus_client import Histogram
|
||||
|
||||
from ... schema import TextDocument, Chunk
|
||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import ChunkingService, ConsumerSpec, ProducerSpec
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "chunker"
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
class Processor(ChunkingService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
|
|
@ -28,6 +28,10 @@ class Processor(FlowProcessor):
|
|||
**params | { "id": id }
|
||||
)
|
||||
|
||||
# Store default values for parameter override
|
||||
self.default_chunk_size = chunk_size
|
||||
self.default_chunk_overlap = chunk_overlap
|
||||
|
||||
if not hasattr(__class__, "chunk_metric"):
|
||||
__class__.chunk_metric = Histogram(
|
||||
'chunk_size', 'Chunk size',
|
||||
|
|
@ -65,7 +69,22 @@ class Processor(FlowProcessor):
|
|||
v = msg.value()
|
||||
logger.info(f"Chunking document {v.metadata.id}...")
|
||||
|
||||
texts = self.text_splitter.create_documents(
|
||||
# Extract chunk parameters from flow (allows runtime override)
|
||||
chunk_size, chunk_overlap = await self.chunk_document(
|
||||
msg, consumer, flow,
|
||||
self.default_chunk_size,
|
||||
self.default_chunk_overlap
|
||||
)
|
||||
|
||||
# Create text splitter with effective parameters
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
length_function=len,
|
||||
is_separator_regex=False,
|
||||
)
|
||||
|
||||
texts = text_splitter.create_documents(
|
||||
[v.text.decode("utf-8")]
|
||||
)
|
||||
|
||||
|
|
@ -89,7 +108,7 @@ class Processor(FlowProcessor):
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
ChunkingService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-z', '--chunk-size',
|
||||
|
|
|
|||
|
|
@ -9,14 +9,14 @@ from langchain_text_splitters import TokenTextSplitter
|
|||
from prometheus_client import Histogram
|
||||
|
||||
from ... schema import TextDocument, Chunk
|
||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import ChunkingService, ConsumerSpec, ProducerSpec
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "chunker"
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
class Processor(ChunkingService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
|
|
@ -28,6 +28,10 @@ class Processor(FlowProcessor):
|
|||
**params | { "id": id }
|
||||
)
|
||||
|
||||
# Store default values for parameter override
|
||||
self.default_chunk_size = chunk_size
|
||||
self.default_chunk_overlap = chunk_overlap
|
||||
|
||||
if not hasattr(__class__, "chunk_metric"):
|
||||
__class__.chunk_metric = Histogram(
|
||||
'chunk_size', 'Chunk size',
|
||||
|
|
@ -64,7 +68,21 @@ class Processor(FlowProcessor):
|
|||
v = msg.value()
|
||||
logger.info(f"Chunking document {v.metadata.id}...")
|
||||
|
||||
texts = self.text_splitter.create_documents(
|
||||
# Extract chunk parameters from flow (allows runtime override)
|
||||
chunk_size, chunk_overlap = await self.chunk_document(
|
||||
msg, consumer, flow,
|
||||
self.default_chunk_size,
|
||||
self.default_chunk_overlap
|
||||
)
|
||||
|
||||
# Create text splitter with effective parameters
|
||||
text_splitter = TokenTextSplitter(
|
||||
encoding_name="cl100k_base",
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
|
||||
texts = text_splitter.create_documents(
|
||||
[v.text.decode("utf-8")]
|
||||
)
|
||||
|
||||
|
|
@ -88,7 +106,7 @@ class Processor(FlowProcessor):
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
ChunkingService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-z', '--chunk-size',
|
||||
|
|
|
|||
|
|
@ -10,6 +10,95 @@ class FlowConfig:
|
|||
def __init__(self, config):
|
||||
|
||||
self.config = config
|
||||
# Cache for parameter type definitions to avoid repeated lookups
|
||||
self.param_type_cache = {}
|
||||
|
||||
async def resolve_parameters(self, flow_class, user_params):
|
||||
"""
|
||||
Resolve parameters by merging user-provided values with defaults.
|
||||
|
||||
Args:
|
||||
flow_class: The flow class definition dict
|
||||
user_params: User-provided parameters dict (may be None or empty)
|
||||
|
||||
Returns:
|
||||
Complete parameter dict with user values and defaults merged (all values as strings)
|
||||
"""
|
||||
# If the flow class has no parameters section, return user params as-is (stringified)
|
||||
if "parameters" not in flow_class:
|
||||
if not user_params:
|
||||
return {}
|
||||
# Ensure all values are strings
|
||||
return {k: str(v) for k, v in user_params.items()}
|
||||
|
||||
resolved = {}
|
||||
flow_params = flow_class["parameters"]
|
||||
user_params = user_params if user_params else {}
|
||||
|
||||
# First pass: resolve parameters with explicit values or defaults
|
||||
for param_name, param_meta in flow_params.items():
|
||||
# Check if user provided a value
|
||||
if param_name in user_params:
|
||||
# Store as string
|
||||
resolved[param_name] = str(user_params[param_name])
|
||||
else:
|
||||
# Look up the parameter type definition
|
||||
param_type = param_meta.get("type")
|
||||
if param_type:
|
||||
# Check cache first
|
||||
if param_type not in self.param_type_cache:
|
||||
try:
|
||||
# Fetch parameter type definition from config store
|
||||
type_def = await self.config.get("parameter-types").get(param_type)
|
||||
if type_def:
|
||||
self.param_type_cache[param_type] = json.loads(type_def)
|
||||
else:
|
||||
logger.warning(f"Parameter type '{param_type}' not found in config")
|
||||
self.param_type_cache[param_type] = {}
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching parameter type '{param_type}': {e}")
|
||||
self.param_type_cache[param_type] = {}
|
||||
|
||||
# Apply default from type definition (as string)
|
||||
type_def = self.param_type_cache[param_type]
|
||||
if "default" in type_def:
|
||||
default_value = type_def["default"]
|
||||
# Convert to string based on type
|
||||
if isinstance(default_value, bool):
|
||||
resolved[param_name] = "true" if default_value else "false"
|
||||
else:
|
||||
resolved[param_name] = str(default_value)
|
||||
elif type_def.get("required", False):
|
||||
# Required parameter with no default and no user value
|
||||
raise RuntimeError(f"Required parameter '{param_name}' not provided and has no default")
|
||||
|
||||
# Second pass: handle controlled-by relationships
|
||||
for param_name, param_meta in flow_params.items():
|
||||
if param_name not in resolved and "controlled-by" in param_meta:
|
||||
controller = param_meta["controlled-by"]
|
||||
if controller in resolved:
|
||||
# Inherit value from controlling parameter (already a string)
|
||||
resolved[param_name] = resolved[controller]
|
||||
else:
|
||||
# Controller has no value, try to get default from type definition
|
||||
param_type = param_meta.get("type")
|
||||
if param_type and param_type in self.param_type_cache:
|
||||
type_def = self.param_type_cache[param_type]
|
||||
if "default" in type_def:
|
||||
default_value = type_def["default"]
|
||||
# Convert to string based on type
|
||||
if isinstance(default_value, bool):
|
||||
resolved[param_name] = "true" if default_value else "false"
|
||||
else:
|
||||
resolved[param_name] = str(default_value)
|
||||
|
||||
# Include any extra parameters from user that weren't in flow class definition
|
||||
# This allows for forward compatibility (ensure they're strings)
|
||||
for key, value in user_params.items():
|
||||
if key not in resolved:
|
||||
resolved[key] = str(value)
|
||||
|
||||
return resolved
|
||||
|
||||
async def handle_list_classes(self, msg):
|
||||
|
||||
|
|
@ -68,11 +157,14 @@ class FlowConfig:
|
|||
|
||||
async def handle_get_flow(self, msg):
|
||||
|
||||
flow = await self.config.get("flows").get(msg.flow_id)
|
||||
flow_data = await self.config.get("flows").get(msg.flow_id)
|
||||
flow = json.loads(flow_data)
|
||||
|
||||
return FlowResponse(
|
||||
error = None,
|
||||
flow = flow,
|
||||
flow = flow_data,
|
||||
description = flow.get("description", ""),
|
||||
parameters = flow.get("parameters", {}),
|
||||
)
|
||||
|
||||
async def handle_start_flow(self, msg):
|
||||
|
|
@ -83,45 +175,65 @@ class FlowConfig:
|
|||
if msg.flow_id is None:
|
||||
raise RuntimeError("No flow ID")
|
||||
|
||||
if msg.flow_id in await self.config.get("flows").values():
|
||||
if msg.flow_id in await self.config.get("flows").keys():
|
||||
raise RuntimeError("Flow already exists")
|
||||
|
||||
if msg.description is None:
|
||||
raise RuntimeError("No description")
|
||||
|
||||
if msg.class_name not in await self.config.get("flow-classes").values():
|
||||
if msg.class_name not in await self.config.get("flow-classes").keys():
|
||||
raise RuntimeError("Class does not exist")
|
||||
|
||||
def repl_template(tmp):
|
||||
return tmp.replace(
|
||||
"{class}", msg.class_name
|
||||
).replace(
|
||||
"{id}", msg.flow_id
|
||||
)
|
||||
|
||||
cls = json.loads(
|
||||
await self.config.get("flow-classes").get(msg.class_name)
|
||||
)
|
||||
|
||||
# Resolve parameters by merging user-provided values with defaults
|
||||
user_params = msg.parameters if msg.parameters else {}
|
||||
parameters = await self.resolve_parameters(cls, user_params)
|
||||
|
||||
# Log the resolved parameters for debugging
|
||||
logger.debug(f"User provided parameters: {user_params}")
|
||||
logger.debug(f"Resolved parameters (with defaults): {parameters}")
|
||||
|
||||
# Apply parameter substitution to template replacement function
|
||||
def repl_template_with_params(tmp):
|
||||
|
||||
result = tmp.replace(
|
||||
"{class}", msg.class_name
|
||||
).replace(
|
||||
"{id}", msg.flow_id
|
||||
)
|
||||
# Apply parameter substitutions
|
||||
for param_name, param_value in parameters.items():
|
||||
result = result.replace(f"{{{param_name}}}", str(param_value))
|
||||
|
||||
return result
|
||||
|
||||
for kind in ("class", "flow"):
|
||||
|
||||
for k, v in cls[kind].items():
|
||||
|
||||
processor, variant = k.split(":", 1)
|
||||
|
||||
variant = repl_template(variant)
|
||||
variant = repl_template_with_params(variant)
|
||||
|
||||
v = {
|
||||
repl_template(k2): repl_template(v2)
|
||||
repl_template_with_params(k2): repl_template_with_params(v2)
|
||||
for k2, v2 in v.items()
|
||||
}
|
||||
|
||||
flac = await self.config.get("flows-active").values()
|
||||
if processor in flac:
|
||||
target = json.loads(flac[processor])
|
||||
flac = await self.config.get("flows-active").get(processor)
|
||||
if flac is not None:
|
||||
target = json.loads(flac)
|
||||
else:
|
||||
target = {}
|
||||
|
||||
# The condition if variant not in target: means it only adds
|
||||
# the configuration if the variant doesn't already exist.
|
||||
# If "everything" already exists in the target with old
|
||||
# values, they won't update.
|
||||
|
||||
if variant not in target:
|
||||
target[variant] = v
|
||||
|
||||
|
|
@ -131,10 +243,10 @@ class FlowConfig:
|
|||
|
||||
def repl_interface(i):
|
||||
if isinstance(i, str):
|
||||
return repl_template(i)
|
||||
return repl_template_with_params(i)
|
||||
else:
|
||||
return {
|
||||
k: repl_template(v)
|
||||
k: repl_template_with_params(v)
|
||||
for k, v in i.items()
|
||||
}
|
||||
|
||||
|
|
@ -152,6 +264,7 @@ class FlowConfig:
|
|||
"description": msg.description,
|
||||
"class-name": msg.class_name,
|
||||
"interfaces": interfaces,
|
||||
"parameters": parameters,
|
||||
})
|
||||
)
|
||||
|
||||
|
|
@ -177,15 +290,20 @@ class FlowConfig:
|
|||
raise RuntimeError("Internal error: flow has no flow class")
|
||||
|
||||
class_name = flow["class-name"]
|
||||
parameters = flow.get("parameters", {})
|
||||
|
||||
cls = json.loads(await self.config.get("flow-classes").get(class_name))
|
||||
|
||||
def repl_template(tmp):
|
||||
return tmp.replace(
|
||||
result = tmp.replace(
|
||||
"{class}", class_name
|
||||
).replace(
|
||||
"{id}", msg.flow_id
|
||||
)
|
||||
# Apply parameter substitutions
|
||||
for param_name, param_value in parameters.items():
|
||||
result = result.replace(f"{{{param_name}}}", str(param_value))
|
||||
return result
|
||||
|
||||
for kind in ("flow",):
|
||||
|
||||
|
|
@ -195,10 +313,10 @@ class FlowConfig:
|
|||
|
||||
variant = repl_template(variant)
|
||||
|
||||
flac = await self.config.get("flows-active").values()
|
||||
flac = await self.config.get("flows-active").get(processor)
|
||||
|
||||
if processor in flac:
|
||||
target = json.loads(flac[processor])
|
||||
if flac is not None:
|
||||
target = json.loads(flac)
|
||||
else:
|
||||
target = {}
|
||||
|
||||
|
|
@ -209,7 +327,7 @@ class FlowConfig:
|
|||
processor, json.dumps(target)
|
||||
)
|
||||
|
||||
if msg.flow_id in await self.config.get("flows").values():
|
||||
if msg.flow_id in await self.config.get("flows").keys():
|
||||
await self.config.get("flows").delete(msg.flow_id)
|
||||
|
||||
await self.config.inc_version()
|
||||
|
|
|
|||
|
|
@ -24,16 +24,12 @@ class KnowledgeGraph:
|
|||
self.keyspace = keyspace
|
||||
self.username = username
|
||||
|
||||
# Multi-table schema design for optimal performance
|
||||
self.use_legacy = os.getenv('CASSANDRA_USE_LEGACY', 'false').lower() == 'true'
|
||||
|
||||
if self.use_legacy:
|
||||
self.table = "triples" # Legacy single table
|
||||
else:
|
||||
# New optimized tables
|
||||
self.subject_table = "triples_s"
|
||||
self.po_table = "triples_p"
|
||||
self.object_table = "triples_o"
|
||||
# Optimized multi-table schema with collection deletion support
|
||||
self.subject_table = "triples_s"
|
||||
self.po_table = "triples_p"
|
||||
self.object_table = "triples_o"
|
||||
self.collection_table = "triples_collection" # For SPO queries and deletion
|
||||
self.collection_metadata_table = "collection_metadata" # For tracking which collections exist
|
||||
|
||||
if username and password:
|
||||
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
|
||||
|
|
@ -47,9 +43,7 @@ class KnowledgeGraph:
|
|||
_active_clusters.append(self.cluster)
|
||||
|
||||
self.init()
|
||||
|
||||
if not self.use_legacy:
|
||||
self.prepare_statements()
|
||||
self.prepare_statements()
|
||||
|
||||
def clear(self):
|
||||
|
||||
|
|
@ -70,42 +64,13 @@ class KnowledgeGraph:
|
|||
""");
|
||||
|
||||
self.session.set_keyspace(self.keyspace)
|
||||
self.init_optimized_schema()
|
||||
|
||||
if self.use_legacy:
|
||||
self.init_legacy_schema()
|
||||
else:
|
||||
self.init_optimized_schema()
|
||||
|
||||
def init_legacy_schema(self):
|
||||
"""Initialize legacy single-table schema for backward compatibility"""
|
||||
self.session.execute(f"""
|
||||
create table if not exists {self.table} (
|
||||
collection text,
|
||||
s text,
|
||||
p text,
|
||||
o text,
|
||||
PRIMARY KEY (collection, s, p, o)
|
||||
);
|
||||
""");
|
||||
|
||||
self.session.execute(f"""
|
||||
create index if not exists {self.table}_s
|
||||
ON {self.table} (s);
|
||||
""");
|
||||
|
||||
self.session.execute(f"""
|
||||
create index if not exists {self.table}_p
|
||||
ON {self.table} (p);
|
||||
""");
|
||||
|
||||
self.session.execute(f"""
|
||||
create index if not exists {self.table}_o
|
||||
ON {self.table} (o);
|
||||
""");
|
||||
|
||||
def init_optimized_schema(self):
|
||||
"""Initialize optimized multi-table schema for performance"""
|
||||
# Table 1: Subject-centric queries (get_s, get_sp, get_spo, get_os)
|
||||
# Table 1: Subject-centric queries (get_s, get_sp, get_os)
|
||||
# Compound partition key for optimal data distribution
|
||||
self.session.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.subject_table} (
|
||||
collection text,
|
||||
|
|
@ -117,6 +82,7 @@ class KnowledgeGraph:
|
|||
""");
|
||||
|
||||
# Table 2: Predicate-Object queries (get_p, get_po) - eliminates ALLOW FILTERING!
|
||||
# Compound partition key for optimal data distribution
|
||||
self.session.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.po_table} (
|
||||
collection text,
|
||||
|
|
@ -128,6 +94,7 @@ class KnowledgeGraph:
|
|||
""");
|
||||
|
||||
# Table 3: Object-centric queries (get_o)
|
||||
# Compound partition key for optimal data distribution
|
||||
self.session.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.object_table} (
|
||||
collection text,
|
||||
|
|
@ -138,7 +105,29 @@ class KnowledgeGraph:
|
|||
);
|
||||
""");
|
||||
|
||||
logger.info("Optimized multi-table schema initialized")
|
||||
# Table 4: Collection management and SPO queries (get_spo)
|
||||
# Simple partition key enables efficient collection deletion
|
||||
self.session.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.collection_table} (
|
||||
collection text,
|
||||
s text,
|
||||
p text,
|
||||
o text,
|
||||
PRIMARY KEY (collection, s, p, o)
|
||||
);
|
||||
""");
|
||||
|
||||
# Table 5: Collection metadata tracking
|
||||
# Tracks which collections exist without polluting triple data
|
||||
self.session.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.collection_metadata_table} (
|
||||
collection text,
|
||||
created_at timestamp,
|
||||
PRIMARY KEY (collection)
|
||||
);
|
||||
""");
|
||||
|
||||
logger.info("Optimized multi-table schema initialized (5 tables)")
|
||||
|
||||
def prepare_statements(self):
|
||||
"""Prepare statements for optimal performance"""
|
||||
|
|
@ -155,6 +144,10 @@ class KnowledgeGraph:
|
|||
f"INSERT INTO {self.object_table} (collection, o, s, p) VALUES (?, ?, ?, ?)"
|
||||
)
|
||||
|
||||
self.insert_collection_stmt = self.session.prepare(
|
||||
f"INSERT INTO {self.collection_table} (collection, s, p, o) VALUES (?, ?, ?, ?)"
|
||||
)
|
||||
|
||||
# Query statements for optimized access
|
||||
self.get_all_stmt = self.session.prepare(
|
||||
f"SELECT s, p, o FROM {self.subject_table} WHERE collection = ? LIMIT ? ALLOW FILTERING"
|
||||
|
|
@ -186,158 +179,177 @@ class KnowledgeGraph:
|
|||
)
|
||||
|
||||
self.get_spo_stmt = self.session.prepare(
|
||||
f"SELECT s as x FROM {self.subject_table} WHERE collection = ? AND s = ? AND p = ? AND o = ? LIMIT ?"
|
||||
f"SELECT s as x FROM {self.collection_table} WHERE collection = ? AND s = ? AND p = ? AND o = ? LIMIT ?"
|
||||
)
|
||||
|
||||
logger.info("Prepared statements initialized for optimal performance")
|
||||
# Delete statements for collection deletion
|
||||
self.delete_subject_stmt = self.session.prepare(
|
||||
f"DELETE FROM {self.subject_table} WHERE collection = ? AND s = ? AND p = ? AND o = ?"
|
||||
)
|
||||
|
||||
self.delete_po_stmt = self.session.prepare(
|
||||
f"DELETE FROM {self.po_table} WHERE collection = ? AND p = ? AND o = ? AND s = ?"
|
||||
)
|
||||
|
||||
self.delete_object_stmt = self.session.prepare(
|
||||
f"DELETE FROM {self.object_table} WHERE collection = ? AND o = ? AND s = ? AND p = ?"
|
||||
)
|
||||
|
||||
self.delete_collection_stmt = self.session.prepare(
|
||||
f"DELETE FROM {self.collection_table} WHERE collection = ? AND s = ? AND p = ? AND o = ?"
|
||||
)
|
||||
|
||||
logger.info("Prepared statements initialized for optimal performance (4 tables)")
|
||||
|
||||
def insert(self, collection, s, p, o):
|
||||
# Batch write to all four tables for consistency
|
||||
batch = BatchStatement()
|
||||
|
||||
if self.use_legacy:
|
||||
self.session.execute(
|
||||
f"insert into {self.table} (collection, s, p, o) values (%s, %s, %s, %s)",
|
||||
(collection, s, p, o)
|
||||
)
|
||||
else:
|
||||
# Batch write to all three tables for consistency
|
||||
batch = BatchStatement()
|
||||
# Insert into subject table
|
||||
batch.add(self.insert_subject_stmt, (collection, s, p, o))
|
||||
|
||||
# Insert into subject table
|
||||
batch.add(self.insert_subject_stmt, (collection, s, p, o))
|
||||
# Insert into predicate-object table (column order: collection, p, o, s)
|
||||
batch.add(self.insert_po_stmt, (collection, p, o, s))
|
||||
|
||||
# Insert into predicate-object table (column order: collection, p, o, s)
|
||||
batch.add(self.insert_po_stmt, (collection, p, o, s))
|
||||
# Insert into object table (column order: collection, o, s, p)
|
||||
batch.add(self.insert_object_stmt, (collection, o, s, p))
|
||||
|
||||
# Insert into object table (column order: collection, o, s, p)
|
||||
batch.add(self.insert_object_stmt, (collection, o, s, p))
|
||||
# Insert into collection table for SPO queries and deletion tracking
|
||||
batch.add(self.insert_collection_stmt, (collection, s, p, o))
|
||||
|
||||
self.session.execute(batch)
|
||||
self.session.execute(batch)
|
||||
|
||||
def get_all(self, collection, limit=50):
|
||||
if self.use_legacy:
|
||||
return self.session.execute(
|
||||
f"select s, p, o from {self.table} where collection = %s limit {limit}",
|
||||
(collection,)
|
||||
)
|
||||
else:
|
||||
# Use subject table for get_all queries
|
||||
return self.session.execute(
|
||||
self.get_all_stmt,
|
||||
(collection, limit)
|
||||
)
|
||||
# Use subject table for get_all queries
|
||||
return self.session.execute(
|
||||
self.get_all_stmt,
|
||||
(collection, limit)
|
||||
)
|
||||
|
||||
def get_s(self, collection, s, limit=10):
|
||||
if self.use_legacy:
|
||||
return self.session.execute(
|
||||
f"select p, o from {self.table} where collection = %s and s = %s limit {limit}",
|
||||
(collection, s)
|
||||
)
|
||||
else:
|
||||
# Optimized: Direct partition access with (collection, s)
|
||||
return self.session.execute(
|
||||
self.get_s_stmt,
|
||||
(collection, s, limit)
|
||||
)
|
||||
# Optimized: Direct partition access with (collection, s)
|
||||
return self.session.execute(
|
||||
self.get_s_stmt,
|
||||
(collection, s, limit)
|
||||
)
|
||||
|
||||
def get_p(self, collection, p, limit=10):
|
||||
if self.use_legacy:
|
||||
return self.session.execute(
|
||||
f"select s, o from {self.table} where collection = %s and p = %s limit {limit}",
|
||||
(collection, p)
|
||||
)
|
||||
else:
|
||||
# Optimized: Use po_table for direct partition access
|
||||
return self.session.execute(
|
||||
self.get_p_stmt,
|
||||
(collection, p, limit)
|
||||
)
|
||||
# Optimized: Use po_table for direct partition access
|
||||
return self.session.execute(
|
||||
self.get_p_stmt,
|
||||
(collection, p, limit)
|
||||
)
|
||||
|
||||
def get_o(self, collection, o, limit=10):
|
||||
if self.use_legacy:
|
||||
return self.session.execute(
|
||||
f"select s, p from {self.table} where collection = %s and o = %s limit {limit}",
|
||||
(collection, o)
|
||||
)
|
||||
else:
|
||||
# Optimized: Use object_table for direct partition access
|
||||
return self.session.execute(
|
||||
self.get_o_stmt,
|
||||
(collection, o, limit)
|
||||
)
|
||||
# Optimized: Use object_table for direct partition access
|
||||
return self.session.execute(
|
||||
self.get_o_stmt,
|
||||
(collection, o, limit)
|
||||
)
|
||||
|
||||
def get_sp(self, collection, s, p, limit=10):
|
||||
if self.use_legacy:
|
||||
return self.session.execute(
|
||||
f"select o from {self.table} where collection = %s and s = %s and p = %s limit {limit}",
|
||||
(collection, s, p)
|
||||
)
|
||||
else:
|
||||
# Optimized: Use subject_table with clustering key access
|
||||
return self.session.execute(
|
||||
self.get_sp_stmt,
|
||||
(collection, s, p, limit)
|
||||
)
|
||||
# Optimized: Use subject_table with clustering key access
|
||||
return self.session.execute(
|
||||
self.get_sp_stmt,
|
||||
(collection, s, p, limit)
|
||||
)
|
||||
|
||||
def get_po(self, collection, p, o, limit=10):
|
||||
if self.use_legacy:
|
||||
return self.session.execute(
|
||||
f"select s from {self.table} where collection = %s and p = %s and o = %s limit {limit} allow filtering",
|
||||
(collection, p, o)
|
||||
)
|
||||
else:
|
||||
# CRITICAL OPTIMIZATION: Use po_table - NO MORE ALLOW FILTERING!
|
||||
return self.session.execute(
|
||||
self.get_po_stmt,
|
||||
(collection, p, o, limit)
|
||||
)
|
||||
# CRITICAL OPTIMIZATION: Use po_table - NO MORE ALLOW FILTERING!
|
||||
return self.session.execute(
|
||||
self.get_po_stmt,
|
||||
(collection, p, o, limit)
|
||||
)
|
||||
|
||||
def get_os(self, collection, o, s, limit=10):
|
||||
if self.use_legacy:
|
||||
return self.session.execute(
|
||||
f"select p from {self.table} where collection = %s and o = %s and s = %s limit {limit} allow filtering",
|
||||
(collection, o, s)
|
||||
)
|
||||
else:
|
||||
# Optimized: Use subject_table with clustering access (no more ALLOW FILTERING)
|
||||
return self.session.execute(
|
||||
self.get_os_stmt,
|
||||
(collection, s, o, limit)
|
||||
)
|
||||
# Optimized: Use subject_table with clustering access (no more ALLOW FILTERING)
|
||||
return self.session.execute(
|
||||
self.get_os_stmt,
|
||||
(collection, s, o, limit)
|
||||
)
|
||||
|
||||
def get_spo(self, collection, s, p, o, limit=10):
|
||||
if self.use_legacy:
|
||||
return self.session.execute(
|
||||
f"""select s as x from {self.table} where collection = %s and s = %s and p = %s and o = %s limit {limit}""",
|
||||
(collection, s, p, o)
|
||||
# Optimized: Use collection_table for exact key lookup
|
||||
return self.session.execute(
|
||||
self.get_spo_stmt,
|
||||
(collection, s, p, o, limit)
|
||||
)
|
||||
|
||||
def collection_exists(self, collection):
|
||||
"""Check if collection exists by querying collection_metadata table"""
|
||||
try:
|
||||
result = self.session.execute(
|
||||
f"SELECT collection FROM {self.collection_metadata_table} WHERE collection = %s LIMIT 1",
|
||||
(collection,)
|
||||
)
|
||||
else:
|
||||
# Optimized: Use subject_table for exact key lookup
|
||||
return self.session.execute(
|
||||
self.get_spo_stmt,
|
||||
(collection, s, p, o, limit)
|
||||
return bool(list(result))
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking collection existence: {e}")
|
||||
return False
|
||||
|
||||
def create_collection(self, collection):
|
||||
"""Create collection by inserting metadata row"""
|
||||
try:
|
||||
import datetime
|
||||
self.session.execute(
|
||||
f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)",
|
||||
(collection, datetime.datetime.now())
|
||||
)
|
||||
logger.info(f"Created collection metadata for {collection}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating collection: {e}")
|
||||
raise e
|
||||
|
||||
def delete_collection(self, collection):
|
||||
"""Delete all triples for a specific collection"""
|
||||
if self.use_legacy:
|
||||
self.session.execute(
|
||||
f"delete from {self.table} where collection = %s",
|
||||
(collection,)
|
||||
)
|
||||
else:
|
||||
# Delete from all three tables
|
||||
self.session.execute(
|
||||
f"delete from {self.subject_table} where collection = %s",
|
||||
(collection,)
|
||||
)
|
||||
self.session.execute(
|
||||
f"delete from {self.po_table} where collection = %s",
|
||||
(collection,)
|
||||
)
|
||||
self.session.execute(
|
||||
f"delete from {self.object_table} where collection = %s",
|
||||
(collection,)
|
||||
)
|
||||
"""Delete all triples for a specific collection
|
||||
|
||||
Uses collection_table to enumerate all triples, then deletes from all 4 tables
|
||||
using full partition keys for optimal performance with compound keys.
|
||||
"""
|
||||
# Step 1: Read all triples from collection_table (single partition read)
|
||||
rows = self.session.execute(
|
||||
f"SELECT s, p, o FROM {self.collection_table} WHERE collection = %s",
|
||||
(collection,)
|
||||
)
|
||||
|
||||
# Step 2: Delete each triple from all 4 tables using full partition keys
|
||||
# Batch deletions for efficiency
|
||||
batch = BatchStatement()
|
||||
count = 0
|
||||
|
||||
for row in rows:
|
||||
s, p, o = row.s, row.p, row.o
|
||||
|
||||
# Delete from subject table (partition key: collection, s)
|
||||
batch.add(self.delete_subject_stmt, (collection, s, p, o))
|
||||
|
||||
# Delete from predicate-object table (partition key: collection, p)
|
||||
batch.add(self.delete_po_stmt, (collection, p, o, s))
|
||||
|
||||
# Delete from object table (partition key: collection, o)
|
||||
batch.add(self.delete_object_stmt, (collection, o, s, p))
|
||||
|
||||
# Delete from collection table (partition key: collection only)
|
||||
batch.add(self.delete_collection_stmt, (collection, s, p, o))
|
||||
|
||||
count += 1
|
||||
|
||||
# Execute batch every 100 triples to avoid oversized batches
|
||||
if count % 100 == 0:
|
||||
self.session.execute(batch)
|
||||
batch = BatchStatement()
|
||||
|
||||
# Execute remaining deletions
|
||||
if count % 100 != 0:
|
||||
self.session.execute(batch)
|
||||
|
||||
# Step 3: Delete collection metadata
|
||||
self.session.execute(
|
||||
f"DELETE FROM {self.collection_metadata_table} WHERE collection = %s",
|
||||
(collection,)
|
||||
)
|
||||
|
||||
logger.info(f"Deleted {count} triples from collection {collection}")
|
||||
|
||||
def close(self):
|
||||
"""Close the Cassandra session and cluster connections properly"""
|
||||
|
|
|
|||
|
|
@ -49,6 +49,22 @@ class DocVectors:
|
|||
self.next_reload = time.time() + self.reload_time
|
||||
logger.debug(f"Reload at {self.next_reload}")
|
||||
|
||||
def collection_exists(self, user, collection):
|
||||
"""Check if collection exists (dimension-independent check)"""
|
||||
collection_name = make_safe_collection_name(user, collection, self.prefix)
|
||||
return self.client.has_collection(collection_name)
|
||||
|
||||
def create_collection(self, user, collection, dimension=384):
|
||||
"""Create collection with default dimension"""
|
||||
collection_name = make_safe_collection_name(user, collection, self.prefix)
|
||||
|
||||
if self.client.has_collection(collection_name):
|
||||
logger.info(f"Collection {collection_name} already exists")
|
||||
return
|
||||
|
||||
self.init_collection(dimension, user, collection)
|
||||
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
|
||||
|
||||
def init_collection(self, dimension, user, collection):
|
||||
|
||||
collection_name = make_safe_collection_name(user, collection, self.prefix)
|
||||
|
|
@ -128,14 +144,6 @@ class DocVectors:
|
|||
|
||||
coll = self.collections[(dim, user, collection)]
|
||||
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
"params": {
|
||||
"radius": 0.1,
|
||||
"range_filter": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug("Loading...")
|
||||
self.client.load_collection(
|
||||
collection_name=coll,
|
||||
|
|
@ -145,10 +153,11 @@ class DocVectors:
|
|||
|
||||
res = self.client.search(
|
||||
collection_name=coll,
|
||||
anns_field="vector",
|
||||
data=[embeds],
|
||||
limit=limit,
|
||||
output_fields=fields,
|
||||
search_params=search_params,
|
||||
search_params={ "metric_type": "COSINE" },
|
||||
)[0]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -49,6 +49,22 @@ class EntityVectors:
|
|||
self.next_reload = time.time() + self.reload_time
|
||||
logger.debug(f"Reload at {self.next_reload}")
|
||||
|
||||
def collection_exists(self, user, collection):
|
||||
"""Check if collection exists (dimension-independent check)"""
|
||||
collection_name = make_safe_collection_name(user, collection, self.prefix)
|
||||
return self.client.has_collection(collection_name)
|
||||
|
||||
def create_collection(self, user, collection, dimension=384):
|
||||
"""Create collection with default dimension"""
|
||||
collection_name = make_safe_collection_name(user, collection, self.prefix)
|
||||
|
||||
if self.client.has_collection(collection_name):
|
||||
logger.info(f"Collection {collection_name} already exists")
|
||||
return
|
||||
|
||||
self.init_collection(dimension, user, collection)
|
||||
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
|
||||
|
||||
def init_collection(self, dimension, user, collection):
|
||||
|
||||
collection_name = make_safe_collection_name(user, collection, self.prefix)
|
||||
|
|
@ -128,14 +144,6 @@ class EntityVectors:
|
|||
|
||||
coll = self.collections[(dim, user, collection)]
|
||||
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
"params": {
|
||||
"radius": 0.1,
|
||||
"range_filter": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug("Loading...")
|
||||
self.client.load_collection(
|
||||
collection_name=coll,
|
||||
|
|
@ -145,10 +153,11 @@ class EntityVectors:
|
|||
|
||||
res = self.client.search(
|
||||
collection_name=coll,
|
||||
anns_field="vector",
|
||||
data=[embeds],
|
||||
limit=limit,
|
||||
output_fields=fields,
|
||||
search_params=search_params,
|
||||
search_params={ "metric_type": "COSINE" },
|
||||
)[0]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class CollectionManager:
|
|||
|
||||
async def ensure_collection_exists(self, user: str, collection: str):
|
||||
"""
|
||||
Ensure a collection exists, creating it if necessary (lazy creation)
|
||||
Ensure a collection exists, creating it if necessary with broadcast to storage
|
||||
|
||||
Args:
|
||||
user: User ID
|
||||
|
|
@ -74,7 +74,7 @@ class CollectionManager:
|
|||
return
|
||||
|
||||
# Create new collection with default metadata
|
||||
logger.info(f"Creating new collection {user}/{collection}")
|
||||
logger.info(f"Auto-creating collection {user}/{collection} from document submission")
|
||||
await self.table_store.create_collection(
|
||||
user=user,
|
||||
collection=collection,
|
||||
|
|
@ -83,10 +83,64 @@ class CollectionManager:
|
|||
tags=set()
|
||||
)
|
||||
|
||||
# Broadcast collection creation to all storage backends
|
||||
creation_key = (user, collection)
|
||||
logger.info(f"Broadcasting create-collection for {creation_key}")
|
||||
|
||||
self.pending_deletions[creation_key] = {
|
||||
"responses_pending": 4, # doc-embeddings, graph-embeddings, object, triples
|
||||
"responses_received": [],
|
||||
"all_successful": True,
|
||||
"error_messages": [],
|
||||
"deletion_complete": asyncio.Event()
|
||||
}
|
||||
|
||||
storage_request = StorageManagementRequest(
|
||||
operation="create-collection",
|
||||
user=user,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
# Send creation requests to all storage types
|
||||
if self.vector_storage_producer:
|
||||
await self.vector_storage_producer.send(storage_request)
|
||||
if self.object_storage_producer:
|
||||
await self.object_storage_producer.send(storage_request)
|
||||
if self.triples_storage_producer:
|
||||
await self.triples_storage_producer.send(storage_request)
|
||||
|
||||
# Wait for all storage creations to complete (with timeout)
|
||||
creation_info = self.pending_deletions[creation_key]
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
creation_info["deletion_complete"].wait(),
|
||||
timeout=30.0 # 30 second timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Timeout waiting for storage creation responses for {creation_key}")
|
||||
creation_info["all_successful"] = False
|
||||
creation_info["error_messages"].append("Timeout waiting for storage creation")
|
||||
|
||||
# Check if all creations succeeded
|
||||
if not creation_info["all_successful"]:
|
||||
error_msg = f"Storage creation failed: {'; '.join(creation_info['error_messages'])}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# Clean up metadata on failure
|
||||
await self.table_store.delete_collection(user, collection)
|
||||
|
||||
# Clean up tracking
|
||||
del self.pending_deletions[creation_key]
|
||||
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
# Clean up tracking
|
||||
del self.pending_deletions[creation_key]
|
||||
logger.info(f"Collection {creation_key} auto-created successfully in all storage backends")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring collection exists: {e}")
|
||||
# Don't fail the operation if collection creation fails
|
||||
# This maintains backward compatibility
|
||||
raise e
|
||||
|
||||
async def list_collections(self, request: CollectionManagementRequest) -> CollectionManagementResponse:
|
||||
"""
|
||||
|
|
@ -154,6 +208,67 @@ class CollectionManager:
|
|||
tags=tags
|
||||
)
|
||||
|
||||
# Broadcast collection creation to all storage backends
|
||||
creation_key = (request.user, request.collection)
|
||||
logger.info(f"Broadcasting create-collection for {creation_key}")
|
||||
|
||||
self.pending_deletions[creation_key] = {
|
||||
"responses_pending": 4, # doc-embeddings, graph-embeddings, object, triples
|
||||
"responses_received": [],
|
||||
"all_successful": True,
|
||||
"error_messages": [],
|
||||
"deletion_complete": asyncio.Event()
|
||||
}
|
||||
|
||||
storage_request = StorageManagementRequest(
|
||||
operation="create-collection",
|
||||
user=request.user,
|
||||
collection=request.collection
|
||||
)
|
||||
|
||||
# Send creation requests to all storage types
|
||||
if self.vector_storage_producer:
|
||||
await self.vector_storage_producer.send(storage_request)
|
||||
if self.object_storage_producer:
|
||||
await self.object_storage_producer.send(storage_request)
|
||||
if self.triples_storage_producer:
|
||||
await self.triples_storage_producer.send(storage_request)
|
||||
|
||||
# Wait for all storage creations to complete (with timeout)
|
||||
creation_info = self.pending_deletions[creation_key]
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
creation_info["deletion_complete"].wait(),
|
||||
timeout=30.0 # 30 second timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Timeout waiting for storage creation responses for {creation_key}")
|
||||
creation_info["all_successful"] = False
|
||||
creation_info["error_messages"].append("Timeout waiting for storage creation")
|
||||
|
||||
# Check if all creations succeeded
|
||||
if not creation_info["all_successful"]:
|
||||
error_msg = f"Storage creation failed: {'; '.join(creation_info['error_messages'])}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# Clean up metadata on failure
|
||||
await self.table_store.delete_collection(request.user, request.collection)
|
||||
|
||||
# Clean up tracking
|
||||
del self.pending_deletions[creation_key]
|
||||
|
||||
return CollectionManagementResponse(
|
||||
error=Error(
|
||||
type="storage_creation_error",
|
||||
message=error_msg
|
||||
),
|
||||
timestamp=datetime.now().isoformat()
|
||||
)
|
||||
|
||||
# Clean up tracking
|
||||
del self.pending_deletions[creation_key]
|
||||
logger.info(f"Collection {creation_key} created successfully in all storage backends")
|
||||
|
||||
# Get the newly created collection for response
|
||||
created_collection = await self.table_store.get_collection(request.user, request.collection)
|
||||
|
||||
|
|
@ -213,7 +328,7 @@ class CollectionManager:
|
|||
|
||||
# Track this deletion request
|
||||
self.pending_deletions[deletion_key] = {
|
||||
"responses_pending": 3, # vector, object, triples
|
||||
"responses_pending": 4, # doc-embeddings, graph-embeddings, object, triples
|
||||
"responses_received": [],
|
||||
"all_successful": True,
|
||||
"error_messages": [],
|
||||
|
|
@ -303,9 +418,9 @@ class CollectionManager:
|
|||
if response.error and response.error.message:
|
||||
info["all_successful"] = False
|
||||
info["error_messages"].append(response.error.message)
|
||||
logger.warning(f"Storage deletion failed for {deletion_key}: {response.error.message}")
|
||||
logger.warning(f"Storage operation failed for {deletion_key}: {response.error.message}")
|
||||
else:
|
||||
logger.debug(f"Storage deletion succeeded for {deletion_key}")
|
||||
logger.debug(f"Storage operation succeeded for {deletion_key}")
|
||||
|
||||
# If all responses received, signal completion
|
||||
if info["responses_pending"] == 0:
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class Processor(LlmService):
|
|||
token = params.get("token", default_token)
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
max_output = params.get("max_output", default_max_output)
|
||||
model = default_model
|
||||
model = params.get("model", default_model)
|
||||
|
||||
if endpoint is None:
|
||||
raise RuntimeError("Azure endpoint not specified")
|
||||
|
|
@ -53,9 +53,11 @@ class Processor(LlmService):
|
|||
self.token = token
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
self.model = model
|
||||
self.default_model = model
|
||||
|
||||
def build_prompt(self, system, content):
|
||||
def build_prompt(self, system, content, temperature=None):
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
data = {
|
||||
"messages": [
|
||||
|
|
@ -67,7 +69,7 @@ class Processor(LlmService):
|
|||
}
|
||||
],
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"temperature": effective_temperature,
|
||||
"top_p": 1
|
||||
}
|
||||
|
||||
|
|
@ -100,13 +102,22 @@ class Processor(LlmService):
|
|||
|
||||
return result
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
try:
|
||||
|
||||
prompt = self.build_prompt(
|
||||
system,
|
||||
prompt
|
||||
prompt,
|
||||
effective_temperature
|
||||
)
|
||||
|
||||
response = self.call_llm(prompt)
|
||||
|
|
@ -125,7 +136,7 @@ class Processor(LlmService):
|
|||
text = resp,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = self.model
|
||||
model = model_name
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ class Processor(LlmService):
|
|||
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
self.model = model
|
||||
self.default_model = model
|
||||
|
||||
self.openai = AzureOpenAI(
|
||||
api_key=token,
|
||||
|
|
@ -62,14 +62,22 @@ class Processor(LlmService):
|
|||
azure_endpoint = endpoint,
|
||||
)
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
prompt = system + "\n\n" + prompt
|
||||
|
||||
try:
|
||||
|
||||
resp = self.openai.chat.completions.create(
|
||||
model=self.model,
|
||||
model=model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
|
|
@ -81,7 +89,7 @@ class Processor(LlmService):
|
|||
]
|
||||
}
|
||||
],
|
||||
temperature=self.temperature,
|
||||
temperature=effective_temperature,
|
||||
max_tokens=self.max_output,
|
||||
top_p=1,
|
||||
)
|
||||
|
|
@ -97,7 +105,7 @@ class Processor(LlmService):
|
|||
text = resp.choices[0].message.content,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = self.model
|
||||
model = model_name
|
||||
)
|
||||
|
||||
return r
|
||||
|
|
|
|||
|
|
@ -41,21 +41,29 @@ class Processor(LlmService):
|
|||
}
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.default_model = model
|
||||
self.claude = anthropic.Anthropic(api_key=api_key)
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
|
||||
logger.info("Claude LLM service initialized")
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
try:
|
||||
|
||||
response = message = self.claude.messages.create(
|
||||
model=self.model,
|
||||
model=model_name,
|
||||
max_tokens=self.max_output,
|
||||
temperature=self.temperature,
|
||||
temperature=effective_temperature,
|
||||
system = system,
|
||||
messages=[
|
||||
{
|
||||
|
|
@ -81,7 +89,7 @@ class Processor(LlmService):
|
|||
text = resp,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = self.model
|
||||
model = model_name
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -39,21 +39,29 @@ class Processor(LlmService):
|
|||
}
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.default_model = model
|
||||
self.temperature = temperature
|
||||
self.cohere = cohere.Client(api_key=api_key)
|
||||
|
||||
logger.info("Cohere LLM service initialized")
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
try:
|
||||
|
||||
output = self.cohere.chat(
|
||||
model=self.model,
|
||||
output = self.cohere.chat(
|
||||
model=model_name,
|
||||
message=prompt,
|
||||
preamble = system,
|
||||
temperature=self.temperature,
|
||||
temperature=effective_temperature,
|
||||
chat_history=[],
|
||||
prompt_truncation='auto',
|
||||
connectors=[]
|
||||
|
|
@ -71,7 +79,7 @@ class Processor(LlmService):
|
|||
text = resp,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = self.model
|
||||
model = model_name
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -53,10 +53,13 @@ class Processor(LlmService):
|
|||
)
|
||||
|
||||
self.client = genai.Client(api_key=api_key)
|
||||
self.model = model
|
||||
self.default_model = model
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
|
||||
# Cache for generation configs per model
|
||||
self.generation_configs = {}
|
||||
|
||||
block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH
|
||||
|
||||
self.safety_settings = [
|
||||
|
|
@ -83,22 +86,45 @@ class Processor(LlmService):
|
|||
|
||||
logger.info("GoogleAIStudio LLM service initialized")
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
def _get_or_create_config(self, model_name, temperature=None):
|
||||
"""Get or create generation config with dynamic temperature"""
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
generation_config = types.GenerateContentConfig(
|
||||
temperature = self.temperature,
|
||||
top_p = 1,
|
||||
top_k = 40,
|
||||
max_output_tokens = self.max_output,
|
||||
response_mime_type = "text/plain",
|
||||
system_instruction = system,
|
||||
safety_settings = self.safety_settings,
|
||||
)
|
||||
# Create cache key that includes temperature to avoid conflicts
|
||||
cache_key = f"{model_name}:{effective_temperature}"
|
||||
|
||||
if cache_key not in self.generation_configs:
|
||||
logger.info(f"Creating generation config for '{model_name}' with temperature {effective_temperature}")
|
||||
self.generation_configs[cache_key] = types.GenerateContentConfig(
|
||||
temperature = effective_temperature,
|
||||
top_p = 1,
|
||||
top_k = 40,
|
||||
max_output_tokens = self.max_output,
|
||||
response_mime_type = "text/plain",
|
||||
safety_settings = self.safety_settings,
|
||||
)
|
||||
|
||||
return self.generation_configs[cache_key]
|
||||
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
generation_config = self._get_or_create_config(model_name, effective_temperature)
|
||||
# Set system instruction per request (can't be cached)
|
||||
generation_config.system_instruction = system
|
||||
|
||||
try:
|
||||
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model,
|
||||
model=model_name,
|
||||
config=generation_config,
|
||||
contents=prompt,
|
||||
)
|
||||
|
|
@ -114,7 +140,7 @@ class Processor(LlmService):
|
|||
text = resp,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = self.model
|
||||
model = model_name
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class Processor(LlmService):
|
|||
}
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.default_model = model
|
||||
self.llamafile=llamafile
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
|
|
@ -50,25 +50,33 @@ class Processor(LlmService):
|
|||
|
||||
logger.info("Llamafile LLM service initialized")
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
prompt = system + "\n\n" + prompt
|
||||
|
||||
try:
|
||||
|
||||
resp = self.openai.chat.completions.create(
|
||||
model=self.model,
|
||||
model=model_name,
|
||||
messages=[
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
#temperature=self.temperature,
|
||||
#max_tokens=self.max_output,
|
||||
#top_p=1,
|
||||
#frequency_penalty=0,
|
||||
#presence_penalty=0,
|
||||
#response_format={
|
||||
# "type": "text"
|
||||
#}
|
||||
],
|
||||
temperature=effective_temperature,
|
||||
max_tokens=self.max_output,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
response_format={
|
||||
"type": "text"
|
||||
}
|
||||
)
|
||||
|
||||
inputtokens = resp.usage.prompt_tokens
|
||||
|
|
@ -82,7 +90,7 @@ class Processor(LlmService):
|
|||
text = resp.choices[0].message.content,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = "llama.cpp",
|
||||
model = model_name,
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class Processor(LlmService):
|
|||
}
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.default_model = model
|
||||
self.url = url + "v1/"
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
|
|
@ -50,7 +50,15 @@ class Processor(LlmService):
|
|||
|
||||
logger.info("LMStudio LLM service initialized")
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
prompt = system + "\n\n" + prompt
|
||||
|
||||
|
|
@ -59,18 +67,18 @@ class Processor(LlmService):
|
|||
logger.debug(f"Prompt: {prompt}")
|
||||
|
||||
resp = self.openai.chat.completions.create(
|
||||
model=self.model,
|
||||
model=model_name,
|
||||
messages=[
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
#temperature=self.temperature,
|
||||
#max_tokens=self.max_output,
|
||||
#top_p=1,
|
||||
#frequency_penalty=0,
|
||||
#presence_penalty=0,
|
||||
#response_format={
|
||||
# "type": "text"
|
||||
#}
|
||||
],
|
||||
temperature=effective_temperature,
|
||||
max_tokens=self.max_output,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
response_format={
|
||||
"type": "text"
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(f"Full response: {resp}")
|
||||
|
|
@ -86,7 +94,7 @@ class Processor(LlmService):
|
|||
text = resp.choices[0].message.content,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = self.model
|
||||
model = model_name
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -41,21 +41,29 @@ class Processor(LlmService):
|
|||
}
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.default_model = model
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
self.mistral = Mistral(api_key=api_key)
|
||||
|
||||
logger.info("Mistral LLM service initialized")
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
prompt = system + "\n\n" + prompt
|
||||
|
||||
try:
|
||||
|
||||
resp = self.mistral.chat.complete(
|
||||
model=self.model,
|
||||
model=model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
|
|
@ -67,7 +75,7 @@ class Processor(LlmService):
|
|||
]
|
||||
}
|
||||
],
|
||||
temperature=self.temperature,
|
||||
temperature=effective_temperature,
|
||||
max_tokens=self.max_output,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
|
|
@ -87,7 +95,7 @@ class Processor(LlmService):
|
|||
text = resp.choices[0].message.content,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = self.model
|
||||
model = model_name
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from .... base import LlmService, LlmResult
|
|||
default_ident = "text-completion"
|
||||
|
||||
default_model = 'gemma2:9b'
|
||||
default_temperature = 0.0
|
||||
default_ollama = os.getenv("OLLAMA_HOST", 'http://localhost:11434')
|
||||
|
||||
class Processor(LlmService):
|
||||
|
|
@ -24,25 +25,36 @@ class Processor(LlmService):
|
|||
def __init__(self, **params):
|
||||
|
||||
model = params.get("model", default_model)
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
ollama = params.get("ollama", default_ollama)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"model": model,
|
||||
"temperature": temperature,
|
||||
"ollama": ollama,
|
||||
}
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.default_model = model
|
||||
self.temperature = temperature
|
||||
self.llm = Client(host=ollama)
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
prompt = system + "\n\n" + prompt
|
||||
|
||||
try:
|
||||
|
||||
response = self.llm.generate(self.model, prompt)
|
||||
response = self.llm.generate(model_name, prompt, options={'temperature': effective_temperature})
|
||||
|
||||
response_text = response['response']
|
||||
logger.debug("Sending response...")
|
||||
|
|
@ -55,7 +67,7 @@ class Processor(LlmService):
|
|||
text = response_text,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = self.model
|
||||
model = model_name
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
@ -84,6 +96,13 @@ class Processor(LlmService):
|
|||
help=f'ollama (default: {default_ollama})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--temperature',
|
||||
type=float,
|
||||
default=default_temperature,
|
||||
help=f'LLM temperature parameter (default: {default_temperature})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ class Processor(LlmService):
|
|||
}
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.default_model = model
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
|
||||
|
|
@ -58,14 +58,22 @@ class Processor(LlmService):
|
|||
|
||||
logger.info("OpenAI LLM service initialized")
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
prompt = system + "\n\n" + prompt
|
||||
|
||||
try:
|
||||
|
||||
resp = self.openai.chat.completions.create(
|
||||
model=self.model,
|
||||
model=model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
|
|
@ -77,7 +85,7 @@ class Processor(LlmService):
|
|||
]
|
||||
}
|
||||
],
|
||||
temperature=self.temperature,
|
||||
temperature=effective_temperature,
|
||||
max_tokens=self.max_output,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
|
|
@ -97,7 +105,7 @@ class Processor(LlmService):
|
|||
text = resp.choices[0].message.content,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = self.model
|
||||
model = model_name
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -30,32 +30,43 @@ class Processor(LlmService):
|
|||
base_url = params.get("url", default_base_url)
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
max_output = params.get("max_output", default_max_output)
|
||||
model = params.get("model", "tgi")
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"temperature": temperature,
|
||||
"max_output": max_output,
|
||||
"url": base_url,
|
||||
"model": model,
|
||||
}
|
||||
)
|
||||
|
||||
self.base_url = base_url
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
self.default_model = model
|
||||
|
||||
self.session = aiohttp.ClientSession()
|
||||
|
||||
logger.info(f"Using TGI service at {base_url}")
|
||||
logger.info("TGI LLM service initialized")
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
request = {
|
||||
"model": "tgi",
|
||||
"model": model_name,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
|
|
@ -67,7 +78,7 @@ class Processor(LlmService):
|
|||
}
|
||||
],
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"temperature": effective_temperature,
|
||||
}
|
||||
|
||||
try:
|
||||
|
|
@ -96,7 +107,7 @@ class Processor(LlmService):
|
|||
text = ans,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = "tgi",
|
||||
model = model_name,
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -45,24 +45,32 @@ class Processor(LlmService):
|
|||
self.base_url = base_url
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
self.model = model
|
||||
self.default_model = model
|
||||
|
||||
self.session = aiohttp.ClientSession()
|
||||
|
||||
logger.info(f"Using vLLM service at {base_url}")
|
||||
logger.info("vLLM LLM service initialized")
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
request = {
|
||||
"model": self.model,
|
||||
"model": model_name,
|
||||
"prompt": system + "\n\n" + prompt,
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"temperature": effective_temperature,
|
||||
}
|
||||
|
||||
try:
|
||||
|
|
@ -91,7 +99,7 @@ class Processor(LlmService):
|
|||
text = ans,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = self.model,
|
||||
model = model_name,
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -57,21 +57,26 @@ class Processor(DocumentEmbeddingsQueryService):
|
|||
raise e
|
||||
self.last_collection = collection
|
||||
|
||||
def collection_exists(self, collection):
|
||||
"""Check if collection exists (no implicit creation)"""
|
||||
return self.qdrant.collection_exists(collection)
|
||||
|
||||
async def query_document_embeddings(self, msg):
|
||||
|
||||
try:
|
||||
|
||||
chunks = []
|
||||
|
||||
collection = (
|
||||
"d_" + msg.user + "_" + msg.collection
|
||||
)
|
||||
|
||||
# Check if collection exists - return empty if not
|
||||
if not self.collection_exists(collection):
|
||||
logger.info(f"Collection {collection} does not exist, returning empty results")
|
||||
return []
|
||||
|
||||
for vec in msg.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
collection = (
|
||||
"d_" + msg.user + "_" + msg.collection
|
||||
)
|
||||
|
||||
self.ensure_collection_exists(collection, dim)
|
||||
|
||||
search_result = self.qdrant.query_points(
|
||||
collection_name=collection,
|
||||
query=vec,
|
||||
|
|
|
|||
|
|
@ -57,6 +57,10 @@ class Processor(GraphEmbeddingsQueryService):
|
|||
raise e
|
||||
self.last_collection = collection
|
||||
|
||||
def collection_exists(self, collection):
|
||||
"""Check if collection exists (no implicit creation)"""
|
||||
return self.qdrant.collection_exists(collection)
|
||||
|
||||
def create_value(self, ent):
|
||||
if ent.startswith("http://") or ent.startswith("https://"):
|
||||
return Value(value=ent, is_uri=True)
|
||||
|
|
@ -70,12 +74,16 @@ class Processor(GraphEmbeddingsQueryService):
|
|||
entity_set = set()
|
||||
entities = []
|
||||
|
||||
for vec in msg.vectors:
|
||||
collection = (
|
||||
"t_" + msg.user + "_" + msg.collection
|
||||
)
|
||||
|
||||
dim = len(vec)
|
||||
collection = (
|
||||
"t_" + msg.user + "_" + msg.collection
|
||||
)
|
||||
# Check if collection exists - return empty if not
|
||||
if not self.collection_exists(collection):
|
||||
logger.info(f"Collection {collection} does not exist, returning empty results")
|
||||
return []
|
||||
|
||||
for vec in msg.vectors:
|
||||
|
||||
self.ensure_collection_exists(collection, dim)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,56 @@
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
||||
class LRUCacheWithTTL:
|
||||
"""LRU cache with TTL for label caching
|
||||
|
||||
CRITICAL SECURITY WARNING:
|
||||
This cache is shared within a GraphRag instance but GraphRag instances
|
||||
are created per-request. Cache keys MUST include user:collection prefix
|
||||
to ensure data isolation between different security contexts.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size=5000, ttl=300):
|
||||
self.cache = OrderedDict()
|
||||
self.access_times = {}
|
||||
self.max_size = max_size
|
||||
self.ttl = ttl
|
||||
|
||||
def get(self, key):
|
||||
if key not in self.cache:
|
||||
return None
|
||||
|
||||
# Check TTL expiration
|
||||
if time.time() - self.access_times[key] > self.ttl:
|
||||
del self.cache[key]
|
||||
del self.access_times[key]
|
||||
return None
|
||||
|
||||
# Move to end (most recently used)
|
||||
self.cache.move_to_end(key)
|
||||
return self.cache[key]
|
||||
|
||||
def put(self, key, value):
|
||||
if key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
else:
|
||||
if len(self.cache) >= self.max_size:
|
||||
# Remove least recently used
|
||||
oldest_key = next(iter(self.cache))
|
||||
del self.cache[oldest_key]
|
||||
del self.access_times[oldest_key]
|
||||
|
||||
self.cache[key] = value
|
||||
self.access_times[key] = time.time()
|
||||
|
||||
class Query:
|
||||
|
||||
def __init__(
|
||||
|
|
@ -61,8 +105,14 @@ class Query:
|
|||
|
||||
async def maybe_label(self, e):
|
||||
|
||||
if e in self.rag.label_cache:
|
||||
return self.rag.label_cache[e]
|
||||
# CRITICAL SECURITY: Cache key MUST include user and collection
|
||||
# to prevent data leakage between different contexts
|
||||
cache_key = f"{self.user}:{self.collection}:{e}"
|
||||
|
||||
# Check LRU cache first with isolated key
|
||||
cached_label = self.rag.label_cache.get(cache_key)
|
||||
if cached_label is not None:
|
||||
return cached_label
|
||||
|
||||
res = await self.rag.triples_client.query(
|
||||
s=e, p=LABEL, o=None, limit=1,
|
||||
|
|
@ -70,60 +120,104 @@ class Query:
|
|||
)
|
||||
|
||||
if len(res) == 0:
|
||||
self.rag.label_cache[e] = e
|
||||
self.rag.label_cache.put(cache_key, e)
|
||||
return e
|
||||
|
||||
self.rag.label_cache[e] = str(res[0].o)
|
||||
return self.rag.label_cache[e]
|
||||
label = str(res[0].o)
|
||||
self.rag.label_cache.put(cache_key, label)
|
||||
return label
|
||||
|
||||
async def execute_batch_triple_queries(self, entities, limit_per_entity):
|
||||
"""Execute triple queries for multiple entities concurrently"""
|
||||
tasks = []
|
||||
|
||||
for entity in entities:
|
||||
# Create concurrent tasks for all 3 query types per entity
|
||||
tasks.extend([
|
||||
self.rag.triples_client.query(
|
||||
s=entity, p=None, o=None,
|
||||
limit=limit_per_entity,
|
||||
user=self.user, collection=self.collection
|
||||
),
|
||||
self.rag.triples_client.query(
|
||||
s=None, p=entity, o=None,
|
||||
limit=limit_per_entity,
|
||||
user=self.user, collection=self.collection
|
||||
),
|
||||
self.rag.triples_client.query(
|
||||
s=None, p=None, o=entity,
|
||||
limit=limit_per_entity,
|
||||
user=self.user, collection=self.collection
|
||||
)
|
||||
])
|
||||
|
||||
# Execute all queries concurrently
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Combine all results
|
||||
all_triples = []
|
||||
for result in results:
|
||||
if not isinstance(result, Exception):
|
||||
all_triples.extend(result)
|
||||
|
||||
return all_triples
|
||||
|
||||
async def follow_edges_batch(self, entities, max_depth):
|
||||
"""Optimized iterative graph traversal with batching"""
|
||||
visited = set()
|
||||
current_level = set(entities)
|
||||
subgraph = set()
|
||||
|
||||
for depth in range(max_depth):
|
||||
if not current_level or len(subgraph) >= self.max_subgraph_size:
|
||||
break
|
||||
|
||||
# Filter out already visited entities
|
||||
unvisited_entities = [e for e in current_level if e not in visited]
|
||||
if not unvisited_entities:
|
||||
break
|
||||
|
||||
# Batch query all unvisited entities at current level
|
||||
triples = await self.execute_batch_triple_queries(
|
||||
unvisited_entities, self.triple_limit
|
||||
)
|
||||
|
||||
# Process results and collect next level entities
|
||||
next_level = set()
|
||||
for triple in triples:
|
||||
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
|
||||
subgraph.add(triple_tuple)
|
||||
|
||||
# Collect entities for next level (only from s and o positions)
|
||||
if depth < max_depth - 1: # Don't collect for final depth
|
||||
s, p, o = triple_tuple
|
||||
if s not in visited:
|
||||
next_level.add(s)
|
||||
if o not in visited:
|
||||
next_level.add(o)
|
||||
|
||||
# Stop if subgraph size limit reached
|
||||
if len(subgraph) >= self.max_subgraph_size:
|
||||
return subgraph
|
||||
|
||||
# Update for next iteration
|
||||
visited.update(current_level)
|
||||
current_level = next_level
|
||||
|
||||
return subgraph
|
||||
|
||||
async def follow_edges(self, ent, subgraph, path_length):
|
||||
|
||||
# Not needed?
|
||||
"""Legacy method - replaced by follow_edges_batch"""
|
||||
# Maintain backward compatibility with early termination checks
|
||||
if path_length <= 0:
|
||||
return
|
||||
|
||||
# Stop spanning around if the subgraph is already maxed out
|
||||
if len(subgraph) >= self.max_subgraph_size:
|
||||
return
|
||||
|
||||
res = await self.rag.triples_client.query(
|
||||
s=ent, p=None, o=None,
|
||||
limit=self.triple_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(str(triple.s), str(triple.p), str(triple.o))
|
||||
)
|
||||
if path_length > 1:
|
||||
await self.follow_edges(str(triple.o), subgraph, path_length-1)
|
||||
|
||||
res = await self.rag.triples_client.query(
|
||||
s=None, p=ent, o=None,
|
||||
limit=self.triple_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(str(triple.s), str(triple.p), str(triple.o))
|
||||
)
|
||||
|
||||
res = await self.rag.triples_client.query(
|
||||
s=None, p=None, o=ent,
|
||||
limit=self.triple_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(str(triple.s), str(triple.p), str(triple.o))
|
||||
)
|
||||
if path_length > 1:
|
||||
await self.follow_edges(
|
||||
str(triple.s), subgraph, path_length-1
|
||||
)
|
||||
# For backward compatibility, convert to new approach
|
||||
batch_result = await self.follow_edges_batch([ent], path_length)
|
||||
subgraph.update(batch_result)
|
||||
|
||||
async def get_subgraph(self, query):
|
||||
|
||||
|
|
@ -132,31 +226,52 @@ class Query:
|
|||
if self.verbose:
|
||||
logger.debug("Getting subgraph...")
|
||||
|
||||
subgraph = set()
|
||||
# Use optimized batch traversal instead of sequential processing
|
||||
subgraph = await self.follow_edges_batch(entities, self.max_path_length)
|
||||
|
||||
for ent in entities:
|
||||
await self.follow_edges(ent, subgraph, self.max_path_length)
|
||||
return list(subgraph)
|
||||
|
||||
subgraph = list(subgraph)
|
||||
async def resolve_labels_batch(self, entities):
|
||||
"""Resolve labels for multiple entities in parallel"""
|
||||
tasks = []
|
||||
for entity in entities:
|
||||
tasks.append(self.maybe_label(entity))
|
||||
|
||||
return subgraph
|
||||
return await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def get_labelgraph(self, query):
|
||||
|
||||
subgraph = await self.get_subgraph(query)
|
||||
|
||||
# Filter out label triples
|
||||
filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL]
|
||||
|
||||
# Collect all unique entities that need label resolution
|
||||
entities_to_resolve = set()
|
||||
for s, p, o in filtered_subgraph:
|
||||
entities_to_resolve.update([s, p, o])
|
||||
|
||||
# Batch resolve labels for all entities in parallel
|
||||
entity_list = list(entities_to_resolve)
|
||||
resolved_labels = await self.resolve_labels_batch(entity_list)
|
||||
|
||||
# Create entity-to-label mapping
|
||||
label_map = {}
|
||||
for entity, label in zip(entity_list, resolved_labels):
|
||||
if not isinstance(label, Exception):
|
||||
label_map[entity] = label
|
||||
else:
|
||||
label_map[entity] = entity # Fallback to entity itself
|
||||
|
||||
# Apply labels to subgraph
|
||||
sg2 = []
|
||||
|
||||
for edge in subgraph:
|
||||
|
||||
if edge[1] == LABEL:
|
||||
continue
|
||||
|
||||
s = await self.maybe_label(edge[0])
|
||||
p = await self.maybe_label(edge[1])
|
||||
o = await self.maybe_label(edge[2])
|
||||
|
||||
sg2.append((s, p, o))
|
||||
for s, p, o in filtered_subgraph:
|
||||
labeled_triple = (
|
||||
label_map.get(s, s),
|
||||
label_map.get(p, p),
|
||||
label_map.get(o, o)
|
||||
)
|
||||
sg2.append(labeled_triple)
|
||||
|
||||
sg2 = sg2[0:self.max_subgraph_size]
|
||||
|
||||
|
|
@ -171,6 +286,13 @@ class Query:
|
|||
return sg2
|
||||
|
||||
class GraphRag:
|
||||
"""
|
||||
CRITICAL SECURITY:
|
||||
This class MUST be instantiated per-request to ensure proper isolation
|
||||
between users and collections. The cache within this instance will only
|
||||
live for the duration of a single request, preventing cross-contamination
|
||||
of data between different security contexts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, prompt_client, embeddings_client, graph_embeddings_client,
|
||||
|
|
@ -184,7 +306,9 @@ class GraphRag:
|
|||
self.graph_embeddings_client = graph_embeddings_client
|
||||
self.triples_client = triples_client
|
||||
|
||||
self.label_cache = {}
|
||||
# Replace simple dict with LRU cache with TTL
|
||||
# CRITICAL: This cache only lives for one request due to per-request instantiation
|
||||
self.label_cache = LRUCacheWithTTL(max_size=5000, ttl=300)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("GraphRag initialized")
|
||||
|
|
|
|||
|
|
@ -45,6 +45,10 @@ class Processor(FlowProcessor):
|
|||
self.default_max_subgraph_size = max_subgraph_size
|
||||
self.default_max_path_length = max_path_length
|
||||
|
||||
# CRITICAL SECURITY: NEVER share data between users or collections
|
||||
# Each user/collection combination MUST have isolated data access
|
||||
# Caching must NEVER allow information leakage across these boundaries
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "request",
|
||||
|
|
@ -93,11 +97,14 @@ class Processor(FlowProcessor):
|
|||
|
||||
try:
|
||||
|
||||
self.rag = GraphRag(
|
||||
embeddings_client = flow("embeddings-request"),
|
||||
graph_embeddings_client = flow("graph-embeddings-request"),
|
||||
triples_client = flow("triples-request"),
|
||||
prompt_client = flow("prompt-request"),
|
||||
# CRITICAL SECURITY: Create new GraphRag instance per request
|
||||
# This ensures proper isolation between users and collections
|
||||
# Flow clients are request-scoped and must not be shared
|
||||
rag = GraphRag(
|
||||
embeddings_client=flow("embeddings-request"),
|
||||
graph_embeddings_client=flow("graph-embeddings-request"),
|
||||
triples_client=flow("triples-request"),
|
||||
prompt_client=flow("prompt-request"),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
|
@ -128,7 +135,7 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
max_path_length = self.default_max_path_length
|
||||
|
||||
response = await self.rag.query(
|
||||
response = await rag.query(
|
||||
query = v.query, user = v.user, collection = v.collection,
|
||||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
|
|
|
|||
|
|
@ -60,19 +60,34 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
metrics=storage_response_metrics,
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
"""Start the processor and its storage management consumer"""
|
||||
await super().start()
|
||||
await self.storage_request_consumer.start()
|
||||
await self.storage_response_producer.start()
|
||||
|
||||
async def store_document_embeddings(self, message):
|
||||
|
||||
# Validate collection exists before accepting writes
|
||||
if not self.vecstore.collection_exists(message.metadata.user, message.metadata.collection):
|
||||
error_msg = (
|
||||
f"Collection {message.metadata.collection} does not exist. "
|
||||
f"Create it first with tg-set-collection."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
for emb in message.chunks:
|
||||
|
||||
if emb.chunk is None or emb.chunk == b"": continue
|
||||
|
||||
|
||||
chunk = emb.chunk.decode("utf-8")
|
||||
if chunk == "": continue
|
||||
|
||||
for vec in emb.vectors:
|
||||
self.vecstore.insert(
|
||||
vec, chunk,
|
||||
message.metadata.user,
|
||||
vec, chunk,
|
||||
message.metadata.user,
|
||||
message.metadata.collection
|
||||
)
|
||||
|
||||
|
|
@ -87,18 +102,21 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
help=f'Milvus store URI (default: {default_store_uri})'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
async def on_storage_management(self, message, consumer, flow):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
request = message.value()
|
||||
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
if request.operation == "create-collection":
|
||||
await self.handle_create_collection(request)
|
||||
elif request.operation == "delete-collection":
|
||||
await self.handle_delete_collection(request)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
message=f"Unknown operation: {request.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
|
@ -113,17 +131,40 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
async def handle_create_collection(self, request):
|
||||
"""Create a Milvus collection for document embeddings"""
|
||||
try:
|
||||
if self.vecstore.collection_exists(request.user, request.collection):
|
||||
logger.info(f"Collection {request.user}/{request.collection} already exists")
|
||||
else:
|
||||
self.vecstore.create_collection(request.user, request.collection)
|
||||
logger.info(f"Created collection {request.user}/{request.collection}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(error=None)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create collection: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="creation_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, request):
|
||||
"""Delete the collection for document embeddings"""
|
||||
try:
|
||||
self.vecstore.delete_collection(message.user, message.collection)
|
||||
self.vecstore.delete_collection(request.user, request.collection)
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
|
|
|||
|
|
@ -115,38 +115,36 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
"Gave up waiting for index creation"
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
"""Start the processor and its storage management consumer"""
|
||||
await super().start()
|
||||
await self.storage_request_consumer.start()
|
||||
await self.storage_response_producer.start()
|
||||
|
||||
async def store_document_embeddings(self, message):
|
||||
|
||||
index_name = (
|
||||
"d-" + message.metadata.user + "-" + message.metadata.collection
|
||||
)
|
||||
|
||||
# Validate collection exists before accepting writes
|
||||
if not self.pinecone.has_index(index_name):
|
||||
error_msg = (
|
||||
f"Collection {message.metadata.collection} does not exist. "
|
||||
f"Create it first with tg-set-collection."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
for emb in message.chunks:
|
||||
|
||||
if emb.chunk is None or emb.chunk == b"": continue
|
||||
|
||||
|
||||
chunk = emb.chunk.decode("utf-8")
|
||||
if chunk == "": continue
|
||||
|
||||
for vec in emb.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
index_name = (
|
||||
"d-" + message.metadata.user + "-" + message.metadata.collection
|
||||
)
|
||||
|
||||
if index_name != self.last_index_name:
|
||||
|
||||
if not self.pinecone.has_index(index_name):
|
||||
|
||||
try:
|
||||
|
||||
self.create_index(index_name, dim)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Pinecone index creation failed")
|
||||
raise e
|
||||
|
||||
logger.info(f"Index {index_name} created")
|
||||
|
||||
self.last_index_name = index_name
|
||||
|
||||
index = self.pinecone.Index(index_name)
|
||||
|
||||
# Generate unique ID for each vector
|
||||
|
|
@ -192,18 +190,21 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
help=f'Pinecone region, (default: {default_region}'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
async def on_storage_management(self, message, consumer, flow):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
request = message.value()
|
||||
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
if request.operation == "create-collection":
|
||||
await self.handle_create_collection(request)
|
||||
elif request.operation == "delete-collection":
|
||||
await self.handle_delete_collection(request)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
message=f"Unknown operation: {request.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
|
@ -218,10 +219,36 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
async def handle_create_collection(self, request):
|
||||
"""Create a Pinecone index for document embeddings"""
|
||||
try:
|
||||
index_name = f"d-{request.user}-{request.collection}"
|
||||
|
||||
if self.pinecone.has_index(index_name):
|
||||
logger.info(f"Pinecone index {index_name} already exists")
|
||||
else:
|
||||
# Create with default dimension - will need to be recreated if dimension doesn't match
|
||||
self.create_index(index_name, dim=384)
|
||||
logger.info(f"Created Pinecone index: {index_name}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(error=None)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create collection: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="creation_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, request):
|
||||
"""Delete the collection for document embeddings"""
|
||||
try:
|
||||
index_name = f"d-{message.user}-{message.collection}"
|
||||
index_name = f"d-{request.user}-{request.collection}"
|
||||
|
||||
if self.pinecone.has_index(index_name):
|
||||
self.pinecone.delete_index(index_name)
|
||||
|
|
@ -234,7 +261,7 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
|
|
|||
|
|
@ -36,8 +36,6 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
}
|
||||
)
|
||||
|
||||
self.last_collection = None
|
||||
|
||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||
|
||||
# Set up storage management if base class attributes are available
|
||||
|
|
@ -71,8 +69,30 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
metrics=storage_response_metrics,
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
"""Start the processor and its storage management consumer"""
|
||||
await super().start()
|
||||
if hasattr(self, 'storage_request_consumer'):
|
||||
await self.storage_request_consumer.start()
|
||||
if hasattr(self, 'storage_response_producer'):
|
||||
await self.storage_response_producer.start()
|
||||
|
||||
async def store_document_embeddings(self, message):
|
||||
|
||||
# Validate collection exists before accepting writes
|
||||
collection = (
|
||||
"d_" + message.metadata.user + "_" +
|
||||
message.metadata.collection
|
||||
)
|
||||
|
||||
if not self.qdrant.collection_exists(collection):
|
||||
error_msg = (
|
||||
f"Collection {message.metadata.collection} does not exist. "
|
||||
f"Create it first with tg-set-collection."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
for emb in message.chunks:
|
||||
|
||||
chunk = emb.chunk.decode("utf-8")
|
||||
|
|
@ -80,29 +100,6 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
|
||||
for vec in emb.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
collection = (
|
||||
"d_" + message.metadata.user + "_" +
|
||||
message.metadata.collection
|
||||
)
|
||||
|
||||
if collection != self.last_collection:
|
||||
|
||||
if not self.qdrant.collection_exists(collection):
|
||||
|
||||
try:
|
||||
self.qdrant.create_collection(
|
||||
collection_name=collection,
|
||||
vectors_config=VectorParams(
|
||||
size=dim, distance=Distance.COSINE
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Qdrant collection creation failed")
|
||||
raise e
|
||||
|
||||
self.last_collection = collection
|
||||
|
||||
self.qdrant.upsert(
|
||||
collection_name=collection,
|
||||
points=[
|
||||
|
|
@ -133,18 +130,21 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
help=f'Qdrant API key (default: None)'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
async def on_storage_management(self, message, consumer, flow):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
request = message.value()
|
||||
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
if request.operation == "create-collection":
|
||||
await self.handle_create_collection(request)
|
||||
elif request.operation == "delete-collection":
|
||||
await self.handle_delete_collection(request)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
message=f"Unknown operation: {request.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
|
@ -159,10 +159,43 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
async def handle_create_collection(self, request):
|
||||
"""Create a Qdrant collection for document embeddings"""
|
||||
try:
|
||||
collection_name = f"d_{request.user}_{request.collection}"
|
||||
|
||||
if self.qdrant.collection_exists(collection_name):
|
||||
logger.info(f"Qdrant collection {collection_name} already exists")
|
||||
else:
|
||||
# Create collection with default dimension (will be recreated with correct dim on first write if needed)
|
||||
# Using a placeholder dimension - actual dimension determined by first embedding
|
||||
self.qdrant.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(
|
||||
size=384, # Default dimension, common for many models
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
logger.info(f"Created Qdrant collection: {collection_name}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(error=None)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create collection: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="creation_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, request):
|
||||
"""Delete the collection for document embeddings"""
|
||||
try:
|
||||
collection_name = f"d_{message.user}_{message.collection}"
|
||||
collection_name = f"d_{request.user}_{request.collection}"
|
||||
|
||||
if self.qdrant.collection_exists(collection_name):
|
||||
self.qdrant.delete_collection(collection_name)
|
||||
|
|
@ -175,7 +208,7 @@ class Processor(DocumentEmbeddingsStoreService):
|
|||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
|
|
|||
|
|
@ -60,8 +60,23 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
metrics=storage_response_metrics,
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
"""Start the processor and its storage management consumer"""
|
||||
await super().start()
|
||||
await self.storage_request_consumer.start()
|
||||
await self.storage_response_producer.start()
|
||||
|
||||
async def store_graph_embeddings(self, message):
|
||||
|
||||
# Validate collection exists before accepting writes
|
||||
if not self.vecstore.collection_exists(message.metadata.user, message.metadata.collection):
|
||||
error_msg = (
|
||||
f"Collection {message.metadata.collection} does not exist. "
|
||||
f"Create it first with tg-set-collection."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
for entity in message.entities:
|
||||
|
||||
if entity.entity.value != "" and entity.entity.value is not None:
|
||||
|
|
@ -83,18 +98,21 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
help=f'Milvus store URI (default: {default_store_uri})'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
async def on_storage_management(self, message, consumer, flow):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
request = message.value()
|
||||
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
if request.operation == "create-collection":
|
||||
await self.handle_create_collection(request)
|
||||
elif request.operation == "delete-collection":
|
||||
await self.handle_delete_collection(request)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
message=f"Unknown operation: {request.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
|
@ -109,17 +127,40 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
async def handle_create_collection(self, request):
|
||||
"""Create a Milvus collection for graph embeddings"""
|
||||
try:
|
||||
if self.vecstore.collection_exists(request.user, request.collection):
|
||||
logger.info(f"Collection {request.user}/{request.collection} already exists")
|
||||
else:
|
||||
self.vecstore.create_collection(request.user, request.collection)
|
||||
logger.info(f"Created collection {request.user}/{request.collection}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(error=None)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create collection: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="creation_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, request):
|
||||
"""Delete the collection for graph embeddings"""
|
||||
try:
|
||||
self.vecstore.delete_collection(message.user, message.collection)
|
||||
self.vecstore.delete_collection(request.user, request.collection)
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
|
|
|||
|
|
@ -115,8 +115,27 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
"Gave up waiting for index creation"
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
"""Start the processor and its storage management consumer"""
|
||||
await super().start()
|
||||
await self.storage_request_consumer.start()
|
||||
await self.storage_response_producer.start()
|
||||
|
||||
async def store_graph_embeddings(self, message):
|
||||
|
||||
index_name = (
|
||||
"t-" + message.metadata.user + "-" + message.metadata.collection
|
||||
)
|
||||
|
||||
# Validate collection exists before accepting writes
|
||||
if not self.pinecone.has_index(index_name):
|
||||
error_msg = (
|
||||
f"Collection {message.metadata.collection} does not exist. "
|
||||
f"Create it first with tg-set-collection."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
for entity in message.entities:
|
||||
|
||||
if entity.entity.value == "" or entity.entity.value is None:
|
||||
|
|
@ -124,28 +143,6 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
|
||||
for vec in entity.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
|
||||
index_name = (
|
||||
"t-" + message.metadata.user + "-" + message.metadata.collection
|
||||
)
|
||||
|
||||
if index_name != self.last_index_name:
|
||||
|
||||
if not self.pinecone.has_index(index_name):
|
||||
|
||||
try:
|
||||
|
||||
self.create_index(index_name, dim)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Pinecone index creation failed")
|
||||
raise e
|
||||
|
||||
logger.info(f"Index {index_name} created")
|
||||
|
||||
self.last_index_name = index_name
|
||||
|
||||
index = self.pinecone.Index(index_name)
|
||||
|
||||
# Generate unique ID for each vector
|
||||
|
|
@ -191,18 +188,21 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
help=f'Pinecone region, (default: {default_region}'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
async def on_storage_management(self, message, consumer, flow):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
request = message.value()
|
||||
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
if request.operation == "create-collection":
|
||||
await self.handle_create_collection(request)
|
||||
elif request.operation == "delete-collection":
|
||||
await self.handle_delete_collection(request)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
message=f"Unknown operation: {request.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
|
@ -217,10 +217,36 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
async def handle_create_collection(self, request):
|
||||
"""Create a Pinecone index for graph embeddings"""
|
||||
try:
|
||||
index_name = f"t-{request.user}-{request.collection}"
|
||||
|
||||
if self.pinecone.has_index(index_name):
|
||||
logger.info(f"Pinecone index {index_name} already exists")
|
||||
else:
|
||||
# Create with default dimension - will need to be recreated if dimension doesn't match
|
||||
self.create_index(index_name, dim=384)
|
||||
logger.info(f"Created Pinecone index: {index_name}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(error=None)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create collection: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="creation_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, request):
|
||||
"""Delete the collection for graph embeddings"""
|
||||
try:
|
||||
index_name = f"t-{message.user}-{message.collection}"
|
||||
index_name = f"t-{request.user}-{request.collection}"
|
||||
|
||||
if self.pinecone.has_index(index_name):
|
||||
self.pinecone.delete_index(index_name)
|
||||
|
|
@ -233,7 +259,7 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
|
|
|||
|
|
@ -36,8 +36,6 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
}
|
||||
)
|
||||
|
||||
self.last_collection = None
|
||||
|
||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||
|
||||
# Set up storage management if base class attributes are available
|
||||
|
|
@ -71,31 +69,30 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
metrics=storage_response_metrics,
|
||||
)
|
||||
|
||||
def get_collection(self, dim, user, collection):
|
||||
|
||||
def get_collection(self, user, collection):
|
||||
"""Get collection name and validate it exists"""
|
||||
cname = (
|
||||
"t_" + user + "_" + collection
|
||||
)
|
||||
|
||||
if cname != self.last_collection:
|
||||
|
||||
if not self.qdrant.collection_exists(cname):
|
||||
|
||||
try:
|
||||
self.qdrant.create_collection(
|
||||
collection_name=cname,
|
||||
vectors_config=VectorParams(
|
||||
size=dim, distance=Distance.COSINE
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Qdrant collection creation failed")
|
||||
raise e
|
||||
|
||||
self.last_collection = cname
|
||||
if not self.qdrant.collection_exists(cname):
|
||||
error_msg = (
|
||||
f"Collection {collection} does not exist. "
|
||||
f"Create it first with tg-set-collection."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
return cname
|
||||
|
||||
async def start(self):
|
||||
"""Start the processor and its storage management consumer"""
|
||||
await super().start()
|
||||
if hasattr(self, 'storage_request_consumer'):
|
||||
await self.storage_request_consumer.start()
|
||||
if hasattr(self, 'storage_response_producer'):
|
||||
await self.storage_response_producer.start()
|
||||
|
||||
async def store_graph_embeddings(self, message):
|
||||
|
||||
for entity in message.entities:
|
||||
|
|
@ -104,10 +101,8 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
|
||||
for vec in entity.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
|
||||
collection = self.get_collection(
|
||||
dim, message.metadata.user, message.metadata.collection
|
||||
message.metadata.user, message.metadata.collection
|
||||
)
|
||||
|
||||
self.qdrant.upsert(
|
||||
|
|
@ -140,18 +135,21 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
help=f'Qdrant API key'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
async def on_storage_management(self, message, consumer, flow):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
request = message.value()
|
||||
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
if request.operation == "create-collection":
|
||||
await self.handle_create_collection(request)
|
||||
elif request.operation == "delete-collection":
|
||||
await self.handle_delete_collection(request)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
message=f"Unknown operation: {request.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
|
@ -166,10 +164,43 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
async def handle_create_collection(self, request):
|
||||
"""Create a Qdrant collection for graph embeddings"""
|
||||
try:
|
||||
collection_name = f"t_{request.user}_{request.collection}"
|
||||
|
||||
if self.qdrant.collection_exists(collection_name):
|
||||
logger.info(f"Qdrant collection {collection_name} already exists")
|
||||
else:
|
||||
# Create collection with default dimension (will be recreated with correct dim on first write if needed)
|
||||
# Using a placeholder dimension - actual dimension determined by first embedding
|
||||
self.qdrant.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(
|
||||
size=384, # Default dimension, common for many models
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
logger.info(f"Created Qdrant collection: {collection_name}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(error=None)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create collection: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="creation_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, request):
|
||||
"""Delete the collection for graph embeddings"""
|
||||
try:
|
||||
collection_name = f"t_{message.user}_{message.collection}"
|
||||
collection_name = f"t_{request.user}_{request.collection}"
|
||||
|
||||
if self.qdrant.collection_exists(collection_name):
|
||||
self.qdrant.delete_collection(collection_name)
|
||||
|
|
@ -182,7 +213,7 @@ class Processor(GraphEmbeddingsStoreService):
|
|||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
|
|
|||
|
|
@ -295,6 +295,8 @@ class Processor(FlowProcessor):
|
|||
|
||||
try:
|
||||
self.session.execute(create_table_cql)
|
||||
if keyspace not in self.known_tables:
|
||||
self.known_tables[keyspace] = set()
|
||||
self.known_tables[keyspace].add(table_key)
|
||||
logger.info(f"Ensured table exists: {safe_keyspace}.{safe_table}")
|
||||
|
||||
|
|
@ -340,18 +342,47 @@ class Processor(FlowProcessor):
|
|||
logger.warning(f"Failed to convert value {value} to type {field_type}: {e}")
|
||||
return str(value)
|
||||
|
||||
async def start(self):
|
||||
"""Start the processor and its storage management consumer"""
|
||||
await super().start()
|
||||
await self.storage_request_consumer.start()
|
||||
await self.storage_response_producer.start()
|
||||
|
||||
async def on_object(self, msg, consumer, flow):
|
||||
"""Process incoming ExtractedObject and store in Cassandra"""
|
||||
|
||||
|
||||
obj = msg.value()
|
||||
logger.info(f"Storing {len(obj.values)} objects for schema {obj.schema_name} from {obj.metadata.id}")
|
||||
|
||||
|
||||
# Validate collection/keyspace exists before accepting writes
|
||||
safe_keyspace = self.sanitize_name(obj.metadata.user)
|
||||
if safe_keyspace not in self.known_keyspaces:
|
||||
# Check if keyspace actually exists in Cassandra
|
||||
self.connect_cassandra()
|
||||
check_keyspace_cql = """
|
||||
SELECT keyspace_name FROM system_schema.keyspaces
|
||||
WHERE keyspace_name = %s
|
||||
"""
|
||||
result = self.session.execute(check_keyspace_cql, (safe_keyspace,))
|
||||
# Check if result is None (mock case) or has no rows
|
||||
if result is None or not result.one():
|
||||
error_msg = (
|
||||
f"Collection {obj.metadata.collection} does not exist. "
|
||||
f"Create it first with tg-set-collection."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
# Cache it if it exists
|
||||
self.known_keyspaces.add(safe_keyspace)
|
||||
if safe_keyspace not in self.known_tables:
|
||||
self.known_tables[safe_keyspace] = set()
|
||||
|
||||
# Get schema definition
|
||||
schema = self.schemas.get(obj.schema_name)
|
||||
if not schema:
|
||||
logger.warning(f"No schema found for {obj.schema_name} - skipping")
|
||||
return
|
||||
|
||||
|
||||
# Ensure table exists
|
||||
keyspace = obj.metadata.user
|
||||
table_name = obj.schema_name
|
||||
|
|
@ -425,26 +456,36 @@ class Processor(FlowProcessor):
|
|||
|
||||
async def on_storage_management(self, msg, consumer, flow):
|
||||
"""Handle storage management requests for collection operations"""
|
||||
logger.info(f"Received storage management request: {msg.operation} for {msg.user}/{msg.collection}")
|
||||
request = msg.value()
|
||||
logger.info(f"Received storage management request: {request.operation} for {request.user}/{request.collection}")
|
||||
|
||||
try:
|
||||
if msg.operation == "delete-collection":
|
||||
await self.delete_collection(msg.user, msg.collection)
|
||||
if request.operation == "create-collection":
|
||||
await self.create_collection(request.user, request.collection)
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {msg.user}/{msg.collection}")
|
||||
logger.info(f"Successfully created collection {request.user}/{request.collection}")
|
||||
elif request.operation == "delete-collection":
|
||||
await self.delete_collection(request.user, request.collection)
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
|
||||
else:
|
||||
logger.warning(f"Unknown storage management operation: {msg.operation}")
|
||||
logger.warning(f"Unknown storage management operation: {request.operation}")
|
||||
# Send error response
|
||||
from .... schema import Error
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="unknown_operation",
|
||||
message=f"Unknown operation: {msg.operation}"
|
||||
message=f"Unknown operation: {request.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
|
@ -459,10 +500,28 @@ class Processor(FlowProcessor):
|
|||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.send("storage-response", response)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def create_collection(self, user: str, collection: str):
|
||||
"""Create/verify collection exists in Cassandra object store"""
|
||||
# Connect if not already connected
|
||||
self.connect_cassandra()
|
||||
|
||||
# Sanitize names for safety
|
||||
safe_keyspace = self.sanitize_name(user)
|
||||
|
||||
# Ensure keyspace exists
|
||||
if safe_keyspace not in self.known_keyspaces:
|
||||
self.ensure_keyspace(safe_keyspace)
|
||||
self.known_keyspaces.add(safe_keyspace)
|
||||
|
||||
# For Cassandra objects, collection is just a property in rows
|
||||
# No need to create separate tables per collection
|
||||
# Just mark that we've seen this collection
|
||||
logger.info(f"Collection {collection} ready for user {user} (using keyspace {safe_keyspace})")
|
||||
|
||||
async def delete_collection(self, user: str, collection: str):
|
||||
"""Delete all data for a specific collection"""
|
||||
"""Delete all data for a specific collection using schema information"""
|
||||
# Connect if not already connected
|
||||
self.connect_cassandra()
|
||||
|
||||
|
|
@ -482,40 +541,78 @@ class Processor(FlowProcessor):
|
|||
return
|
||||
self.known_keyspaces.add(safe_keyspace)
|
||||
|
||||
# Get all tables in the keyspace that might contain collection data
|
||||
get_tables_cql = """
|
||||
SELECT table_name FROM system_schema.tables
|
||||
WHERE keyspace_name = %s
|
||||
"""
|
||||
|
||||
tables = self.session.execute(get_tables_cql, (safe_keyspace,))
|
||||
# Iterate over schemas we manage to delete from relevant tables
|
||||
tables_deleted = 0
|
||||
|
||||
for row in tables:
|
||||
table_name = row.table_name
|
||||
for schema_name, schema in self.schemas.items():
|
||||
safe_table = self.sanitize_table(schema_name)
|
||||
|
||||
# Check if the table has a collection column
|
||||
check_column_cql = """
|
||||
SELECT column_name FROM system_schema.columns
|
||||
WHERE keyspace_name = %s AND table_name = %s AND column_name = 'collection'
|
||||
"""
|
||||
# Check if table exists
|
||||
table_key = f"{user}.{schema_name}"
|
||||
if table_key not in self.known_tables.get(user, set()):
|
||||
logger.debug(f"Table {safe_keyspace}.{safe_table} not in known tables, skipping")
|
||||
continue
|
||||
|
||||
result = self.session.execute(check_column_cql, (safe_keyspace, table_name))
|
||||
if result.one():
|
||||
# Table has collection column, delete data for this collection
|
||||
try:
|
||||
delete_cql = f"""
|
||||
DELETE FROM {safe_keyspace}.{table_name}
|
||||
try:
|
||||
# Get primary key fields from schema
|
||||
primary_key_fields = [field for field in schema.fields if field.primary]
|
||||
|
||||
if primary_key_fields:
|
||||
# Schema has primary keys: need to query for partition keys first
|
||||
# Build SELECT query for primary key fields
|
||||
pk_field_names = [self.sanitize_name(field.name) for field in primary_key_fields]
|
||||
select_cql = f"""
|
||||
SELECT {', '.join(pk_field_names)}
|
||||
FROM {safe_keyspace}.{safe_table}
|
||||
WHERE collection = %s
|
||||
ALLOW FILTERING
|
||||
"""
|
||||
self.session.execute(delete_cql, (collection,))
|
||||
tables_deleted += 1
|
||||
logger.info(f"Deleted collection {collection} from table {safe_keyspace}.{table_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete from table {safe_keyspace}.{table_name}: {e}")
|
||||
raise
|
||||
|
||||
logger.info(f"Deleted collection {collection} from {tables_deleted} tables in keyspace {safe_keyspace}")
|
||||
rows = self.session.execute(select_cql, (collection,))
|
||||
|
||||
# Delete each row using full partition key
|
||||
for row in rows:
|
||||
where_clauses = ["collection = %s"]
|
||||
values = [collection]
|
||||
|
||||
for field_name in pk_field_names:
|
||||
where_clauses.append(f"{field_name} = %s")
|
||||
values.append(getattr(row, field_name))
|
||||
|
||||
delete_cql = f"""
|
||||
DELETE FROM {safe_keyspace}.{safe_table}
|
||||
WHERE {' AND '.join(where_clauses)}
|
||||
"""
|
||||
|
||||
self.session.execute(delete_cql, tuple(values))
|
||||
else:
|
||||
# No primary keys, uses synthetic_id
|
||||
# Need to query for synthetic_ids first
|
||||
select_cql = f"""
|
||||
SELECT synthetic_id
|
||||
FROM {safe_keyspace}.{safe_table}
|
||||
WHERE collection = %s
|
||||
ALLOW FILTERING
|
||||
"""
|
||||
|
||||
rows = self.session.execute(select_cql, (collection,))
|
||||
|
||||
# Delete each row using collection and synthetic_id
|
||||
for row in rows:
|
||||
delete_cql = f"""
|
||||
DELETE FROM {safe_keyspace}.{safe_table}
|
||||
WHERE collection = %s AND synthetic_id = %s
|
||||
"""
|
||||
self.session.execute(delete_cql, (collection, row.synthetic_id))
|
||||
|
||||
tables_deleted += 1
|
||||
logger.info(f"Deleted collection {collection} from table {safe_keyspace}.{safe_table}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete from table {safe_keyspace}.{safe_table}: {e}")
|
||||
raise
|
||||
|
||||
logger.info(f"Deleted collection {collection} from {tables_deleted} schema-based tables in keyspace {safe_keyspace}")
|
||||
|
||||
def close(self):
|
||||
"""Clean up Cassandra connections"""
|
||||
|
|
|
|||
|
|
@ -109,6 +109,15 @@ class Processor(TriplesStoreService):
|
|||
|
||||
self.table = user
|
||||
|
||||
# Validate collection exists before accepting writes
|
||||
if not self.tg.collection_exists(message.metadata.collection):
|
||||
error_msg = (
|
||||
f"Collection {message.metadata.collection} does not exist. "
|
||||
f"Create it first with tg-set-collection."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
for t in message.triples:
|
||||
self.tg.insert(
|
||||
message.metadata.collection,
|
||||
|
|
@ -117,18 +126,27 @@ class Processor(TriplesStoreService):
|
|||
t.o.value
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
async def start(self):
|
||||
"""Start the processor and its storage management consumer"""
|
||||
await super().start()
|
||||
await self.storage_request_consumer.start()
|
||||
await self.storage_response_producer.start()
|
||||
|
||||
async def on_storage_management(self, message, consumer, flow):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
request = message.value()
|
||||
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
if request.operation == "create-collection":
|
||||
await self.handle_create_collection(request)
|
||||
elif request.operation == "delete-collection":
|
||||
await self.handle_delete_collection(request)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
message=f"Unknown operation: {request.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
|
@ -143,42 +161,85 @@ class Processor(TriplesStoreService):
|
|||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
"""Delete all data for a specific collection from the unified triples table"""
|
||||
async def handle_create_collection(self, request):
|
||||
"""Create a collection in Cassandra triple store"""
|
||||
try:
|
||||
# Create or reuse connection for this user's keyspace
|
||||
if self.table is None or self.table != message.user:
|
||||
if self.table is None or self.table != request.user:
|
||||
self.tg = None
|
||||
|
||||
try:
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
self.tg = KnowledgeGraph(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=message.user,
|
||||
keyspace=request.user,
|
||||
username=self.cassandra_username,
|
||||
password=self.cassandra_password
|
||||
)
|
||||
else:
|
||||
self.tg = KnowledgeGraph(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=message.user,
|
||||
keyspace=request.user,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Cassandra for user {message.user}: {e}")
|
||||
logger.error(f"Failed to connect to Cassandra for user {request.user}: {e}")
|
||||
raise
|
||||
|
||||
self.table = message.user
|
||||
self.table = request.user
|
||||
|
||||
# Delete all triples for this collection from the unified table
|
||||
# In the unified table schema, collection is the partition key
|
||||
delete_cql = """
|
||||
DELETE FROM triples
|
||||
WHERE collection = ?
|
||||
"""
|
||||
# Create collection using the built-in method
|
||||
logger.info(f"Creating collection {request.collection} for user {request.user}")
|
||||
|
||||
if self.tg.collection_exists(request.collection):
|
||||
logger.info(f"Collection {request.collection} already exists")
|
||||
else:
|
||||
self.tg.create_collection(request.collection)
|
||||
logger.info(f"Created collection {request.collection}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(error=None)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create collection: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="creation_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, request):
|
||||
"""Delete all data for a specific collection from the unified triples table"""
|
||||
try:
|
||||
# Create or reuse connection for this user's keyspace
|
||||
if self.table is None or self.table != request.user:
|
||||
self.tg = None
|
||||
|
||||
try:
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
self.tg = KnowledgeGraph(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=request.user,
|
||||
username=self.cassandra_username,
|
||||
password=self.cassandra_password
|
||||
)
|
||||
else:
|
||||
self.tg = KnowledgeGraph(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=request.user,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Cassandra for user {request.user}: {e}")
|
||||
raise
|
||||
|
||||
self.table = request.user
|
||||
|
||||
# Delete all triples for this collection using the built-in method
|
||||
try:
|
||||
self.tg.session.execute(delete_cql, (message.collection,))
|
||||
logger.info(f"Deleted all triples for collection {message.collection} from keyspace {message.user}")
|
||||
self.tg.delete_collection(request.collection)
|
||||
logger.info(f"Deleted all triples for collection {request.collection} from keyspace {request.user}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection data: {e}")
|
||||
raise
|
||||
|
|
@ -188,7 +249,7 @@ class Processor(TriplesStoreService):
|
|||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
|
|
|||
|
|
@ -152,11 +152,43 @@ class Processor(TriplesStoreService):
|
|||
time=res.run_time_ms
|
||||
))
|
||||
|
||||
def collection_exists(self, user, collection):
|
||||
"""Check if collection metadata node exists"""
|
||||
result = self.io.query(
|
||||
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
|
||||
"RETURN c LIMIT 1",
|
||||
params={"user": user, "collection": collection}
|
||||
)
|
||||
return result.result_set is not None and len(result.result_set) > 0
|
||||
|
||||
def create_collection(self, user, collection):
|
||||
"""Create collection metadata node"""
|
||||
import datetime
|
||||
self.io.query(
|
||||
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) "
|
||||
"SET c.created_at = $created_at",
|
||||
params={
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"created_at": datetime.datetime.now().isoformat()
|
||||
}
|
||||
)
|
||||
logger.info(f"Created collection metadata node for {user}/{collection}")
|
||||
|
||||
async def store_triples(self, message):
|
||||
# Extract user and collection from metadata
|
||||
user = message.metadata.user if message.metadata.user else "default"
|
||||
collection = message.metadata.collection if message.metadata.collection else "default"
|
||||
|
||||
# Validate collection exists before accepting writes
|
||||
if not self.collection_exists(user, collection):
|
||||
error_msg = (
|
||||
f"Collection {collection} does not exist. "
|
||||
f"Create it first with tg-set-collection."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
for t in message.triples:
|
||||
|
||||
self.create_node(t.s.value, user, collection)
|
||||
|
|
@ -185,18 +217,27 @@ class Processor(TriplesStoreService):
|
|||
help=f'FalkorDB database (default: {default_database})'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
async def start(self):
|
||||
"""Start the processor and its storage management consumer"""
|
||||
await super().start()
|
||||
await self.storage_request_consumer.start()
|
||||
await self.storage_response_producer.start()
|
||||
|
||||
async def on_storage_management(self, message, consumer, flow):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
request = message.value()
|
||||
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
if request.operation == "create-collection":
|
||||
await self.handle_create_collection(request)
|
||||
elif request.operation == "delete-collection":
|
||||
await self.handle_delete_collection(request)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
message=f"Unknown operation: {request.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
|
@ -211,28 +252,57 @@ class Processor(TriplesStoreService):
|
|||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
async def handle_create_collection(self, request):
|
||||
"""Create collection metadata in FalkorDB"""
|
||||
try:
|
||||
if self.collection_exists(request.user, request.collection):
|
||||
logger.info(f"Collection {request.user}/{request.collection} already exists")
|
||||
else:
|
||||
self.create_collection(request.user, request.collection)
|
||||
logger.info(f"Created collection {request.user}/{request.collection}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(error=None)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create collection: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="creation_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, request):
|
||||
"""Delete the collection for FalkorDB triples"""
|
||||
try:
|
||||
# Delete all nodes and literals for this user/collection
|
||||
node_result = self.io.query(
|
||||
"MATCH (n:Node {user: $user, collection: $collection}) DETACH DELETE n",
|
||||
params={"user": message.user, "collection": message.collection}
|
||||
params={"user": request.user, "collection": request.collection}
|
||||
)
|
||||
|
||||
literal_result = self.io.query(
|
||||
"MATCH (n:Literal {user: $user, collection: $collection}) DETACH DELETE n",
|
||||
params={"user": message.user, "collection": message.collection}
|
||||
params={"user": request.user, "collection": request.collection}
|
||||
)
|
||||
|
||||
logger.info(f"Deleted {node_result.nodes_deleted} nodes and {literal_result.nodes_deleted} literals for collection {message.user}/{message.collection}")
|
||||
# Delete collection metadata node
|
||||
metadata_result = self.io.query(
|
||||
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) DELETE c",
|
||||
params={"user": request.user, "collection": request.collection}
|
||||
)
|
||||
|
||||
logger.info(f"Deleted {node_result.nodes_deleted} nodes, {literal_result.nodes_deleted} literals, and {metadata_result.nodes_deleted} metadata nodes for collection {request.user}/{request.collection}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
|
|
|||
|
|
@ -267,12 +267,43 @@ class Processor(TriplesStoreService):
|
|||
src=t.s.value, dest=t.o.value, uri=t.p.value, user=user, collection=collection,
|
||||
)
|
||||
|
||||
def collection_exists(self, user, collection):
|
||||
"""Check if collection metadata node exists"""
|
||||
with self.io.session(database=self.db) as session:
|
||||
result = session.run(
|
||||
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
|
||||
"RETURN c LIMIT 1",
|
||||
user=user, collection=collection
|
||||
)
|
||||
return bool(list(result))
|
||||
|
||||
def create_collection(self, user, collection):
|
||||
"""Create collection metadata node"""
|
||||
import datetime
|
||||
with self.io.session(database=self.db) as session:
|
||||
session.run(
|
||||
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) "
|
||||
"SET c.created_at = $created_at",
|
||||
user=user, collection=collection,
|
||||
created_at=datetime.datetime.now().isoformat()
|
||||
)
|
||||
logger.info(f"Created collection metadata node for {user}/{collection}")
|
||||
|
||||
async def store_triples(self, message):
|
||||
|
||||
# Extract user and collection from metadata
|
||||
user = message.metadata.user if message.metadata.user else "default"
|
||||
collection = message.metadata.collection if message.metadata.collection else "default"
|
||||
|
||||
# Validate collection exists before accepting writes
|
||||
if not self.collection_exists(user, collection):
|
||||
error_msg = (
|
||||
f"Collection {collection} does not exist. "
|
||||
f"Create it first with tg-set-collection."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
for t in message.triples:
|
||||
|
||||
self.create_node(t.s.value, user, collection)
|
||||
|
|
@ -317,18 +348,27 @@ class Processor(TriplesStoreService):
|
|||
help=f'Memgraph database (default: {default_database})'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
async def start(self):
|
||||
"""Start the processor and its storage management consumer"""
|
||||
await super().start()
|
||||
await self.storage_request_consumer.start()
|
||||
await self.storage_response_producer.start()
|
||||
|
||||
async def on_storage_management(self, message, consumer, flow):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
request = message.value()
|
||||
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
if request.operation == "create-collection":
|
||||
await self.handle_create_collection(request)
|
||||
elif request.operation == "delete-collection":
|
||||
await self.handle_delete_collection(request)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
message=f"Unknown operation: {request.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
|
@ -343,7 +383,30 @@ class Processor(TriplesStoreService):
|
|||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
async def handle_create_collection(self, request):
|
||||
"""Create collection metadata in Memgraph"""
|
||||
try:
|
||||
if self.collection_exists(request.user, request.collection):
|
||||
logger.info(f"Collection {request.user}/{request.collection} already exists")
|
||||
else:
|
||||
self.create_collection(request.user, request.collection)
|
||||
logger.info(f"Created collection {request.user}/{request.collection}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(error=None)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create collection: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="creation_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, request):
|
||||
"""Delete all data for a specific collection"""
|
||||
try:
|
||||
with self.io.session(database=self.db) as session:
|
||||
|
|
@ -351,7 +414,7 @@ class Processor(TriplesStoreService):
|
|||
node_result = session.run(
|
||||
"MATCH (n:Node {user: $user, collection: $collection}) "
|
||||
"DETACH DELETE n",
|
||||
user=message.user, collection=message.collection
|
||||
user=request.user, collection=request.collection
|
||||
)
|
||||
nodes_deleted = node_result.consume().counters.nodes_deleted
|
||||
|
||||
|
|
@ -359,20 +422,28 @@ class Processor(TriplesStoreService):
|
|||
literal_result = session.run(
|
||||
"MATCH (n:Literal {user: $user, collection: $collection}) "
|
||||
"DETACH DELETE n",
|
||||
user=message.user, collection=message.collection
|
||||
user=request.user, collection=request.collection
|
||||
)
|
||||
literals_deleted = literal_result.consume().counters.nodes_deleted
|
||||
|
||||
# Delete collection metadata node
|
||||
metadata_result = session.run(
|
||||
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
|
||||
"DELETE c",
|
||||
user=request.user, collection=request.collection
|
||||
)
|
||||
metadata_deleted = metadata_result.consume().counters.nodes_deleted
|
||||
|
||||
# Note: Relationships are automatically deleted with DETACH DELETE
|
||||
|
||||
logger.info(f"Deleted {nodes_deleted} nodes and {literals_deleted} literals for {message.user}/{message.collection}")
|
||||
logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {request.user}/{request.collection}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
|
|
|||
|
|
@ -228,6 +228,15 @@ class Processor(TriplesStoreService):
|
|||
user = message.metadata.user if message.metadata.user else "default"
|
||||
collection = message.metadata.collection if message.metadata.collection else "default"
|
||||
|
||||
# Validate collection exists before accepting writes
|
||||
if not self.collection_exists(user, collection):
|
||||
error_msg = (
|
||||
f"Collection {collection} does not exist. "
|
||||
f"Create it first with tg-set-collection."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
for t in message.triples:
|
||||
|
||||
self.create_node(t.s.value, user, collection)
|
||||
|
|
@ -268,18 +277,27 @@ class Processor(TriplesStoreService):
|
|||
help=f'Neo4j database (default: {default_database})'
|
||||
)
|
||||
|
||||
async def on_storage_management(self, message):
|
||||
async def start(self):
|
||||
"""Start the processor and its storage management consumer"""
|
||||
await super().start()
|
||||
await self.storage_request_consumer.start()
|
||||
await self.storage_response_producer.start()
|
||||
|
||||
async def on_storage_management(self, message, consumer, flow):
|
||||
"""Handle storage management requests"""
|
||||
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
|
||||
request = message.value()
|
||||
logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}")
|
||||
|
||||
try:
|
||||
if message.operation == "delete-collection":
|
||||
await self.handle_delete_collection(message)
|
||||
if request.operation == "create-collection":
|
||||
await self.handle_create_collection(request)
|
||||
elif request.operation == "delete-collection":
|
||||
await self.handle_delete_collection(request)
|
||||
else:
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="invalid_operation",
|
||||
message=f"Unknown operation: {message.operation}"
|
||||
message=f"Unknown operation: {request.operation}"
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
|
@ -294,7 +312,52 @@ class Processor(TriplesStoreService):
|
|||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, message):
|
||||
def collection_exists(self, user, collection):
|
||||
"""Check if collection metadata node exists"""
|
||||
with self.io.session(database=self.db) as session:
|
||||
result = session.run(
|
||||
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
|
||||
"RETURN c LIMIT 1",
|
||||
user=user, collection=collection
|
||||
)
|
||||
return bool(list(result))
|
||||
|
||||
def create_collection(self, user, collection):
|
||||
"""Create collection metadata node"""
|
||||
import datetime
|
||||
with self.io.session(database=self.db) as session:
|
||||
session.run(
|
||||
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) "
|
||||
"SET c.created_at = $created_at",
|
||||
user=user, collection=collection,
|
||||
created_at=datetime.datetime.now().isoformat()
|
||||
)
|
||||
logger.info(f"Created collection metadata node for {user}/{collection}")
|
||||
|
||||
async def handle_create_collection(self, request):
|
||||
"""Create collection metadata in Neo4j"""
|
||||
try:
|
||||
if self.collection_exists(request.user, request.collection):
|
||||
logger.info(f"Collection {request.user}/{request.collection} already exists")
|
||||
else:
|
||||
self.create_collection(request.user, request.collection)
|
||||
logger.info(f"Created collection {request.user}/{request.collection}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(error=None)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create collection: {e}", exc_info=True)
|
||||
response = StorageManagementResponse(
|
||||
error=Error(
|
||||
type="creation_error",
|
||||
message=str(e)
|
||||
)
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
|
||||
async def handle_delete_collection(self, request):
|
||||
"""Delete all data for a specific collection"""
|
||||
try:
|
||||
with self.io.session(database=self.db) as session:
|
||||
|
|
@ -302,7 +365,7 @@ class Processor(TriplesStoreService):
|
|||
node_result = session.run(
|
||||
"MATCH (n:Node {user: $user, collection: $collection}) "
|
||||
"DETACH DELETE n",
|
||||
user=message.user, collection=message.collection
|
||||
user=request.user, collection=request.collection
|
||||
)
|
||||
nodes_deleted = node_result.consume().counters.nodes_deleted
|
||||
|
||||
|
|
@ -310,20 +373,28 @@ class Processor(TriplesStoreService):
|
|||
literal_result = session.run(
|
||||
"MATCH (n:Literal {user: $user, collection: $collection}) "
|
||||
"DETACH DELETE n",
|
||||
user=message.user, collection=message.collection
|
||||
user=request.user, collection=request.collection
|
||||
)
|
||||
literals_deleted = literal_result.consume().counters.nodes_deleted
|
||||
|
||||
# Note: Relationships are automatically deleted with DETACH DELETE
|
||||
|
||||
logger.info(f"Deleted {nodes_deleted} nodes and {literals_deleted} literals for {message.user}/{message.collection}")
|
||||
# Delete collection metadata node
|
||||
metadata_result = session.run(
|
||||
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
|
||||
"DELETE c",
|
||||
user=request.user, collection=request.collection
|
||||
)
|
||||
metadata_deleted = metadata_result.consume().counters.nodes_deleted
|
||||
|
||||
logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {request.user}/{request.collection}")
|
||||
|
||||
# Send success response
|
||||
response = StorageManagementResponse(
|
||||
error=None # No error means success
|
||||
)
|
||||
await self.storage_response_producer.send(response)
|
||||
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
|
||||
logger.info(f"Successfully deleted collection {request.user}/{request.collection}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection: {e}")
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ class ConfigTableStore:
|
|||
""")
|
||||
|
||||
self.get_all_stmt = self.cassandra.prepare("""
|
||||
SELECT class, key, value FROM config;
|
||||
SELECT class AS cls, key, value FROM config;
|
||||
""")
|
||||
|
||||
self.get_values_stmt = self.cassandra.prepare("""
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ Supports both Google's Gemini models and Anthropic's Claude models.
|
|||
|
||||
from google.oauth2 import service_account
|
||||
import google.auth
|
||||
import google.api_core.exceptions
|
||||
import vertexai
|
||||
import logging
|
||||
|
||||
|
|
@ -59,8 +60,17 @@ class Processor(LlmService):
|
|||
|
||||
super(Processor, self).__init__(**params)
|
||||
|
||||
self.model = model
|
||||
self.is_anthropic = 'claude' in self.model.lower()
|
||||
# Store default model and configuration parameters
|
||||
self.default_model = model
|
||||
self.region = region
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
self.private_key = private_key
|
||||
|
||||
# Model client caches
|
||||
self.model_clients = {} # Cache for model instances
|
||||
self.generation_configs = {} # Cache for generation configs (Gemini only)
|
||||
self.anthropic_client = None # Single Anthropic client (handles multiple models)
|
||||
|
||||
# Shared parameters for both model types
|
||||
self.api_params = {
|
||||
|
|
@ -89,75 +99,101 @@ class Processor(LlmService):
|
|||
"Ensure it's set in your environment or service account."
|
||||
)
|
||||
|
||||
# Initialize the appropriate client based on the model type
|
||||
if self.is_anthropic:
|
||||
logger.info(f"Initializing Anthropic model '{model}' via AnthropicVertex SDK")
|
||||
# Initialize AnthropicVertex with credentials if provided, otherwise use ADC
|
||||
anthropic_kwargs = {'region': region, 'project_id': project_id}
|
||||
if credentials and private_key: # Pass credentials only if from a file
|
||||
anthropic_kwargs['credentials'] = credentials
|
||||
logger.debug(f"Using service account credentials for Anthropic model")
|
||||
else:
|
||||
logger.debug(f"Using Application Default Credentials for Anthropic model")
|
||||
|
||||
self.llm = AnthropicVertex(**anthropic_kwargs)
|
||||
else:
|
||||
# For Gemini models, initialize the Vertex AI SDK
|
||||
logger.info(f"Initializing Google model '{model}' via Vertex AI SDK")
|
||||
init_kwargs = {'location': region, 'project': project_id}
|
||||
if credentials and private_key: # Pass credentials only if from a file
|
||||
init_kwargs['credentials'] = credentials
|
||||
|
||||
vertexai.init(**init_kwargs)
|
||||
# Store credentials and project info for later use
|
||||
self.credentials = credentials
|
||||
self.project_id = project_id
|
||||
|
||||
self.llm = GenerativeModel(model)
|
||||
# Initialize Vertex AI SDK for Gemini models
|
||||
init_kwargs = {'location': region, 'project': project_id}
|
||||
if credentials and private_key: # Pass credentials only if from a file
|
||||
init_kwargs['credentials'] = credentials
|
||||
|
||||
self.generation_config = GenerationConfig(
|
||||
temperature=temperature,
|
||||
top_p=1.0,
|
||||
top_k=10,
|
||||
candidate_count=1,
|
||||
max_output_tokens=max_output,
|
||||
)
|
||||
vertexai.init(**init_kwargs)
|
||||
|
||||
# Block none doesn't seem to work
|
||||
block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH
|
||||
# block_level = HarmBlockThreshold.BLOCK_NONE
|
||||
|
||||
self.safety_settings = [
|
||||
SafetySetting(
|
||||
category = HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
threshold = block_level,
|
||||
),
|
||||
SafetySetting(
|
||||
category = HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
threshold = block_level,
|
||||
),
|
||||
SafetySetting(
|
||||
category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
threshold = block_level,
|
||||
),
|
||||
SafetySetting(
|
||||
category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
threshold = block_level,
|
||||
),
|
||||
]
|
||||
# Pre-initialize Anthropic client if needed (single client handles all Claude models)
|
||||
if 'claude' in self.default_model.lower():
|
||||
self._get_anthropic_client()
|
||||
|
||||
# Safety settings for Gemini models
|
||||
block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH
|
||||
self.safety_settings = [
|
||||
SafetySetting(
|
||||
category = HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
threshold = block_level,
|
||||
),
|
||||
SafetySetting(
|
||||
category = HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
threshold = block_level,
|
||||
),
|
||||
SafetySetting(
|
||||
category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
threshold = block_level,
|
||||
),
|
||||
SafetySetting(
|
||||
category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
threshold = block_level,
|
||||
),
|
||||
]
|
||||
|
||||
logger.info("VertexAI initialization complete")
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
def _get_anthropic_client(self):
|
||||
"""Get or create the Anthropic client (single client for all Claude models)"""
|
||||
if self.anthropic_client is None:
|
||||
logger.info(f"Initializing AnthropicVertex client")
|
||||
anthropic_kwargs = {'region': self.region, 'project_id': self.project_id}
|
||||
if self.credentials and self.private_key: # Pass credentials only if from a file
|
||||
anthropic_kwargs['credentials'] = self.credentials
|
||||
logger.debug(f"Using service account credentials for Anthropic models")
|
||||
else:
|
||||
logger.debug(f"Using Application Default Credentials for Anthropic models")
|
||||
|
||||
self.anthropic_client = AnthropicVertex(**anthropic_kwargs)
|
||||
|
||||
return self.anthropic_client
|
||||
|
||||
def _get_gemini_model(self, model_name, temperature=None):
|
||||
"""Get or create a Gemini model instance"""
|
||||
if model_name not in self.model_clients:
|
||||
logger.info(f"Creating GenerativeModel instance for '{model_name}'")
|
||||
self.model_clients[model_name] = GenerativeModel(model_name)
|
||||
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
# Create generation config with the effective temperature
|
||||
generation_config = GenerationConfig(
|
||||
temperature=effective_temperature,
|
||||
top_p=1.0,
|
||||
top_k=10,
|
||||
candidate_count=1,
|
||||
max_output_tokens=self.max_output,
|
||||
)
|
||||
|
||||
return self.model_clients[model_name], generation_config
|
||||
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
try:
|
||||
if self.is_anthropic:
|
||||
if 'claude' in model_name.lower():
|
||||
# Anthropic API uses a dedicated system prompt
|
||||
logger.debug("Sending request to Anthropic model...")
|
||||
response = self.llm.messages.create(
|
||||
model=self.model,
|
||||
logger.debug(f"Sending request to Anthropic model '{model_name}'...")
|
||||
client = self._get_anthropic_client()
|
||||
|
||||
response = client.messages.create(
|
||||
model=model_name,
|
||||
system=system,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=self.api_params['max_output_tokens'],
|
||||
temperature=self.api_params['temperature'],
|
||||
temperature=effective_temperature,
|
||||
top_p=self.api_params['top_p'],
|
||||
top_k=self.api_params['top_k'],
|
||||
)
|
||||
|
|
@ -166,15 +202,17 @@ class Processor(LlmService):
|
|||
text=response.content[0].text,
|
||||
in_token=response.usage.input_tokens,
|
||||
out_token=response.usage.output_tokens,
|
||||
model=self.model
|
||||
model=model_name
|
||||
)
|
||||
else:
|
||||
# Gemini API combines system and user prompts
|
||||
logger.debug("Sending request to Gemini model...")
|
||||
logger.debug(f"Sending request to Gemini model '{model_name}'...")
|
||||
full_prompt = system + "\n\n" + prompt
|
||||
|
||||
response = self.llm.generate_content(
|
||||
full_prompt, generation_config = self.generation_config,
|
||||
llm, generation_config = self._get_gemini_model(model_name, effective_temperature)
|
||||
|
||||
response = llm.generate_content(
|
||||
full_prompt, generation_config = generation_config,
|
||||
safety_settings = self.safety_settings,
|
||||
)
|
||||
|
||||
|
|
@ -182,7 +220,7 @@ class Processor(LlmService):
|
|||
text = response.text,
|
||||
in_token = response.usage_metadata.prompt_token_count,
|
||||
out_token = response.usage_metadata.candidates_token_count,
|
||||
model = self.model
|
||||
model = model_name
|
||||
)
|
||||
|
||||
logger.info(f"Input Tokens: {resp.in_token}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue