mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-01 19:32:38 +02:00
release/v1.4 -> master (#548)
This commit is contained in:
parent
3ec2cd54f9
commit
2bd68ed7f4
94 changed files with 8571 additions and 1740 deletions
|
|
@ -29,23 +29,25 @@ class TestEndToEndConfigurationFlow:
|
|||
'CASSANDRA_USERNAME': 'integration-user',
|
||||
'CASSANDRA_PASSWORD': 'integration-pass'
|
||||
}
|
||||
|
||||
|
||||
mock_cluster_instance = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_cluster_instance.connect.return_value = mock_session
|
||||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = TriplesWriter(taskgroup=MagicMock())
|
||||
|
||||
|
||||
# Create a mock message to trigger TrustGraph creation
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
# This should create TrustGraph with environment config
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
# This should create TrustGraph with environment config
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify Cluster was created with correct hosts
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
@ -145,8 +147,10 @@ class TestConfigurationPriorityEndToEnd:
|
|||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Should use CLI parameters, not environment
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
@ -243,8 +247,10 @@ class TestNoBackwardCompatibilityEndToEnd:
|
|||
mock_message.metadata.user = 'legacy_user'
|
||||
mock_message.metadata.collection = 'legacy_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Should use defaults since old parameters are not recognized
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
@ -299,8 +305,10 @@ class TestNoBackwardCompatibilityEndToEnd:
|
|||
mock_message.metadata.user = 'precedence_user'
|
||||
mock_message.metadata.collection = 'precedence_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Should use new parameters, not old ones
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
@ -349,8 +357,10 @@ class TestMultipleHostsHandling:
|
|||
mock_message.metadata.user = 'single_user'
|
||||
mock_message.metadata.collection = 'single_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Single host should be converted to list
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
|
|||
276
tests/integration/test_dynamic_llm_parameters.py
Normal file
276
tests/integration/test_dynamic_llm_parameters.py
Normal file
|
|
@ -0,0 +1,276 @@
|
|||
"""
|
||||
Integration tests for Dynamic LLM Parameters
|
||||
Testing end-to-end flow of runtime parameter changes in LLM processors
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from trustgraph.model.text_completion.openai.llm import Processor as OpenAIProcessor
|
||||
from trustgraph.base import LlmResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestDynamicLlmParameters:
|
||||
"""Integration tests for dynamic parameter configuration"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_client(self):
|
||||
"""Mock OpenAI client that returns realistic responses"""
|
||||
client = MagicMock()
|
||||
|
||||
# Default mock response
|
||||
usage = CompletionUsage(prompt_tokens=25, completion_tokens=15, total_tokens=40)
|
||||
message = ChatCompletionMessage(role="assistant", content="Dynamic parameter test response")
|
||||
choice = Choice(index=0, message=message, finish_reason="stop")
|
||||
|
||||
completion = ChatCompletion(
|
||||
id="chatcmpl-test-dynamic",
|
||||
choices=[choice],
|
||||
created=1234567890,
|
||||
model="gpt-4", # Will be overridden based on test
|
||||
object="chat.completion",
|
||||
usage=usage
|
||||
)
|
||||
|
||||
client.chat.completions.create.return_value = completion
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def base_processor_config(self):
|
||||
"""Base configuration for test processors"""
|
||||
return {
|
||||
"api_key": "test-api-key",
|
||||
"url": "https://api.openai.com/v1",
|
||||
"temperature": 0.0, # Default temperature
|
||||
"max_output": 1024,
|
||||
}
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_runtime_temperature_override(self, mock_llm_init, mock_async_init,
|
||||
mock_openai_class, mock_openai_client, base_processor_config):
|
||||
"""Test that temperature can be overridden at runtime"""
|
||||
# Arrange
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = base_processor_config | {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"concurrency": 1,
|
||||
"taskgroup": AsyncMock(),
|
||||
"id": "test-processor"
|
||||
}
|
||||
|
||||
processor = OpenAIProcessor(**config)
|
||||
|
||||
# Act - Call with different temperature than configured default (0.0)
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Dynamic parameter test response"
|
||||
|
||||
# Verify the OpenAI API was called with the overridden temperature
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
|
||||
assert call_args.kwargs['temperature'] == 0.9 # Should use runtime parameter
|
||||
assert call_args.kwargs['model'] == "gpt-3.5-turbo" # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_runtime_model_override(self, mock_llm_init, mock_async_init,
|
||||
mock_openai_class, mock_openai_client, base_processor_config):
|
||||
"""Test that model can be overridden at runtime"""
|
||||
# Arrange
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = base_processor_config | {
|
||||
"model": "gpt-3.5-turbo", # Default model
|
||||
"concurrency": 1,
|
||||
"taskgroup": AsyncMock(),
|
||||
"id": "test-processor"
|
||||
}
|
||||
|
||||
processor = OpenAIProcessor(**config)
|
||||
|
||||
# Act - Call with different model than configured default
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gpt-4", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
|
||||
# Verify the OpenAI API was called with the overridden model
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
|
||||
assert call_args.kwargs['model'] == "gpt-4" # Should use runtime parameter
|
||||
assert call_args.kwargs['temperature'] == 0.0 # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_both_parameters_override(self, mock_llm_init, mock_async_init,
|
||||
mock_openai_class, mock_openai_client, base_processor_config):
|
||||
"""Test that both model and temperature can be overridden simultaneously"""
|
||||
# Arrange
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = base_processor_config | {
|
||||
"model": "gpt-3.5-turbo", # Default model
|
||||
"concurrency": 1,
|
||||
"taskgroup": AsyncMock(),
|
||||
"id": "test-processor"
|
||||
}
|
||||
|
||||
processor = OpenAIProcessor(**config)
|
||||
|
||||
# Act - Override both parameters
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gpt-4", # Override model
|
||||
temperature=0.5 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
|
||||
# Verify both parameters were overridden
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
|
||||
assert call_args.kwargs['model'] == "gpt-4" # Should use runtime parameter
|
||||
assert call_args.kwargs['temperature'] == 0.5 # Should use runtime parameter
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_fallback_to_defaults_when_no_override(self, mock_llm_init, mock_async_init,
|
||||
mock_openai_class, mock_openai_client, base_processor_config):
|
||||
"""Test that processor falls back to configured defaults when no parameters are provided"""
|
||||
# Arrange
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = base_processor_config | {
|
||||
"model": "gpt-3.5-turbo", # Default model
|
||||
"temperature": 0.2, # Default temperature
|
||||
"concurrency": 1,
|
||||
"taskgroup": AsyncMock(),
|
||||
"id": "test-processor"
|
||||
}
|
||||
|
||||
processor = OpenAIProcessor(**config)
|
||||
|
||||
# Act - Call with no parameter overrides
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default
|
||||
temperature=None # Use default
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
|
||||
# Verify defaults were used
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
|
||||
assert call_args.kwargs['model'] == "gpt-3.5-turbo" # Should use processor default
|
||||
assert call_args.kwargs['temperature'] == 0.2 # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_multiple_concurrent_calls_different_parameters(self, mock_llm_init, mock_async_init,
|
||||
mock_openai_class, mock_openai_client, base_processor_config):
|
||||
"""Test multiple concurrent calls with different parameters don't interfere"""
|
||||
# Arrange
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = base_processor_config | {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"concurrency": 1,
|
||||
"taskgroup": AsyncMock(),
|
||||
"id": "test-processor"
|
||||
}
|
||||
|
||||
processor = OpenAIProcessor(**config)
|
||||
|
||||
# Reset the mock to track multiple calls
|
||||
mock_openai_client.reset_mock()
|
||||
|
||||
# Act - Make multiple calls with different parameters concurrently
|
||||
import asyncio
|
||||
tasks = [
|
||||
processor.generate_content("System 1", "Prompt 1", model="gpt-3.5-turbo", temperature=0.1),
|
||||
processor.generate_content("System 2", "Prompt 2", model="gpt-4", temperature=0.8),
|
||||
processor.generate_content("System 3", "Prompt 3", model="gpt-3.5-turbo", temperature=0.5)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Assert
|
||||
assert len(results) == 3
|
||||
for result in results:
|
||||
assert isinstance(result, LlmResult)
|
||||
|
||||
# Verify all calls were made with correct parameters
|
||||
assert mock_openai_client.chat.completions.create.call_count == 3
|
||||
|
||||
# Get all call arguments
|
||||
call_args_list = mock_openai_client.chat.completions.create.call_args_list
|
||||
|
||||
# Verify each call had the expected parameters
|
||||
expected_params = [
|
||||
("gpt-3.5-turbo", 0.1),
|
||||
("gpt-4", 0.8),
|
||||
("gpt-3.5-turbo", 0.5)
|
||||
]
|
||||
|
||||
for i, (expected_model, expected_temp) in enumerate(expected_params):
|
||||
call_kwargs = call_args_list[i].kwargs
|
||||
assert call_kwargs['model'] == expected_model
|
||||
assert call_kwargs['temperature'] == expected_temp
|
||||
|
||||
async def test_parameter_boundary_values(self, mock_openai_client, base_processor_config):
|
||||
"""Test parameter boundary values (edge cases)"""
|
||||
# This would test extreme values like temperature=0.0, temperature=2.0, etc.
|
||||
# Implementation depends on specific validation requirements
|
||||
pass
|
||||
|
||||
async def test_invalid_parameter_types_handling(self, mock_openai_client, base_processor_config):
|
||||
"""Test handling of invalid parameter types"""
|
||||
# This would test what happens with invalid temperature values, non-existent models, etc.
|
||||
# Implementation depends on error handling requirements
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -22,7 +22,36 @@ class TestObjectsCassandraIntegration:
|
|||
def mock_cassandra_session(self):
|
||||
"""Mock Cassandra session for integration tests"""
|
||||
session = MagicMock()
|
||||
session.execute = MagicMock()
|
||||
|
||||
# Track if keyspaces have been created
|
||||
created_keyspaces = set()
|
||||
|
||||
# Mock the execute method to return a valid result for keyspace checks
|
||||
def execute_mock(query, *args, **kwargs):
|
||||
result = MagicMock()
|
||||
query_str = str(query)
|
||||
|
||||
# Track keyspace creation
|
||||
if "CREATE KEYSPACE" in query_str:
|
||||
# Extract keyspace name from query
|
||||
import re
|
||||
match = re.search(r'CREATE KEYSPACE IF NOT EXISTS (\w+)', query_str)
|
||||
if match:
|
||||
created_keyspaces.add(match.group(1))
|
||||
|
||||
# For keyspace existence checks
|
||||
if "system_schema.keyspaces" in query_str:
|
||||
# Check if this keyspace was created
|
||||
if args and args[0] in created_keyspaces:
|
||||
result.one.return_value = MagicMock() # Exists
|
||||
else:
|
||||
result.one.return_value = None # Doesn't exist
|
||||
else:
|
||||
result.one.return_value = None
|
||||
|
||||
return result
|
||||
|
||||
session.execute = MagicMock(side_effect=execute_mock)
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -57,7 +86,8 @@ class TestObjectsCassandraIntegration:
|
|||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
|
||||
processor.create_collection = Processor.create_collection.__get__(processor, Processor)
|
||||
|
||||
return processor, mock_cassandra_cluster, mock_cassandra_session
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -85,7 +115,10 @@ class TestObjectsCassandraIntegration:
|
|||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
assert "customer_records" in processor.schemas
|
||||
|
||||
|
||||
# Step 1.5: Create the collection first (simulate tg-set-collection)
|
||||
await processor.create_collection("test_user", "import_2024")
|
||||
|
||||
# Step 2: Process an ExtractedObject
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
|
|
@ -104,10 +137,10 @@ class TestObjectsCassandraIntegration:
|
|||
confidence=0.95,
|
||||
source_span="Customer: John Doe..."
|
||||
)
|
||||
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify Cassandra interactions
|
||||
|
|
@ -178,7 +211,11 @@ class TestObjectsCassandraIntegration:
|
|||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
assert len(processor.schemas) == 2
|
||||
|
||||
|
||||
# Create collections first
|
||||
await processor.create_collection("shop", "catalog")
|
||||
await processor.create_collection("shop", "sales")
|
||||
|
||||
# Process objects for different schemas
|
||||
product_obj = ExtractedObject(
|
||||
metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]),
|
||||
|
|
@ -187,7 +224,7 @@ class TestObjectsCassandraIntegration:
|
|||
confidence=0.9,
|
||||
source_span="Product..."
|
||||
)
|
||||
|
||||
|
||||
order_obj = ExtractedObject(
|
||||
metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]),
|
||||
schema_name="orders",
|
||||
|
|
@ -195,7 +232,7 @@ class TestObjectsCassandraIntegration:
|
|||
confidence=0.85,
|
||||
source_span="Order..."
|
||||
)
|
||||
|
||||
|
||||
# Process both objects
|
||||
for obj in [product_obj, order_obj]:
|
||||
msg = MagicMock()
|
||||
|
|
@ -225,6 +262,9 @@ class TestObjectsCassandraIntegration:
|
|||
]
|
||||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("test", "test")
|
||||
|
||||
# Create object missing required field
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
|
||||
|
|
@ -233,10 +273,10 @@ class TestObjectsCassandraIntegration:
|
|||
confidence=0.8,
|
||||
source_span="Test"
|
||||
)
|
||||
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
|
||||
# Should still process (Cassandra doesn't enforce NOT NULL)
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
|
|
@ -261,6 +301,9 @@ class TestObjectsCassandraIntegration:
|
|||
]
|
||||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("logger", "app_events")
|
||||
|
||||
# Process object
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="e1", user="logger", collection="app_events", metadata=[]),
|
||||
|
|
@ -269,10 +312,10 @@ class TestObjectsCassandraIntegration:
|
|||
confidence=1.0,
|
||||
source_span="Event"
|
||||
)
|
||||
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify synthetic_id was added
|
||||
|
|
@ -325,8 +368,10 @@ class TestObjectsCassandraIntegration:
|
|||
)
|
||||
|
||||
# Make insert fail
|
||||
mock_result = MagicMock()
|
||||
mock_result.one.return_value = MagicMock() # Keyspace exists
|
||||
mock_session.execute.side_effect = [
|
||||
None, # keyspace creation succeeds
|
||||
mock_result, # keyspace existence check succeeds
|
||||
None, # table creation succeeds
|
||||
Exception("Connection timeout") # insert fails
|
||||
]
|
||||
|
|
@ -359,7 +404,11 @@ class TestObjectsCassandraIntegration:
|
|||
|
||||
# Process objects from different collections
|
||||
collections = ["import_jan", "import_feb", "import_mar"]
|
||||
|
||||
|
||||
# Create all collections first
|
||||
for coll in collections:
|
||||
await processor.create_collection("analytics", coll)
|
||||
|
||||
for coll in collections:
|
||||
obj = ExtractedObject(
|
||||
metadata=Metadata(id=f"{coll}-1", user="analytics", collection=coll, metadata=[]),
|
||||
|
|
@ -368,7 +417,7 @@ class TestObjectsCassandraIntegration:
|
|||
confidence=0.9,
|
||||
source_span="Data"
|
||||
)
|
||||
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = obj
|
||||
await processor.on_object(msg, None, None)
|
||||
|
|
@ -436,9 +485,12 @@ class TestObjectsCassandraIntegration:
|
|||
source_span="Multiple customers extracted from document"
|
||||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("test_user", "batch_import")
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = batch_obj
|
||||
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify table creation
|
||||
|
|
@ -479,6 +531,9 @@ class TestObjectsCassandraIntegration:
|
|||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("test", "empty")
|
||||
|
||||
# Process empty batch object
|
||||
empty_obj = ExtractedObject(
|
||||
metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]),
|
||||
|
|
@ -487,10 +542,10 @@ class TestObjectsCassandraIntegration:
|
|||
confidence=1.0,
|
||||
source_span="No objects found"
|
||||
)
|
||||
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = empty_obj
|
||||
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Should still create table
|
||||
|
|
@ -517,6 +572,9 @@ class TestObjectsCassandraIntegration:
|
|||
]
|
||||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("test", "mixed")
|
||||
|
||||
# Single object (backward compatibility)
|
||||
single_obj = ExtractedObject(
|
||||
metadata=Metadata(id="single", user="test", collection="mixed", metadata=[]),
|
||||
|
|
@ -525,7 +583,7 @@ class TestObjectsCassandraIntegration:
|
|||
confidence=0.9,
|
||||
source_span="Single object"
|
||||
)
|
||||
|
||||
|
||||
# Batch object
|
||||
batch_obj = ExtractedObject(
|
||||
metadata=Metadata(id="batch", user="test", collection="mixed", metadata=[]),
|
||||
|
|
@ -537,7 +595,7 @@ class TestObjectsCassandraIntegration:
|
|||
confidence=0.85,
|
||||
source_span="Batch objects"
|
||||
)
|
||||
|
||||
|
||||
# Process both
|
||||
for obj in [single_obj, batch_obj]:
|
||||
msg = MagicMock()
|
||||
|
|
|
|||
|
|
@ -60,13 +60,13 @@ class TestTextCompletionIntegration:
|
|||
"""Create text completion processor with test configuration"""
|
||||
# Create a minimal processor instance for testing generate_content
|
||||
processor = MagicMock()
|
||||
processor.model = processor_config["model"]
|
||||
processor.default_model = processor_config["model"]
|
||||
processor.temperature = processor_config["temperature"]
|
||||
processor.max_output = processor_config["max_output"]
|
||||
|
||||
|
||||
# Add the actual generate_content method from Processor class
|
||||
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
|
||||
|
||||
|
||||
return processor
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -112,11 +112,11 @@ class TestTextCompletionIntegration:
|
|||
for config in test_configs:
|
||||
# Arrange - Create minimal processor mock
|
||||
processor = MagicMock()
|
||||
processor.model = config['model']
|
||||
processor.default_model = config['model']
|
||||
processor.temperature = config['temperature']
|
||||
processor.max_output = config['max_output']
|
||||
processor.openai = mock_openai_client
|
||||
|
||||
|
||||
# Add the actual generate_content method
|
||||
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
|
||||
|
||||
|
|
@ -242,7 +242,7 @@ class TestTextCompletionIntegration:
|
|||
processors = []
|
||||
for i in range(5):
|
||||
processor = MagicMock()
|
||||
processor.model = processor_config["model"]
|
||||
processor.default_model = processor_config["model"]
|
||||
processor.temperature = processor_config["temperature"]
|
||||
processor.max_output = processor_config["max_output"]
|
||||
processor.openai = mock_openai_client
|
||||
|
|
@ -348,7 +348,7 @@ class TestTextCompletionIntegration:
|
|||
"""Test that model parameters are correctly passed to OpenAI API"""
|
||||
# Arrange
|
||||
processor = MagicMock()
|
||||
processor.model = "gpt-4"
|
||||
processor.default_model = "gpt-4"
|
||||
processor.temperature = 0.8
|
||||
processor.max_output = 2048
|
||||
processor.openai = mock_openai_client
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue