mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Fix tests, mock the correct API
This commit is contained in:
parent
6159341bd5
commit
131e04abd9
3 changed files with 149 additions and 165 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
"""
|
"""
|
||||||
Unit tests for trustgraph.model.text_completion.vertexai
|
Unit tests for trustgraph.model.text_completion.vertexai
|
||||||
Starting simple with one test to get the basics working
|
Updated for google-genai SDK
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
@ -15,19 +15,20 @@ from trustgraph.base import LlmResult
|
||||||
class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
"""Simple test for processor initialization"""
|
"""Simple test for processor initialization"""
|
||||||
|
|
||||||
|
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||||
"""Test basic processor initialization with mocked dependencies"""
|
"""Test basic processor initialization with mocked dependencies"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_credentials = MagicMock()
|
mock_credentials = MagicMock()
|
||||||
|
mock_credentials.project_id = "test-project-123"
|
||||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||||
mock_model = MagicMock()
|
|
||||||
mock_generative_model.return_value = mock_model
|
mock_client = MagicMock()
|
||||||
|
mock_genai.Client.return_value = mock_client
|
||||||
|
|
||||||
# Mock the parent class initialization to avoid taskgroup requirement
|
# Mock the parent class initialization to avoid taskgroup requirement
|
||||||
mock_async_init.return_value = None
|
mock_async_init.return_value = None
|
||||||
mock_llm_init.return_value = None
|
mock_llm_init.return_value = None
|
||||||
|
|
@ -47,32 +48,38 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert processor.default_model == 'gemini-2.0-flash-001' # It's stored as 'model', not 'model_name'
|
assert processor.default_model == 'gemini-2.0-flash-001'
|
||||||
assert hasattr(processor, 'generation_configs') # Now a cache dictionary
|
assert hasattr(processor, 'generation_configs') # Cache dictionary
|
||||||
assert hasattr(processor, 'safety_settings')
|
assert hasattr(processor, 'safety_settings')
|
||||||
assert hasattr(processor, 'model_clients') # LLM clients are now cached
|
assert hasattr(processor, 'client') # genai.Client
|
||||||
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('private.json')
|
mock_service_account.Credentials.from_service_account_file.assert_called_once()
|
||||||
mock_vertexai.init.assert_called_once()
|
mock_genai.Client.assert_called_once_with(
|
||||||
|
vertexai=True,
|
||||||
|
project="test-project-123",
|
||||||
|
location="us-central1",
|
||||||
|
credentials=mock_credentials
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||||
"""Test successful content generation"""
|
"""Test successful content generation"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_credentials = MagicMock()
|
mock_credentials = MagicMock()
|
||||||
|
mock_credentials.project_id = "test-project-123"
|
||||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||||
|
|
||||||
mock_model = MagicMock()
|
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.text = "Generated response from Gemini"
|
mock_response.text = "Generated response from Gemini"
|
||||||
mock_response.usage_metadata.prompt_token_count = 15
|
mock_response.usage_metadata.prompt_token_count = 15
|
||||||
mock_response.usage_metadata.candidates_token_count = 8
|
mock_response.usage_metadata.candidates_token_count = 8
|
||||||
mock_model.generate_content.return_value = mock_response
|
|
||||||
mock_generative_model.return_value = mock_model
|
mock_client = MagicMock()
|
||||||
|
mock_client.models.generate_content.return_value = mock_response
|
||||||
|
mock_genai.Client.return_value = mock_client
|
||||||
|
|
||||||
mock_async_init.return_value = None
|
mock_async_init.return_value = None
|
||||||
mock_llm_init.return_value = None
|
mock_llm_init.return_value = None
|
||||||
|
|
||||||
|
|
@ -98,32 +105,26 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
assert result.in_token == 15
|
assert result.in_token == 15
|
||||||
assert result.out_token == 8
|
assert result.out_token == 8
|
||||||
assert result.model == 'gemini-2.0-flash-001'
|
assert result.model == 'gemini-2.0-flash-001'
|
||||||
# Check that the method was called (actual prompt format may vary)
|
mock_client.models.generate_content.assert_called_once()
|
||||||
mock_model.generate_content.assert_called_once()
|
|
||||||
# Verify the call was made with the expected parameters
|
|
||||||
call_args = mock_model.generate_content.call_args
|
|
||||||
# Generation config is now created dynamically per model
|
|
||||||
assert 'generation_config' in call_args[1]
|
|
||||||
assert call_args[1]['safety_settings'] == processor.safety_settings
|
|
||||||
|
|
||||||
|
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||||
"""Test rate limit error handling"""
|
"""Test rate limit error handling"""
|
||||||
# Arrange
|
# Arrange
|
||||||
from google.api_core.exceptions import ResourceExhausted
|
from google.api_core.exceptions import ResourceExhausted
|
||||||
from trustgraph.exceptions import TooManyRequests
|
from trustgraph.exceptions import TooManyRequests
|
||||||
|
|
||||||
mock_credentials = MagicMock()
|
mock_credentials = MagicMock()
|
||||||
|
mock_credentials.project_id = "test-project-123"
|
||||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||||
|
|
||||||
mock_model = MagicMock()
|
mock_client = MagicMock()
|
||||||
mock_model.generate_content.side_effect = ResourceExhausted("Rate limit exceeded")
|
mock_client.models.generate_content.side_effect = ResourceExhausted("Rate limit exceeded")
|
||||||
mock_generative_model.return_value = mock_model
|
mock_genai.Client.return_value = mock_client
|
||||||
|
|
||||||
mock_async_init.return_value = None
|
mock_async_init.return_value = None
|
||||||
mock_llm_init.return_value = None
|
mock_llm_init.return_value = None
|
||||||
|
|
||||||
|
|
@ -144,25 +145,26 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
with pytest.raises(TooManyRequests):
|
with pytest.raises(TooManyRequests):
|
||||||
await processor.generate_content("System prompt", "User prompt")
|
await processor.generate_content("System prompt", "User prompt")
|
||||||
|
|
||||||
|
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_generate_content_blocked_response(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
async def test_generate_content_blocked_response(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||||
"""Test handling of blocked content (safety filters)"""
|
"""Test handling of blocked content (safety filters)"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_credentials = MagicMock()
|
mock_credentials = MagicMock()
|
||||||
|
mock_credentials.project_id = "test-project-123"
|
||||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||||
|
|
||||||
mock_model = MagicMock()
|
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.text = None # Blocked content returns None
|
mock_response.text = None # Blocked content returns None
|
||||||
mock_response.usage_metadata.prompt_token_count = 10
|
mock_response.usage_metadata.prompt_token_count = 10
|
||||||
mock_response.usage_metadata.candidates_token_count = 0
|
mock_response.usage_metadata.candidates_token_count = 0
|
||||||
mock_model.generate_content.return_value = mock_response
|
|
||||||
mock_generative_model.return_value = mock_model
|
mock_client = MagicMock()
|
||||||
|
mock_client.models.generate_content.return_value = mock_response
|
||||||
|
mock_genai.Client.return_value = mock_client
|
||||||
|
|
||||||
mock_async_init.return_value = None
|
mock_async_init.return_value = None
|
||||||
mock_llm_init.return_value = None
|
mock_llm_init.return_value = None
|
||||||
|
|
||||||
|
|
@ -190,24 +192,22 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
assert result.model == 'gemini-2.0-flash-001'
|
assert result.model == 'gemini-2.0-flash-001'
|
||||||
|
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.google.auth.default')
|
@patch('trustgraph.model.text_completion.vertexai.llm.google.auth.default')
|
||||||
|
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account, mock_auth_default):
|
async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai, mock_auth_default):
|
||||||
"""Test processor initialization without private key (uses default credentials)"""
|
"""Test processor initialization without private key (uses default credentials)"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_async_init.return_value = None
|
mock_async_init.return_value = None
|
||||||
mock_llm_init.return_value = None
|
mock_llm_init.return_value = None
|
||||||
|
|
||||||
# Mock google.auth.default() to return credentials and project ID
|
# Mock google.auth.default() to return credentials and project ID
|
||||||
mock_credentials = MagicMock()
|
mock_credentials = MagicMock()
|
||||||
mock_auth_default.return_value = (mock_credentials, "test-project-123")
|
mock_auth_default.return_value = (mock_credentials, "test-project-123")
|
||||||
|
|
||||||
# Mock GenerativeModel
|
mock_client = MagicMock()
|
||||||
mock_model = MagicMock()
|
mock_genai.Client.return_value = mock_client
|
||||||
mock_generative_model.return_value = mock_model
|
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
'region': 'us-central1',
|
'region': 'us-central1',
|
||||||
|
|
@ -222,30 +222,32 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert processor.default_model == 'gemini-2.0-flash-001'
|
assert processor.default_model == 'gemini-2.0-flash-001'
|
||||||
mock_auth_default.assert_called_once()
|
mock_auth_default.assert_called_once()
|
||||||
mock_vertexai.init.assert_called_once_with(
|
mock_genai.Client.assert_called_once_with(
|
||||||
location='us-central1',
|
vertexai=True,
|
||||||
project='test-project-123'
|
project="test-project-123",
|
||||||
|
location="us-central1",
|
||||||
|
credentials=mock_credentials
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||||
"""Test handling of generic exceptions"""
|
"""Test handling of generic exceptions"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_credentials = MagicMock()
|
mock_credentials = MagicMock()
|
||||||
|
mock_credentials.project_id = "test-project-123"
|
||||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||||
|
|
||||||
mock_model = MagicMock()
|
mock_client = MagicMock()
|
||||||
mock_model.generate_content.side_effect = Exception("Network error")
|
mock_client.models.generate_content.side_effect = Exception("Network error")
|
||||||
mock_generative_model.return_value = mock_model
|
mock_genai.Client.return_value = mock_client
|
||||||
|
|
||||||
mock_async_init.return_value = None
|
mock_async_init.return_value = None
|
||||||
mock_llm_init.return_value = None
|
mock_llm_init.return_value = None
|
||||||
|
|
||||||
|
|
@ -266,19 +268,20 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
with pytest.raises(Exception, match="Network error"):
|
with pytest.raises(Exception, match="Network error"):
|
||||||
await processor.generate_content("System prompt", "User prompt")
|
await processor.generate_content("System prompt", "User prompt")
|
||||||
|
|
||||||
|
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||||
"""Test processor initialization with custom parameters"""
|
"""Test processor initialization with custom parameters"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_credentials = MagicMock()
|
mock_credentials = MagicMock()
|
||||||
|
mock_credentials.project_id = "test-project-123"
|
||||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||||
mock_model = MagicMock()
|
|
||||||
mock_generative_model.return_value = mock_model
|
mock_client = MagicMock()
|
||||||
|
mock_genai.Client.return_value = mock_client
|
||||||
|
|
||||||
mock_async_init.return_value = None
|
mock_async_init.return_value = None
|
||||||
mock_llm_init.return_value = None
|
mock_llm_init.return_value = None
|
||||||
|
|
||||||
|
|
@ -298,37 +301,37 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert processor.default_model == 'gemini-1.5-pro'
|
assert processor.default_model == 'gemini-1.5-pro'
|
||||||
|
|
||||||
# Verify that generation_config object exists (can't easily check internal values)
|
# Verify that generation_config cache exists
|
||||||
assert hasattr(processor, 'generation_configs') # Now a cache dictionary
|
assert hasattr(processor, 'generation_configs')
|
||||||
assert processor.generation_configs == {} # Empty cache initially
|
assert processor.generation_configs == {} # Empty cache initially
|
||||||
|
|
||||||
# Verify that safety settings are configured
|
# Verify that safety settings are configured
|
||||||
assert len(processor.safety_settings) == 4
|
assert len(processor.safety_settings) == 4
|
||||||
|
|
||||||
# Verify service account was called with custom key
|
# Verify service account was called with custom key
|
||||||
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('custom-key.json')
|
mock_service_account.Credentials.from_service_account_file.assert_called_once()
|
||||||
|
|
||||||
# Verify that api_params dict has the correct values (this is accessible)
|
# Verify that api_params dict has the correct values
|
||||||
assert processor.api_params["temperature"] == 0.7
|
assert processor.api_params["temperature"] == 0.7
|
||||||
assert processor.api_params["max_output_tokens"] == 4096
|
assert processor.api_params["max_output_tokens"] == 4096
|
||||||
assert processor.api_params["top_p"] == 1.0
|
assert processor.api_params["top_p"] == 1.0
|
||||||
assert processor.api_params["top_k"] == 32
|
assert processor.api_params["top_k"] == 32
|
||||||
|
|
||||||
|
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_vertexai_initialization_with_credentials(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
async def test_vertexai_initialization_with_credentials(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||||
"""Test that VertexAI is initialized correctly with credentials"""
|
"""Test that VertexAI is initialized correctly with credentials"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_credentials = MagicMock()
|
mock_credentials = MagicMock()
|
||||||
mock_credentials.project_id = "test-project-123"
|
mock_credentials.project_id = "test-project-123"
|
||||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||||
mock_model = MagicMock()
|
|
||||||
mock_generative_model.return_value = mock_model
|
mock_client = MagicMock()
|
||||||
|
mock_genai.Client.return_value = mock_client
|
||||||
|
|
||||||
mock_async_init.return_value = None
|
mock_async_init.return_value = None
|
||||||
mock_llm_init.return_value = None
|
mock_llm_init.return_value = None
|
||||||
|
|
||||||
|
|
@ -347,35 +350,34 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# Verify VertexAI init was called with correct parameters
|
# Verify genai.Client was called with correct parameters
|
||||||
mock_vertexai.init.assert_called_once_with(
|
mock_genai.Client.assert_called_once_with(
|
||||||
|
vertexai=True,
|
||||||
|
project='test-project-123',
|
||||||
location='europe-west1',
|
location='europe-west1',
|
||||||
credentials=mock_credentials,
|
credentials=mock_credentials
|
||||||
project='test-project-123'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# GenerativeModel is now created lazily on first use, not at initialization
|
|
||||||
mock_generative_model.assert_not_called()
|
|
||||||
|
|
||||||
|
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||||
"""Test content generation with empty prompts"""
|
"""Test content generation with empty prompts"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_credentials = MagicMock()
|
mock_credentials = MagicMock()
|
||||||
|
mock_credentials.project_id = "test-project-123"
|
||||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||||
|
|
||||||
mock_model = MagicMock()
|
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.text = "Default response"
|
mock_response.text = "Default response"
|
||||||
mock_response.usage_metadata.prompt_token_count = 2
|
mock_response.usage_metadata.prompt_token_count = 2
|
||||||
mock_response.usage_metadata.candidates_token_count = 3
|
mock_response.usage_metadata.candidates_token_count = 3
|
||||||
mock_model.generate_content.return_value = mock_response
|
|
||||||
mock_generative_model.return_value = mock_model
|
mock_client = MagicMock()
|
||||||
|
mock_client.models.generate_content.return_value = mock_response
|
||||||
|
mock_genai.Client.return_value = mock_client
|
||||||
|
|
||||||
mock_async_init.return_value = None
|
mock_async_init.return_value = None
|
||||||
mock_llm_init.return_value = None
|
mock_llm_init.return_value = None
|
||||||
|
|
||||||
|
|
@ -401,27 +403,28 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
assert result.in_token == 2
|
assert result.in_token == 2
|
||||||
assert result.out_token == 3
|
assert result.out_token == 3
|
||||||
assert result.model == 'gemini-2.0-flash-001'
|
assert result.model == 'gemini-2.0-flash-001'
|
||||||
|
|
||||||
# Verify the model was called with the combined empty prompts
|
# Verify the client was called
|
||||||
mock_model.generate_content.assert_called_once()
|
mock_client.models.generate_content.assert_called_once()
|
||||||
call_args = mock_model.generate_content.call_args
|
|
||||||
# The prompt should be "" + "\n\n" + "" = "\n\n"
|
|
||||||
assert call_args[0][0] == "\n\n"
|
|
||||||
|
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.AnthropicVertex')
|
@patch('trustgraph.model.text_completion.vertexai.llm.AnthropicVertex')
|
||||||
|
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_anthropic_processor_initialization_with_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_anthropic_vertex):
|
async def test_anthropic_processor_initialization_with_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai, mock_anthropic_vertex):
|
||||||
"""Test Anthropic processor initialization with private key credentials"""
|
"""Test Anthropic processor initialization with private key credentials"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_async_init.return_value = None
|
mock_async_init.return_value = None
|
||||||
mock_llm_init.return_value = None
|
mock_llm_init.return_value = None
|
||||||
|
|
||||||
mock_credentials = MagicMock()
|
mock_credentials = MagicMock()
|
||||||
mock_credentials.project_id = "test-project-456"
|
mock_credentials.project_id = "test-project-456"
|
||||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_genai.Client.return_value = mock_client
|
||||||
|
|
||||||
# Mock AnthropicVertex
|
# Mock AnthropicVertex
|
||||||
mock_anthropic_client = MagicMock()
|
mock_anthropic_client = MagicMock()
|
||||||
mock_anthropic_vertex.return_value = mock_anthropic_client
|
mock_anthropic_vertex.return_value = mock_anthropic_client
|
||||||
|
|
@ -439,45 +442,45 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert processor.default_model == 'claude-3-sonnet@20240229'
|
assert processor.default_model == 'claude-3-sonnet@20240229'
|
||||||
# is_anthropic logic is now determined dynamically per request
|
|
||||||
|
|
||||||
# Verify service account was called with private key
|
# Verify service account was called with private key
|
||||||
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('anthropic-key.json')
|
mock_service_account.Credentials.from_service_account_file.assert_called_once()
|
||||||
|
|
||||||
# Verify AnthropicVertex was initialized with credentials
|
# Verify AnthropicVertex was initialized with credentials (because model contains 'claude')
|
||||||
mock_anthropic_vertex.assert_called_once_with(
|
mock_anthropic_vertex.assert_called_once_with(
|
||||||
region='us-west1',
|
region='us-west1',
|
||||||
project_id='test-project-456',
|
project_id='test-project-456',
|
||||||
credentials=mock_credentials
|
credentials=mock_credentials
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify api_params are set correctly
|
# Verify api_params are set correctly
|
||||||
assert processor.api_params["temperature"] == 0.5
|
assert processor.api_params["temperature"] == 0.5
|
||||||
assert processor.api_params["max_output_tokens"] == 2048
|
assert processor.api_params["max_output_tokens"] == 2048
|
||||||
assert processor.api_params["top_p"] == 1.0
|
assert processor.api_params["top_p"] == 1.0
|
||||||
assert processor.api_params["top_k"] == 32
|
assert processor.api_params["top_k"] == 32
|
||||||
|
|
||||||
|
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||||
"""Test temperature parameter override functionality"""
|
"""Test temperature parameter override functionality"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_credentials = MagicMock()
|
mock_credentials = MagicMock()
|
||||||
|
mock_credentials.project_id = "test-project-123"
|
||||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||||
|
|
||||||
mock_model = MagicMock()
|
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.text = "Response with custom temperature"
|
mock_response.text = "Response with custom temperature"
|
||||||
mock_response.usage_metadata.prompt_token_count = 20
|
mock_response.usage_metadata.prompt_token_count = 20
|
||||||
mock_response.usage_metadata.candidates_token_count = 12
|
mock_response.usage_metadata.candidates_token_count = 12
|
||||||
mock_model.generate_content.return_value = mock_response
|
|
||||||
mock_generative_model.return_value = mock_model
|
mock_client = MagicMock()
|
||||||
|
mock_client.models.generate_content.return_value = mock_response
|
||||||
|
mock_genai.Client.return_value = mock_client
|
||||||
|
|
||||||
mock_async_init.return_value = None
|
mock_async_init.return_value = None
|
||||||
mock_llm_init.return_value = None
|
mock_llm_init.return_value = None
|
||||||
|
|
@ -506,42 +509,27 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
# Assert
|
# Assert
|
||||||
assert isinstance(result, LlmResult)
|
assert isinstance(result, LlmResult)
|
||||||
assert result.text == "Response with custom temperature"
|
assert result.text == "Response with custom temperature"
|
||||||
|
mock_client.models.generate_content.assert_called_once()
|
||||||
|
|
||||||
# Verify Gemini API was called with overridden temperature
|
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||||
mock_model.generate_content.assert_called_once()
|
|
||||||
call_args = mock_model.generate_content.call_args
|
|
||||||
|
|
||||||
# Check that generation_config was created (we can't directly access temperature from mock)
|
|
||||||
generation_config = call_args.kwargs['generation_config']
|
|
||||||
assert generation_config is not None # Should use overridden temperature configuration
|
|
||||||
|
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||||
"""Test model parameter override functionality"""
|
"""Test model parameter override functionality"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_credentials = MagicMock()
|
mock_credentials = MagicMock()
|
||||||
|
mock_credentials.project_id = "test-project-123"
|
||||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||||
|
|
||||||
# Mock different models
|
|
||||||
mock_model_default = MagicMock()
|
|
||||||
mock_model_override = MagicMock()
|
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.text = "Response with custom model"
|
mock_response.text = "Response with custom model"
|
||||||
mock_response.usage_metadata.prompt_token_count = 18
|
mock_response.usage_metadata.prompt_token_count = 18
|
||||||
mock_response.usage_metadata.candidates_token_count = 14
|
mock_response.usage_metadata.candidates_token_count = 14
|
||||||
mock_model_override.generate_content.return_value = mock_response
|
|
||||||
|
|
||||||
# GenerativeModel should return different models based on input
|
mock_client = MagicMock()
|
||||||
def model_factory(model_name):
|
mock_client.models.generate_content.return_value = mock_response
|
||||||
if model_name == 'gemini-1.5-pro':
|
mock_genai.Client.return_value = mock_client
|
||||||
return mock_model_override
|
|
||||||
return mock_model_default
|
|
||||||
|
|
||||||
mock_generative_model.side_effect = model_factory
|
|
||||||
|
|
||||||
mock_async_init.return_value = None
|
mock_async_init.return_value = None
|
||||||
mock_llm_init.return_value = None
|
mock_llm_init.return_value = None
|
||||||
|
|
@ -549,7 +537,7 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
config = {
|
config = {
|
||||||
'region': 'us-central1',
|
'region': 'us-central1',
|
||||||
'model': 'gemini-2.0-flash-001', # Default model
|
'model': 'gemini-2.0-flash-001', # Default model
|
||||||
'temperature': 0.2, # Default temperature
|
'temperature': 0.2,
|
||||||
'max_output': 8192,
|
'max_output': 8192,
|
||||||
'private_key': 'private.json',
|
'private_key': 'private.json',
|
||||||
'concurrency': 1,
|
'concurrency': 1,
|
||||||
|
|
@ -571,29 +559,29 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
assert isinstance(result, LlmResult)
|
assert isinstance(result, LlmResult)
|
||||||
assert result.text == "Response with custom model"
|
assert result.text == "Response with custom model"
|
||||||
|
|
||||||
# Verify the overridden model was used
|
# Verify the call was made with the override model
|
||||||
mock_model_override.generate_content.assert_called_once()
|
call_args = mock_client.models.generate_content.call_args
|
||||||
# Verify GenerativeModel was called with the override model
|
assert call_args.kwargs['model'] == "gemini-1.5-pro"
|
||||||
mock_generative_model.assert_called_with('gemini-1.5-pro')
|
|
||||||
|
|
||||||
|
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
|
||||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||||
"""Test overriding both model and temperature parameters simultaneously"""
|
"""Test overriding both model and temperature parameters simultaneously"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_credentials = MagicMock()
|
mock_credentials = MagicMock()
|
||||||
|
mock_credentials.project_id = "test-project-123"
|
||||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||||
|
|
||||||
mock_model = MagicMock()
|
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.text = "Response with both overrides"
|
mock_response.text = "Response with both overrides"
|
||||||
mock_response.usage_metadata.prompt_token_count = 22
|
mock_response.usage_metadata.prompt_token_count = 22
|
||||||
mock_response.usage_metadata.candidates_token_count = 16
|
mock_response.usage_metadata.candidates_token_count = 16
|
||||||
mock_model.generate_content.return_value = mock_response
|
|
||||||
mock_generative_model.return_value = mock_model
|
mock_client = MagicMock()
|
||||||
|
mock_client.models.generate_content.return_value = mock_response
|
||||||
|
mock_genai.Client.return_value = mock_client
|
||||||
|
|
||||||
mock_async_init.return_value = None
|
mock_async_init.return_value = None
|
||||||
mock_llm_init.return_value = None
|
mock_llm_init.return_value = None
|
||||||
|
|
@ -622,18 +610,12 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
# Assert
|
# Assert
|
||||||
assert isinstance(result, LlmResult)
|
assert isinstance(result, LlmResult)
|
||||||
assert result.text == "Response with both overrides"
|
assert result.text == "Response with both overrides"
|
||||||
|
mock_client.models.generate_content.assert_called_once()
|
||||||
|
|
||||||
# Verify both overrides were used
|
# Verify the model override was used
|
||||||
mock_model.generate_content.assert_called_once()
|
call_args = mock_client.models.generate_content.call_args
|
||||||
call_args = mock_model.generate_content.call_args
|
assert call_args.kwargs['model'] == "gemini-1.5-flash-001"
|
||||||
|
|
||||||
# Verify model override
|
|
||||||
mock_generative_model.assert_called_with('gemini-1.5-flash-001') # Should use runtime override
|
|
||||||
|
|
||||||
# Verify temperature override (we can't directly access temperature from mock)
|
|
||||||
generation_config = call_args.kwargs['generation_config']
|
|
||||||
assert generation_config is not None # Should use overridden temperature configuration
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ dependencies = [
|
||||||
"falkordb",
|
"falkordb",
|
||||||
"fastembed",
|
"fastembed",
|
||||||
"google-genai",
|
"google-genai",
|
||||||
|
"google-api-core",
|
||||||
"ibis",
|
"ibis",
|
||||||
"jsonschema",
|
"jsonschema",
|
||||||
"langchain",
|
"langchain",
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ dependencies = [
|
||||||
"trustgraph-base>=2.0,<2.1",
|
"trustgraph-base>=2.0,<2.1",
|
||||||
"pulsar-client",
|
"pulsar-client",
|
||||||
"google-genai",
|
"google-genai",
|
||||||
|
"google-api-core",
|
||||||
"prometheus-client",
|
"prometheus-client",
|
||||||
"anthropic",
|
"anthropic",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue