mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-26 08:56:21 +02:00
Release/v1.2 (#457)
* Bump setup.py versions for 1.1 * PoC MCP server (#419) * Very initial MCP server PoC for TrustGraph * Put service on port 8000 * Add MCP container and packages to buildout * Update docs for API/CLI changes in 1.0 (#421) * Update some API basics for the 0.23/1.0 API change * Add MCP container push (#425) * Add command args to the MCP server (#426) * Host and port parameters * Added websocket arg * More docs * MCP client support (#427) - MCP client service - Tool request/response schema - API gateway support for mcp-tool - Message translation for tool request & response - Make mcp-tool using configuration service for information about where the MCP services are. * Feature/react call mcp (#428) Key Features - MCP Tool Integration: Added core MCP tool support with ToolClientSpec and ToolClient classes - API Enhancement: New mcp_tool method for flow-specific tool invocation - CLI Tooling: New tg-invoke-mcp-tool command for testing MCP integration - React Agent Enhancement: Fixed and improved multi-tool invocation capabilities - Tool Management: Enhanced CLI for tool configuration and management Changes - Added MCP tool invocation to API with flow-specific integration - Implemented ToolClientSpec and ToolClient for tool call handling - Updated agent-manager-react to invoke MCP tools with configurable types - Enhanced CLI with new commands and improved help text - Added comprehensive documentation for new CLI commands - Improved tool configuration management Testing - Added tg-invoke-mcp-tool CLI command for isolated MCP integration testing - Enhanced agent capability to invoke multiple tools simultaneously * Test suite executed from CI pipeline (#433) * Test strategy & test cases * Unit tests * Integration tests * Extending test coverage (#434) * Contract tests * Testing embeedings * Agent unit tests * Knowledge pipeline tests * Turn on contract tests * Increase storage test coverage (#435) * Fixing storage and adding tests * PR pipeline only runs quick tests * Empty configuration is returned as empty list, previously was not in response (#436) * Update config util to take files as well as command-line text (#437) * Updated CLI invocation and config model for tools and mcp (#438) * Updated CLI invocation and config model for tools and mcp * CLI anomalies * Tweaked the MCP tool implementation for new model * Update agent implementation to match the new model * Fix agent tools, now all tested * Fixed integration tests * Fix MCP delete tool params * Update Python deps to 1.2 * Update to enable knowledge extraction using the agent framework (#439) * Implement KG extraction agent (kg-extract-agent) * Using ReAct framework (agent-manager-react) * ReAct manager had an issue when emitting JSON, which conflicts which ReAct manager's own JSON messages, so refactored ReAct manager to use traditional ReAct messages, non-JSON structure. * Minor refactor to take the prompt template client out of prompt-template so it can be more readily used by other modules. kg-extract-agent uses this framework. * Migrate from setup.py to pyproject.toml (#440) * Converted setup.py to pyproject.toml * Modern package infrastructure as recommended by py docs * Install missing build deps (#441) * Install missing build deps (#442) * Implement logging strategy (#444) * Logging strategy and convert all prints() to logging invocations * Fix/startup failure (#445) * Fix loggin startup problems * Fix logging startup problems (#446) * Fix logging startup problems (#447) * Fixed Mistral OCR to use current API (#448) * Fixed Mistral OCR to use current API * Added PDF decoder tests * Fix Mistral OCR ident to be standard pdf-decoder (#450) * Fix Mistral OCR ident to be standard pdf-decoder * Correct test * Schema structure refactor (#451) * Write schema refactor spec * Implemented schema refactor spec * Structure data mvp (#452) * Structured data tech spec * Architecture principles * New schemas * Updated schemas and specs * Object extractor * Add .coveragerc * New tests * Cassandra object storage * Trying to object extraction working, issues exist * Validate librarian collection (#453) * Fix token chunker, broken API invocation (#454) * Fix token chunker, broken API invocation (#455) * Knowledge load utility CLI (#456) * Knowledge loader * More tests
This commit is contained in:
parent
c85ba197be
commit
89be656990
509 changed files with 49632 additions and 5159 deletions
3
tests/unit/test_text_completion/__init__.py
Normal file
3
tests/unit/test_text_completion/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
"""
|
||||
Unit tests for text completion services
|
||||
"""
|
||||
3
tests/unit/test_text_completion/common/__init__.py
Normal file
3
tests/unit/test_text_completion/common/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
"""
|
||||
Common utilities for text completion tests
|
||||
"""
|
||||
69
tests/unit/test_text_completion/common/base_test_cases.py
Normal file
69
tests/unit/test_text_completion/common/base_test_cases.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
"""
|
||||
Base test patterns that can be reused across different text completion models
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
|
||||
class BaseTextCompletionTestCase(IsolatedAsyncioTestCase, ABC):
|
||||
"""
|
||||
Base test class for text completion processors
|
||||
Provides common test patterns that can be reused
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_processor_class(self):
|
||||
"""Return the processor class to test"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_base_config(self):
|
||||
"""Return base configuration for the processor"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_mock_patches(self):
|
||||
"""Return list of patch decorators for mocking dependencies"""
|
||||
pass
|
||||
|
||||
def create_base_config(self, **overrides):
|
||||
"""Create base config with optional overrides"""
|
||||
config = self.get_base_config()
|
||||
config.update(overrides)
|
||||
return config
|
||||
|
||||
def create_mock_llm_result(self, text="Test response", in_token=10, out_token=5):
|
||||
"""Create a mock LLM result"""
|
||||
from trustgraph.base import LlmResult
|
||||
return LlmResult(text=text, in_token=in_token, out_token=out_token)
|
||||
|
||||
|
||||
class CommonTestPatterns:
|
||||
"""
|
||||
Common test patterns that can be used across different models
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def basic_initialization_test_pattern(test_instance):
|
||||
"""
|
||||
Test pattern for basic processor initialization
|
||||
test_instance should be a BaseTextCompletionTestCase
|
||||
"""
|
||||
# This would contain the common pattern for initialization testing
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def successful_generation_test_pattern(test_instance):
|
||||
"""
|
||||
Test pattern for successful content generation
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def error_handling_test_pattern(test_instance):
|
||||
"""
|
||||
Test pattern for error handling
|
||||
"""
|
||||
pass
|
||||
53
tests/unit/test_text_completion/common/mock_helpers.py
Normal file
53
tests/unit/test_text_completion/common/mock_helpers.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
"""
|
||||
Common mocking utilities for text completion tests
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
|
||||
class CommonMocks:
|
||||
"""Common mock objects used across text completion tests"""
|
||||
|
||||
@staticmethod
|
||||
def create_mock_async_processor_init():
|
||||
"""Create mock for AsyncProcessor.__init__"""
|
||||
mock = MagicMock()
|
||||
mock.return_value = None
|
||||
return mock
|
||||
|
||||
@staticmethod
|
||||
def create_mock_llm_service_init():
|
||||
"""Create mock for LlmService.__init__"""
|
||||
mock = MagicMock()
|
||||
mock.return_value = None
|
||||
return mock
|
||||
|
||||
@staticmethod
|
||||
def create_mock_response(text="Test response", prompt_tokens=10, completion_tokens=5):
|
||||
"""Create a mock response object"""
|
||||
response = MagicMock()
|
||||
response.text = text
|
||||
response.usage_metadata.prompt_token_count = prompt_tokens
|
||||
response.usage_metadata.candidates_token_count = completion_tokens
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def create_basic_config():
|
||||
"""Create basic config with required fields"""
|
||||
return {
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
|
||||
class MockPatches:
|
||||
"""Common patch decorators for different services"""
|
||||
|
||||
@staticmethod
|
||||
def get_base_patches():
|
||||
"""Get patches that are common to all processors"""
|
||||
return [
|
||||
'trustgraph.base.async_processor.AsyncProcessor.__init__',
|
||||
'trustgraph.base.llm_service.LlmService.__init__'
|
||||
]
|
||||
499
tests/unit/test_text_completion/conftest.py
Normal file
499
tests/unit/test_text_completion/conftest.py
Normal file
|
|
@ -0,0 +1,499 @@
|
|||
"""
|
||||
Pytest configuration and fixtures for text completion tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
from trustgraph.base import LlmResult
|
||||
|
||||
|
||||
# === Common Fixtures for All Text Completion Models ===
|
||||
|
||||
@pytest.fixture
|
||||
def base_processor_config():
|
||||
"""Base configuration required by all processors"""
|
||||
return {
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_llm_result():
|
||||
"""Sample LlmResult for testing"""
|
||||
return LlmResult(
|
||||
text="Test response",
|
||||
in_token=10,
|
||||
out_token=5
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_processor_init():
|
||||
"""Mock AsyncProcessor.__init__ to avoid infrastructure requirements"""
|
||||
mock = MagicMock()
|
||||
mock.return_value = None
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_service_init():
|
||||
"""Mock LlmService.__init__ to avoid infrastructure requirements"""
|
||||
mock = MagicMock()
|
||||
mock.return_value = None
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prometheus_metrics():
|
||||
"""Mock Prometheus metrics"""
|
||||
mock_metric = MagicMock()
|
||||
mock_metric.labels.return_value.time.return_value = MagicMock()
|
||||
return mock_metric
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pulsar_consumer():
|
||||
"""Mock Pulsar consumer for integration testing"""
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pulsar_producer():
|
||||
"""Mock Pulsar producer for integration testing"""
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_env_vars(monkeypatch):
|
||||
"""Mock environment variables for testing"""
|
||||
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "test-project")
|
||||
monkeypatch.setenv("GOOGLE_APPLICATION_CREDENTIALS", "/path/to/test-credentials.json")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_context_manager():
|
||||
"""Mock async context manager for testing"""
|
||||
class MockAsyncContextManager:
|
||||
def __init__(self, return_value):
|
||||
self.return_value = return_value
|
||||
|
||||
async def __aenter__(self):
|
||||
return self.return_value
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
return MockAsyncContextManager
|
||||
|
||||
|
||||
# === VertexAI Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vertexai_credentials():
|
||||
"""Mock Google Cloud service account credentials"""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vertexai_model():
|
||||
"""Mock VertexAI GenerativeModel"""
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Test response"
|
||||
mock_response.usage_metadata.prompt_token_count = 10
|
||||
mock_response.usage_metadata.candidates_token_count = 5
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
return mock_model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vertexai_processor_config(base_processor_config):
|
||||
"""Default configuration for VertexAI processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json'
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_safety_settings():
|
||||
"""Mock safety settings for VertexAI"""
|
||||
safety_settings = []
|
||||
for i in range(4): # 4 safety categories
|
||||
setting = MagicMock()
|
||||
setting.category = f"HARM_CATEGORY_{i}"
|
||||
setting.threshold = "BLOCK_MEDIUM_AND_ABOVE"
|
||||
safety_settings.append(setting)
|
||||
|
||||
return safety_settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_generation_config():
|
||||
"""Mock generation configuration for VertexAI"""
|
||||
config = MagicMock()
|
||||
config.temperature = 0.0
|
||||
config.max_output_tokens = 8192
|
||||
config.top_p = 1.0
|
||||
config.top_k = 10
|
||||
config.candidate_count = 1
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vertexai_exception():
|
||||
"""Mock VertexAI exceptions"""
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
return ResourceExhausted("Test resource exhausted error")
|
||||
|
||||
|
||||
# === Ollama Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def ollama_processor_config(base_processor_config):
|
||||
"""Default configuration for Ollama processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'model': 'llama2',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'host': 'localhost',
|
||||
'port': 11434
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ollama_client():
|
||||
"""Mock Ollama client"""
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Test response from Ollama',
|
||||
'done': True,
|
||||
'eval_count': 5,
|
||||
'prompt_eval_count': 10
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
return mock_client
|
||||
|
||||
|
||||
# === OpenAI Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def openai_processor_config(base_processor_config):
|
||||
"""Default configuration for OpenAI processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_client():
|
||||
"""Mock OpenAI client"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Mock the response structure
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Test response from OpenAI"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 8
|
||||
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_rate_limit_error():
|
||||
"""Mock OpenAI rate limit error"""
|
||||
from openai import RateLimitError
|
||||
return RateLimitError("Rate limit exceeded", response=MagicMock(), body=None)
|
||||
|
||||
|
||||
# === Azure OpenAI Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def azure_openai_processor_config(base_processor_config):
|
||||
"""Default configuration for Azure OpenAI processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_azure_openai_client():
|
||||
"""Mock Azure OpenAI client"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Mock the response structure
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Test response from Azure OpenAI"
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_azure_openai_rate_limit_error():
|
||||
"""Mock Azure OpenAI rate limit error"""
|
||||
from openai import RateLimitError
|
||||
return RateLimitError("Rate limit exceeded", response=MagicMock(), body=None)
|
||||
|
||||
|
||||
# === Azure Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def azure_processor_config(base_processor_config):
|
||||
"""Default configuration for Azure processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_azure_requests():
|
||||
"""Mock requests for Azure processor"""
|
||||
mock_requests = MagicMock()
|
||||
|
||||
# Mock successful response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Test response from Azure'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 18,
|
||||
'completion_tokens': 9
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
return mock_requests
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_azure_rate_limit_response():
|
||||
"""Mock Azure rate limit response"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 429
|
||||
return mock_response
|
||||
|
||||
|
||||
# === Claude Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def claude_processor_config(base_processor_config):
|
||||
"""Default configuration for Claude processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_claude_client():
|
||||
"""Mock Claude (Anthropic) client"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Mock the response structure
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Test response from Claude"
|
||||
mock_response.usage.input_tokens = 22
|
||||
mock_response.usage.output_tokens = 12
|
||||
|
||||
mock_client.messages.create.return_value = mock_response
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_claude_rate_limit_error():
|
||||
"""Mock Claude rate limit error"""
|
||||
import anthropic
|
||||
return anthropic.RateLimitError("Rate limit exceeded", response=MagicMock(), body=None)
|
||||
|
||||
|
||||
# === vLLM Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def vllm_processor_config(base_processor_config):
|
||||
"""Default configuration for vLLM processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vllm_session():
|
||||
"""Mock aiohttp ClientSession for vLLM"""
|
||||
mock_session = MagicMock()
|
||||
|
||||
# Mock successful response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Test response from vLLM'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 16,
|
||||
'completion_tokens': 8
|
||||
}
|
||||
})
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
|
||||
return mock_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vllm_error_response():
|
||||
"""Mock vLLM error response"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 500
|
||||
return mock_response
|
||||
|
||||
|
||||
# === Cohere Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def cohere_processor_config(base_processor_config):
|
||||
"""Default configuration for Cohere processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cohere_client():
|
||||
"""Mock Cohere client"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Mock the response structure
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = "Test response from Cohere"
|
||||
mock_output.meta.billed_units.input_tokens = 18
|
||||
mock_output.meta.billed_units.output_tokens = 10
|
||||
|
||||
mock_client.chat.return_value = mock_output
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cohere_rate_limit_error():
|
||||
"""Mock Cohere rate limit error"""
|
||||
import cohere
|
||||
return cohere.TooManyRequestsError("Rate limit exceeded")
|
||||
|
||||
|
||||
# === Google AI Studio Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def googleaistudio_processor_config(base_processor_config):
|
||||
"""Default configuration for Google AI Studio processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_googleaistudio_client():
|
||||
"""Mock Google AI Studio client"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Mock the response structure
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Test response from Google AI Studio"
|
||||
mock_response.usage_metadata.prompt_token_count = 20
|
||||
mock_response.usage_metadata.candidates_token_count = 12
|
||||
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_googleaistudio_rate_limit_error():
|
||||
"""Mock Google AI Studio rate limit error"""
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
return ResourceExhausted("Rate limit exceeded")
|
||||
|
||||
|
||||
# === LlamaFile Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def llamafile_processor_config(base_processor_config):
|
||||
"""Default configuration for LlamaFile processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llamafile_client():
|
||||
"""Mock OpenAI client for LlamaFile"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Mock the response structure
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Test response from LlamaFile"
|
||||
mock_response.usage.prompt_tokens = 14
|
||||
mock_response.usage.completion_tokens = 8
|
||||
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
return mock_client
|
||||
407
tests/unit/test_text_completion/test_azure_openai_processor.py
Normal file
407
tests/unit/test_text_completion/test_azure_openai_processor.py
Normal file
|
|
@ -0,0 +1,407 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.azure_openai
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.azure_openai.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Azure OpenAI processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-4'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4192
|
||||
assert hasattr(processor, 'openai')
|
||||
mock_azure_openai_class.assert_called_once_with(
|
||||
api_key='test-token',
|
||||
api_version='2024-12-01-preview',
|
||||
azure_endpoint='https://test.openai.azure.com/'
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Generated response from Azure OpenAI"
|
||||
mock_response.usage.prompt_tokens = 25
|
||||
mock_response.usage.completion_tokens = 15
|
||||
|
||||
mock_azure_client.chat.completions.create.return_value = mock_response
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Azure OpenAI"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
assert result.model == 'gpt-4'
|
||||
|
||||
# Verify the Azure OpenAI API call
|
||||
mock_azure_client.chat.completions.create.assert_called_once_with(
|
||||
model='gpt-4',
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "System prompt\n\nUser prompt"
|
||||
}]
|
||||
}],
|
||||
temperature=0.0,
|
||||
max_tokens=4192,
|
||||
top_p=1
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test rate limit error handling"""
|
||||
# Arrange
|
||||
from openai import RateLimitError
|
||||
|
||||
mock_azure_client = MagicMock()
|
||||
mock_azure_client.chat.completions.create.side_effect = RateLimitError("Rate limit exceeded", response=MagicMock(), body=None)
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TooManyRequests):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_azure_client.chat.completions.create.side_effect = Exception("Azure API connection error")
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Azure API connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_endpoint(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test processor initialization without endpoint (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': None, # No endpoint provided
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Azure endpoint not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_token(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test processor initialization without token (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': None, # No token provided
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Azure token not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-35-turbo',
|
||||
'endpoint': 'https://custom.openai.azure.com/',
|
||||
'token': 'custom-token',
|
||||
'api_version': '2023-05-15',
|
||||
'temperature': 0.7,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-35-turbo'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 2048
|
||||
mock_azure_openai_class.assert_called_once_with(
|
||||
api_key='custom-token',
|
||||
api_version='2023-05-15',
|
||||
azure_endpoint='https://custom.openai.azure.com/'
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Only provide required fields, should use defaults
|
||||
config = {
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'model': 'gpt-4', # Required for Azure
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-4'
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 4192 # default_max_output
|
||||
mock_azure_openai_class.assert_called_once_with(
|
||||
api_key='test-token',
|
||||
api_version='2024-12-01-preview', # default_api
|
||||
azure_endpoint='https://test.openai.azure.com/'
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Default response"
|
||||
mock_response.usage.prompt_tokens = 2
|
||||
mock_response.usage.completion_tokens = 3
|
||||
|
||||
mock_azure_client.chat.completions.create.return_value = mock_response
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'gpt-4'
|
||||
|
||||
# Verify the combined prompt is sent correctly
|
||||
call_args = mock_azure_client.chat.completions.create.call_args
|
||||
expected_prompt = "\n\n" # Empty system + "\n\n" + empty user
|
||||
assert call_args[1]['messages'][0]['content'][0]['text'] == expected_prompt
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_message_structure(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test that Azure OpenAI messages are structured correctly"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with proper structure"
|
||||
mock_response.usage.prompt_tokens = 30
|
||||
mock_response.usage.completion_tokens = 20
|
||||
|
||||
mock_azure_client.chat.completions.create.return_value = mock_response
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.5,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with proper structure"
|
||||
assert result.in_token == 30
|
||||
assert result.out_token == 20
|
||||
|
||||
# Verify the message structure matches Azure OpenAI Chat API format
|
||||
call_args = mock_azure_client.chat.completions.create.call_args
|
||||
messages = call_args[1]['messages']
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0]['role'] == 'user'
|
||||
assert messages[0]['content'][0]['type'] == 'text'
|
||||
assert messages[0]['content'][0]['text'] == "You are a helpful assistant\n\nWhat is AI?"
|
||||
|
||||
# Verify other parameters
|
||||
assert call_args[1]['model'] == 'gpt-4'
|
||||
assert call_args[1]['temperature'] == 0.5
|
||||
assert call_args[1]['max_tokens'] == 1024
|
||||
assert call_args[1]['top_p'] == 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
463
tests/unit/test_text_completion/test_azure_processor.py
Normal file
463
tests/unit/test_text_completion/test_azure_processor.py
Normal file
|
|
@ -0,0 +1,463 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.azure
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.azure.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Azure processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.endpoint == 'https://test.inference.ai.azure.com/v1/chat/completions'
|
||||
assert processor.token == 'test-token'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4192
|
||||
assert processor.model == 'AzureAI'
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Generated response from Azure'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 20,
|
||||
'completion_tokens': 12
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Azure"
|
||||
assert result.in_token == 20
|
||||
assert result.out_token == 12
|
||||
assert result.model == 'AzureAI'
|
||||
|
||||
# Verify the API call was made correctly
|
||||
mock_requests.post.assert_called_once()
|
||||
call_args = mock_requests.post.call_args
|
||||
|
||||
# Check URL
|
||||
assert call_args[0][0] == 'https://test.inference.ai.azure.com/v1/chat/completions'
|
||||
|
||||
# Check headers
|
||||
headers = call_args[1]['headers']
|
||||
assert headers['Content-Type'] == 'application/json'
|
||||
assert headers['Authorization'] == 'Bearer test-token'
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test rate limit error handling"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 429
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TooManyRequests):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_http_error(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test HTTP error handling"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="LLM failure"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_requests.post.side_effect = Exception("Connection error")
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_endpoint(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test processor initialization without endpoint (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': None, # No endpoint provided
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Azure endpoint not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_token(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test processor initialization without token (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': None, # No token provided
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Azure token not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://custom.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'custom-token',
|
||||
'temperature': 0.7,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.endpoint == 'https://custom.inference.ai.azure.com/v1/chat/completions'
|
||||
assert processor.token == 'custom-token'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 2048
|
||||
assert processor.model == 'AzureAI'
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Only provide required fields, should use defaults
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.endpoint == 'https://test.inference.ai.azure.com/v1/chat/completions'
|
||||
assert processor.token == 'test-token'
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 4192 # default_max_output
|
||||
assert processor.model == 'AzureAI' # default_model
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Default response'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 2,
|
||||
'completion_tokens': 3
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'AzureAI'
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_build_prompt_structure(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test that build_prompt creates correct message structure"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response with proper structure'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 25,
|
||||
'completion_tokens': 15
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.5,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with proper structure"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
|
||||
# Verify the request structure
|
||||
mock_requests.post.assert_called_once()
|
||||
call_args = mock_requests.post.call_args
|
||||
|
||||
# Parse the request body
|
||||
import json
|
||||
request_body = json.loads(call_args[1]['data'])
|
||||
|
||||
# Verify message structure
|
||||
assert 'messages' in request_body
|
||||
assert len(request_body['messages']) == 2
|
||||
|
||||
# Check system message
|
||||
assert request_body['messages'][0]['role'] == 'system'
|
||||
assert request_body['messages'][0]['content'] == 'You are a helpful assistant'
|
||||
|
||||
# Check user message
|
||||
assert request_body['messages'][1]['role'] == 'user'
|
||||
assert request_body['messages'][1]['content'] == 'What is AI?'
|
||||
|
||||
# Check parameters
|
||||
assert request_body['temperature'] == 0.5
|
||||
assert request_body['max_tokens'] == 1024
|
||||
assert request_body['top_p'] == 1
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_call_llm_method(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test the call_llm method directly"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Test response'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 10,
|
||||
'completion_tokens': 5
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = processor.call_llm('{"test": "body"}')
|
||||
|
||||
# Assert
|
||||
assert result == mock_response.json.return_value
|
||||
|
||||
# Verify the request was made correctly
|
||||
mock_requests.post.assert_called_once_with(
|
||||
'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
data='{"test": "body"}',
|
||||
headers={
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': 'Bearer test-token'
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
440
tests/unit/test_text_completion/test_claude_processor.py
Normal file
440
tests/unit/test_text_completion/test_claude_processor.py
Normal file
|
|
@ -0,0 +1,440 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.claude
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.claude.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Claude processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'claude-3-5-sonnet-20240620'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 8192
|
||||
assert hasattr(processor, 'claude')
|
||||
mock_anthropic_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Generated response from Claude"
|
||||
mock_response.usage.input_tokens = 25
|
||||
mock_response.usage.output_tokens = 15
|
||||
|
||||
mock_claude_client.messages.create.return_value = mock_response
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Claude"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
assert result.model == 'claude-3-5-sonnet-20240620'
|
||||
|
||||
# Verify the Claude API call
|
||||
mock_claude_client.messages.create.assert_called_once_with(
|
||||
model='claude-3-5-sonnet-20240620',
|
||||
max_tokens=8192,
|
||||
temperature=0.0,
|
||||
system="System prompt",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "User prompt"
|
||||
}]
|
||||
}]
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test rate limit error handling"""
|
||||
# Arrange
|
||||
import anthropic
|
||||
|
||||
mock_claude_client = MagicMock()
|
||||
mock_claude_client.messages.create.side_effect = anthropic.RateLimitError(
|
||||
"Rate limit exceeded",
|
||||
response=MagicMock(),
|
||||
body=None
|
||||
)
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TooManyRequests):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_claude_client.messages.create.side_effect = Exception("API connection error")
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="API connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_api_key(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test processor initialization without API key (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': None, # No API key provided
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Claude API key not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-haiku-20240307',
|
||||
'api_key': 'custom-api-key',
|
||||
'temperature': 0.7,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'claude-3-haiku-20240307'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 4096
|
||||
mock_anthropic_class.assert_called_once_with(api_key='custom-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Only provide required fields, should use defaults
|
||||
config = {
|
||||
'api_key': 'test-api-key',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'claude-3-5-sonnet-20240620' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 8192 # default_max_output
|
||||
mock_anthropic_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Default response"
|
||||
mock_response.usage.input_tokens = 2
|
||||
mock_response.usage.output_tokens = 3
|
||||
|
||||
mock_claude_client.messages.create.return_value = mock_response
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'claude-3-5-sonnet-20240620'
|
||||
|
||||
# Verify the system prompt and user content are handled correctly
|
||||
call_args = mock_claude_client.messages.create.call_args
|
||||
assert call_args[1]['system'] == ""
|
||||
assert call_args[1]['messages'][0]['content'][0]['text'] == ""
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_message_structure(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test that Claude messages are structured correctly"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Response with proper structure"
|
||||
mock_response.usage.input_tokens = 30
|
||||
mock_response.usage.output_tokens = 20
|
||||
|
||||
mock_claude_client.messages.create.return_value = mock_response
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.5,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with proper structure"
|
||||
assert result.in_token == 30
|
||||
assert result.out_token == 20
|
||||
|
||||
# Verify the message structure matches Claude API format
|
||||
call_args = mock_claude_client.messages.create.call_args
|
||||
|
||||
# Check system prompt
|
||||
assert call_args[1]['system'] == "You are a helpful assistant"
|
||||
|
||||
# Check user message structure
|
||||
messages = call_args[1]['messages']
|
||||
assert len(messages) == 1
|
||||
assert messages[0]['role'] == 'user'
|
||||
assert messages[0]['content'][0]['type'] == 'text'
|
||||
assert messages[0]['content'][0]['text'] == "What is AI?"
|
||||
|
||||
# Verify other parameters
|
||||
assert call_args[1]['model'] == 'claude-3-5-sonnet-20240620'
|
||||
assert call_args[1]['temperature'] == 0.5
|
||||
assert call_args[1]['max_tokens'] == 1024
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_multiple_content_blocks(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test handling of multiple content blocks in response"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
|
||||
# Mock multiple content blocks (Claude can return multiple)
|
||||
mock_content_1 = MagicMock()
|
||||
mock_content_1.text = "First part of response"
|
||||
mock_content_2 = MagicMock()
|
||||
mock_content_2.text = "Second part of response"
|
||||
mock_response.content = [mock_content_1, mock_content_2]
|
||||
|
||||
mock_response.usage.input_tokens = 40
|
||||
mock_response.usage.output_tokens = 30
|
||||
|
||||
mock_claude_client.messages.create.return_value = mock_response
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
# Should take the first content block
|
||||
assert result.text == "First part of response"
|
||||
assert result.in_token == 40
|
||||
assert result.out_token == 30
|
||||
assert result.model == 'claude-3-5-sonnet-20240620'
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_claude_client_initialization(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test that Claude client is initialized correctly"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-opus-20240229',
|
||||
'api_key': 'sk-ant-test-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify Anthropic client was called with correct API key
|
||||
mock_anthropic_class.assert_called_once_with(api_key='sk-ant-test-key')
|
||||
|
||||
# Verify processor has the client
|
||||
assert processor.claude == mock_claude_client
|
||||
assert processor.model == 'claude-3-opus-20240229'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
447
tests/unit/test_text_completion/test_cohere_processor.py
Normal file
447
tests/unit/test_text_completion/test_cohere_processor.py
Normal file
|
|
@ -0,0 +1,447 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.cohere
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.cohere.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Cohere processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'c4ai-aya-23-8b'
|
||||
assert processor.temperature == 0.0
|
||||
assert hasattr(processor, 'cohere')
|
||||
mock_cohere_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = "Generated response from Cohere"
|
||||
mock_output.meta.billed_units.input_tokens = 25
|
||||
mock_output.meta.billed_units.output_tokens = 15
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Cohere"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
assert result.model == 'c4ai-aya-23-8b'
|
||||
|
||||
# Verify the Cohere API call
|
||||
mock_cohere_client.chat.assert_called_once_with(
|
||||
model='c4ai-aya-23-8b',
|
||||
message="User prompt",
|
||||
preamble="System prompt",
|
||||
temperature=0.0,
|
||||
chat_history=[],
|
||||
prompt_truncation='auto',
|
||||
connectors=[]
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test rate limit error handling"""
|
||||
# Arrange
|
||||
import cohere
|
||||
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_cohere_client.chat.side_effect = cohere.TooManyRequestsError("Rate limit exceeded")
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TooManyRequests):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_cohere_client.chat.side_effect = Exception("API connection error")
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="API connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_api_key(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test processor initialization without API key (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': None, # No API key provided
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Cohere API key not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'command-light',
|
||||
'api_key': 'custom-api-key',
|
||||
'temperature': 0.7,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'command-light'
|
||||
assert processor.temperature == 0.7
|
||||
mock_cohere_class.assert_called_once_with(api_key='custom-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Only provide required fields, should use defaults
|
||||
config = {
|
||||
'api_key': 'test-api-key',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'c4ai-aya-23-8b' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
mock_cohere_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = "Default response"
|
||||
mock_output.meta.billed_units.input_tokens = 2
|
||||
mock_output.meta.billed_units.output_tokens = 3
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'c4ai-aya-23-8b'
|
||||
|
||||
# Verify the preamble and message are handled correctly
|
||||
call_args = mock_cohere_client.chat.call_args
|
||||
assert call_args[1]['preamble'] == ""
|
||||
assert call_args[1]['message'] == ""
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_chat_structure(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test that Cohere chat is structured correctly"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = "Response with proper structure"
|
||||
mock_output.meta.billed_units.input_tokens = 30
|
||||
mock_output.meta.billed_units.output_tokens = 20
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.5,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with proper structure"
|
||||
assert result.in_token == 30
|
||||
assert result.out_token == 20
|
||||
|
||||
# Verify the chat structure matches Cohere API format
|
||||
call_args = mock_cohere_client.chat.call_args
|
||||
|
||||
# Check parameters
|
||||
assert call_args[1]['model'] == 'c4ai-aya-23-8b'
|
||||
assert call_args[1]['message'] == "What is AI?"
|
||||
assert call_args[1]['preamble'] == "You are a helpful assistant"
|
||||
assert call_args[1]['temperature'] == 0.5
|
||||
assert call_args[1]['chat_history'] == []
|
||||
assert call_args[1]['prompt_truncation'] == 'auto'
|
||||
assert call_args[1]['connectors'] == []
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_token_parsing(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test token parsing from Cohere response"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = "Token parsing test"
|
||||
mock_output.meta.billed_units.input_tokens = 50
|
||||
mock_output.meta.billed_units.output_tokens = 25
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System", "User query")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Token parsing test"
|
||||
assert result.in_token == 50
|
||||
assert result.out_token == 25
|
||||
assert result.model == 'c4ai-aya-23-8b'
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_cohere_client_initialization(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test that Cohere client is initialized correctly"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'command-r',
|
||||
'api_key': 'co-test-key',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify Cohere client was called with correct API key
|
||||
mock_cohere_class.assert_called_once_with(api_key='co-test-key')
|
||||
|
||||
# Verify processor has the client
|
||||
assert processor.cohere == mock_cohere_client
|
||||
assert processor.model == 'command-r'
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_chat_parameters(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test that all chat parameters are passed correctly"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = "Chat parameter test"
|
||||
mock_output.meta.billed_units.input_tokens = 20
|
||||
mock_output.meta.billed_units.output_tokens = 10
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.3,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System instructions", "User question")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Chat parameter test"
|
||||
|
||||
# Verify all parameters are passed correctly
|
||||
call_args = mock_cohere_client.chat.call_args
|
||||
assert call_args[1]['model'] == 'c4ai-aya-23-8b'
|
||||
assert call_args[1]['message'] == "User question"
|
||||
assert call_args[1]['preamble'] == "System instructions"
|
||||
assert call_args[1]['temperature'] == 0.3
|
||||
assert call_args[1]['chat_history'] == []
|
||||
assert call_args[1]['prompt_truncation'] == 'auto'
|
||||
assert call_args[1]['connectors'] == []
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
482
tests/unit/test_text_completion/test_googleaistudio_processor.py
Normal file
482
tests/unit/test_text_completion/test_googleaistudio_processor.py
Normal file
|
|
@ -0,0 +1,482 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.googleaistudio
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.googleaistudio.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Google AI Studio processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-2.0-flash-001'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 8192
|
||||
assert hasattr(processor, 'client')
|
||||
assert hasattr(processor, 'safety_settings')
|
||||
assert len(processor.safety_settings) == 4 # 4 safety categories
|
||||
mock_genai_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Generated response from Google AI Studio"
|
||||
mock_response.usage_metadata.prompt_token_count = 25
|
||||
mock_response.usage_metadata.candidates_token_count = 15
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Google AI Studio"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
assert result.model == 'gemini-2.0-flash-001'
|
||||
|
||||
# Verify the Google AI Studio API call
|
||||
mock_genai_client.models.generate_content.assert_called_once()
|
||||
call_args = mock_genai_client.models.generate_content.call_args
|
||||
assert call_args[1]['model'] == 'gemini-2.0-flash-001'
|
||||
assert call_args[1]['contents'] == "User prompt"
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test rate limit error handling"""
|
||||
# Arrange
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
|
||||
mock_genai_client = MagicMock()
|
||||
mock_genai_client.models.generate_content.side_effect = ResourceExhausted("Rate limit exceeded")
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TooManyRequests):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_genai_client.models.generate_content.side_effect = Exception("API connection error")
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="API connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_api_key(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test processor initialization without API key (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': None, # No API key provided
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Google AI Studio API key not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-1.5-pro',
|
||||
'api_key': 'custom-api-key',
|
||||
'temperature': 0.7,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-1.5-pro'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 4096
|
||||
mock_genai_class.assert_called_once_with(api_key='custom-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Only provide required fields, should use defaults
|
||||
config = {
|
||||
'api_key': 'test-api-key',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-2.0-flash-001' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 8192 # default_max_output
|
||||
mock_genai_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Default response"
|
||||
mock_response.usage_metadata.prompt_token_count = 2
|
||||
mock_response.usage_metadata.candidates_token_count = 3
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'gemini-2.0-flash-001'
|
||||
|
||||
# Verify the system instruction and content are handled correctly
|
||||
call_args = mock_genai_client.models.generate_content.call_args
|
||||
assert call_args[1]['contents'] == ""
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_configuration_structure(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test that generation configuration is structured correctly"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Response with proper structure"
|
||||
mock_response.usage_metadata.prompt_token_count = 30
|
||||
mock_response.usage_metadata.candidates_token_count = 20
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.5,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with proper structure"
|
||||
assert result.in_token == 30
|
||||
assert result.out_token == 20
|
||||
|
||||
# Verify the generation configuration
|
||||
call_args = mock_genai_client.models.generate_content.call_args
|
||||
config_arg = call_args[1]['config']
|
||||
|
||||
# Check that the configuration has the right structure
|
||||
assert call_args[1]['model'] == 'gemini-2.0-flash-001'
|
||||
assert call_args[1]['contents'] == "What is AI?"
|
||||
# Config should be a GenerateContentConfig object
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_safety_settings_initialization(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test that safety settings are initialized correctly"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert hasattr(processor, 'safety_settings')
|
||||
assert len(processor.safety_settings) == 4
|
||||
# Should have 4 safety categories: hate speech, harassment, sexually explicit, dangerous content
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_token_parsing(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test token parsing from Google AI Studio response"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Token parsing test"
|
||||
mock_response.usage_metadata.prompt_token_count = 50
|
||||
mock_response.usage_metadata.candidates_token_count = 25
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System", "User query")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Token parsing test"
|
||||
assert result.in_token == 50
|
||||
assert result.out_token == 25
|
||||
assert result.model == 'gemini-2.0-flash-001'
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_genai_client_initialization(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test that Google AI Studio client is initialized correctly"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-1.5-flash',
|
||||
'api_key': 'gai-test-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify Google AI Studio client was called with correct API key
|
||||
mock_genai_class.assert_called_once_with(api_key='gai-test-key')
|
||||
|
||||
# Verify processor has the client
|
||||
assert processor.client == mock_genai_client
|
||||
assert processor.model == 'gemini-1.5-flash'
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_system_instruction(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test that system instruction is handled correctly"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "System instruction test"
|
||||
mock_response.usage_metadata.prompt_token_count = 35
|
||||
mock_response.usage_metadata.candidates_token_count = 25
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("Be helpful and concise", "Explain quantum computing")
|
||||
|
||||
# Assert
|
||||
assert result.text == "System instruction test"
|
||||
assert result.in_token == 35
|
||||
assert result.out_token == 25
|
||||
|
||||
# Verify the system instruction is passed in the config
|
||||
call_args = mock_genai_client.models.generate_content.call_args
|
||||
config_arg = call_args[1]['config']
|
||||
# The system instruction should be in the config object
|
||||
assert call_args[1]['contents'] == "Explain quantum computing"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
454
tests/unit/test_text_completion/test_llamafile_processor.py
Normal file
454
tests/unit/test_text_completion/test_llamafile_processor.py
Normal file
|
|
@ -0,0 +1,454 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.llamafile
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.llamafile.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test LlamaFile processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'LLaMA_CPP'
|
||||
assert processor.llamafile == 'http://localhost:8080/v1'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4096
|
||||
assert hasattr(processor, 'openai')
|
||||
mock_openai_class.assert_called_once_with(
|
||||
base_url='http://localhost:8080/v1',
|
||||
api_key='sk-no-key-required'
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Generated response from LlamaFile"
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 12
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from LlamaFile"
|
||||
assert result.in_token == 20
|
||||
assert result.out_token == 12
|
||||
assert result.model == 'llama.cpp' # Note: model in result is hardcoded to 'llama.cpp'
|
||||
|
||||
# Verify the OpenAI API call structure
|
||||
mock_openai_client.chat.completions.create.assert_called_once_with(
|
||||
model='LLaMA_CPP',
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "System prompt\n\nUser prompt"
|
||||
}]
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_client.chat.completions.create.side_effect = Exception("Connection error")
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'custom-llama',
|
||||
'llamafile': 'http://custom-host:8080/v1',
|
||||
'temperature': 0.7,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'custom-llama'
|
||||
assert processor.llamafile == 'http://custom-host:8080/v1'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 2048
|
||||
mock_openai_class.assert_called_once_with(
|
||||
base_url='http://custom-host:8080/v1',
|
||||
api_key='sk-no-key-required'
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Only provide required fields, should use defaults
|
||||
config = {
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'LLaMA_CPP' # default_model
|
||||
assert processor.llamafile == 'http://localhost:8080/v1' # default_llamafile
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 4096 # default_max_output
|
||||
mock_openai_class.assert_called_once_with(
|
||||
base_url='http://localhost:8080/v1',
|
||||
api_key='sk-no-key-required'
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Default response"
|
||||
mock_response.usage.prompt_tokens = 2
|
||||
mock_response.usage.completion_tokens = 3
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'llama.cpp'
|
||||
|
||||
# Verify the combined prompt is sent correctly
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
expected_prompt = "\n\n" # Empty system + "\n\n" + empty user
|
||||
assert call_args[1]['messages'][0]['content'] == expected_prompt
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_message_structure(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test that LlamaFile messages are structured correctly"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with proper structure"
|
||||
mock_response.usage.prompt_tokens = 25
|
||||
mock_response.usage.completion_tokens = 15
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with proper structure"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
|
||||
# Verify the message structure
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
messages = call_args[1]['messages']
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0]['role'] == 'user'
|
||||
assert messages[0]['content'] == "You are a helpful assistant\n\nWhat is AI?"
|
||||
|
||||
# Verify model parameter
|
||||
assert call_args[1]['model'] == 'LLaMA_CPP'
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_openai_client_initialization(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test that OpenAI client is initialized correctly for LlamaFile"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama-custom',
|
||||
'llamafile': 'http://llamafile-server:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify OpenAI client was called with correct parameters
|
||||
mock_openai_class.assert_called_once_with(
|
||||
base_url='http://llamafile-server:8080/v1',
|
||||
api_key='sk-no-key-required'
|
||||
)
|
||||
|
||||
# Verify processor has the client
|
||||
assert processor.openai == mock_openai_client
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test prompt construction with system and user prompts"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with system instructions"
|
||||
mock_response.usage.prompt_tokens = 30
|
||||
mock_response.usage.completion_tokens = 20
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is machine learning?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with system instructions"
|
||||
assert result.in_token == 30
|
||||
assert result.out_token == 20
|
||||
|
||||
# Verify the combined prompt
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
expected_prompt = "You are a helpful assistant\n\nWhat is machine learning?"
|
||||
assert call_args[1]['messages'][0]['content'] == expected_prompt
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_hardcoded_model_response(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test that response model is hardcoded to 'llama.cpp'"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Test response"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'custom-model-name', # This should be ignored in response
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System", "User")
|
||||
|
||||
# Assert
|
||||
assert result.model == 'llama.cpp' # Should always be 'llama.cpp', not 'custom-model-name'
|
||||
assert processor.model == 'custom-model-name' # But processor.model should still be custom
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_no_rate_limiting(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test that no rate limiting is implemented (SLM assumption)"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "No rate limiting test"
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System", "User")
|
||||
|
||||
# Assert
|
||||
assert result.text == "No rate limiting test"
|
||||
# No specific rate limit error handling tested since SLM presumably has no rate limits
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
317
tests/unit/test_text_completion/test_ollama_processor.py
Normal file
317
tests/unit/test_text_completion/test_ollama_processor.py
Normal file
|
|
@ -0,0 +1,317 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.ollama
|
||||
Following the same successful pattern as VertexAI tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.ollama.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
|
||||
|
||||
class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Ollama processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Mock the parent class initialization
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2',
|
||||
'ollama': 'http://localhost:11434',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'llama2'
|
||||
assert hasattr(processor, 'llm')
|
||||
mock_client_class.assert_called_once_with(host='http://localhost:11434')
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Generated response from Ollama',
|
||||
'prompt_eval_count': 15,
|
||||
'eval_count': 8
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2',
|
||||
'ollama': 'http://localhost:11434',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Ollama"
|
||||
assert result.in_token == 15
|
||||
assert result.out_token == 8
|
||||
assert result.model == 'llama2'
|
||||
mock_client.generate.assert_called_once_with('llama2', "System prompt\n\nUser prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client.generate.side_effect = Exception("Connection error")
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2',
|
||||
'ollama': 'http://localhost:11434',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'mistral',
|
||||
'ollama': 'http://192.168.1.100:11434',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'mistral'
|
||||
mock_client_class.assert_called_once_with(host='http://192.168.1.100:11434')
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Don't provide model or ollama - should use defaults
|
||||
config = {
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemma2:9b' # default_model
|
||||
# Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env)
|
||||
mock_client_class.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Default response',
|
||||
'prompt_eval_count': 2,
|
||||
'eval_count': 3
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2',
|
||||
'ollama': 'http://localhost:11434',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'llama2'
|
||||
|
||||
# The prompt should be "" + "\n\n" + "" = "\n\n"
|
||||
mock_client.generate.assert_called_once_with('llama2', "\n\n")
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_token_counting(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test token counting from Ollama response"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Test response',
|
||||
'prompt_eval_count': 50,
|
||||
'eval_count': 25
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2',
|
||||
'ollama': 'http://localhost:11434',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Test response"
|
||||
assert result.in_token == 50
|
||||
assert result.out_token == 25
|
||||
assert result.model == 'llama2'
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_ollama_client_initialization(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test that Ollama client is initialized correctly"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'codellama',
|
||||
'ollama': 'http://ollama-server:11434',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify Client was called with correct host
|
||||
mock_client_class.assert_called_once_with(host='http://ollama-server:11434')
|
||||
|
||||
# Verify processor has the client
|
||||
assert processor.llm == mock_client
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test prompt construction with system and user prompts"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Response with system instructions',
|
||||
'prompt_eval_count': 25,
|
||||
'eval_count': 15
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2',
|
||||
'ollama': 'http://localhost:11434',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with system instructions"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
|
||||
# Verify the combined prompt
|
||||
mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
395
tests/unit/test_text_completion/test_openai_processor.py
Normal file
395
tests/unit/test_text_completion/test_openai_processor.py
Normal file
|
|
@ -0,0 +1,395 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.openai
|
||||
Following the same successful pattern as VertexAI and Ollama tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.openai.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test OpenAI processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-3.5-turbo'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4096
|
||||
assert hasattr(processor, 'openai')
|
||||
mock_openai_class.assert_called_once_with(base_url='https://api.openai.com/v1', api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Generated response from OpenAI"
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 12
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from OpenAI"
|
||||
assert result.in_token == 20
|
||||
assert result.out_token == 12
|
||||
assert result.model == 'gpt-3.5-turbo'
|
||||
|
||||
# Verify the OpenAI API call
|
||||
mock_openai_client.chat.completions.create.assert_called_once_with(
|
||||
model='gpt-3.5-turbo',
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "System prompt\n\nUser prompt"
|
||||
}]
|
||||
}],
|
||||
temperature=0.0,
|
||||
max_tokens=4096,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
response_format={"type": "text"}
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test rate limit error handling"""
|
||||
# Arrange
|
||||
from openai import RateLimitError
|
||||
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_client.chat.completions.create.side_effect = RateLimitError("Rate limit exceeded", response=MagicMock(), body=None)
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TooManyRequests):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_client.chat.completions.create.side_effect = Exception("API connection error")
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="API connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_api_key(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test processor initialization without API key (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': None, # No API key provided
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="OpenAI API key not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'api_key': 'custom-api-key',
|
||||
'url': 'https://custom-openai-url.com/v1',
|
||||
'temperature': 0.7,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-4'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 2048
|
||||
mock_openai_class.assert_called_once_with(base_url='https://custom-openai-url.com/v1', api_key='custom-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Only provide required fields, should use defaults
|
||||
config = {
|
||||
'api_key': 'test-api-key',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-3.5-turbo' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 4096 # default_max_output
|
||||
mock_openai_class.assert_called_once_with(base_url='https://api.openai.com/v1', api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Default response"
|
||||
mock_response.usage.prompt_tokens = 2
|
||||
mock_response.usage.completion_tokens = 3
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'gpt-3.5-turbo'
|
||||
|
||||
# Verify the combined prompt is sent correctly
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
expected_prompt = "\n\n" # Empty system + "\n\n" + empty user
|
||||
assert call_args[1]['messages'][0]['content'][0]['text'] == expected_prompt
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_openai_client_initialization_without_base_url(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test OpenAI client initialization without base_url"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': None, # No base URL
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert - should be called without base_url when it's None
|
||||
mock_openai_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_message_structure(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test that OpenAI messages are structured correctly"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with proper structure"
|
||||
mock_response.usage.prompt_tokens = 25
|
||||
mock_response.usage.completion_tokens = 15
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.5,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with proper structure"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
|
||||
# Verify the message structure matches OpenAI Chat API format
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
messages = call_args[1]['messages']
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0]['role'] == 'user'
|
||||
assert messages[0]['content'][0]['type'] == 'text'
|
||||
assert messages[0]['content'][0]['text'] == "You are a helpful assistant\n\nWhat is AI?"
|
||||
|
||||
# Verify other parameters
|
||||
assert call_args[1]['model'] == 'gpt-3.5-turbo'
|
||||
assert call_args[1]['temperature'] == 0.5
|
||||
assert call_args[1]['max_tokens'] == 1024
|
||||
assert call_args[1]['top_p'] == 1
|
||||
assert call_args[1]['frequency_penalty'] == 0
|
||||
assert call_args[1]['presence_penalty'] == 0
|
||||
assert call_args[1]['response_format'] == {"type": "text"}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
397
tests/unit/test_text_completion/test_vertexai_processor.py
Normal file
397
tests/unit/test_text_completion/test_vertexai_processor.py
Normal file
|
|
@ -0,0 +1,397 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.vertexai
|
||||
Starting simple with one test to get the basics working
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.vertexai.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
|
||||
|
||||
class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Simple test for processor initialization"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test basic processor initialization with mocked dependencies"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
mock_model = MagicMock()
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
# Mock the parent class initialization to avoid taskgroup requirement
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(), # Required by AsyncProcessor
|
||||
'id': 'test-processor' # Required by AsyncProcessor
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-2.0-flash-001' # It's stored as 'model', not 'model_name'
|
||||
assert hasattr(processor, 'generation_config')
|
||||
assert hasattr(processor, 'safety_settings')
|
||||
assert hasattr(processor, 'llm')
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('private.json')
|
||||
mock_vertexai.init.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Generated response from Gemini"
|
||||
mock_response.usage_metadata.prompt_token_count = 15
|
||||
mock_response.usage_metadata.candidates_token_count = 8
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Gemini"
|
||||
assert result.in_token == 15
|
||||
assert result.out_token == 8
|
||||
assert result.model == 'gemini-2.0-flash-001'
|
||||
# Check that the method was called (actual prompt format may vary)
|
||||
mock_model.generate_content.assert_called_once()
|
||||
# Verify the call was made with the expected parameters
|
||||
call_args = mock_model.generate_content.call_args
|
||||
assert call_args[1]['generation_config'] == processor.generation_config
|
||||
assert call_args[1]['safety_settings'] == processor.safety_settings
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test rate limit error handling"""
|
||||
# Arrange
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate_content.side_effect = ResourceExhausted("Rate limit exceeded")
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TooManyRequests):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_blocked_response(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test handling of blocked content (safety filters)"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = None # Blocked content returns None
|
||||
mock_response.usage_metadata.prompt_token_count = 10
|
||||
mock_response.usage_metadata.candidates_token_count = 0
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "Blocked content")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text is None # Should preserve None for blocked content
|
||||
assert result.in_token == 10
|
||||
assert result.out_token == 0
|
||||
assert result.model == 'gemini-2.0-flash-001'
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test processor initialization without private key (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': None, # No private key provided
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Private key file not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate_content.side_effect = Exception("Network error")
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Network error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
mock_model = MagicMock()
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-west1',
|
||||
'model': 'gemini-1.5-pro',
|
||||
'temperature': 0.7,
|
||||
'max_output': 4096,
|
||||
'private_key': 'custom-key.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-1.5-pro'
|
||||
|
||||
# Verify that generation_config object exists (can't easily check internal values)
|
||||
assert hasattr(processor, 'generation_config')
|
||||
assert processor.generation_config is not None
|
||||
|
||||
# Verify that safety settings are configured
|
||||
assert len(processor.safety_settings) == 4
|
||||
|
||||
# Verify service account was called with custom key
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('custom-key.json')
|
||||
|
||||
# Verify that parameters dict has the correct values (this is accessible)
|
||||
assert processor.parameters["temperature"] == 0.7
|
||||
assert processor.parameters["max_output_tokens"] == 4096
|
||||
assert processor.parameters["top_p"] == 1.0
|
||||
assert processor.parameters["top_k"] == 32
|
||||
assert processor.parameters["candidate_count"] == 1
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_vertexai_initialization_with_credentials(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test that VertexAI is initialized correctly with credentials"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
mock_model = MagicMock()
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'europe-west1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': 'service-account.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify VertexAI init was called with correct parameters
|
||||
mock_vertexai.init.assert_called_once_with(
|
||||
location='europe-west1',
|
||||
credentials=mock_credentials,
|
||||
project='test-project-123'
|
||||
)
|
||||
|
||||
# Verify GenerativeModel was created with the right model name
|
||||
mock_generative_model.assert_called_once_with('gemini-2.0-flash-001')
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Default response"
|
||||
mock_response.usage_metadata.prompt_token_count = 2
|
||||
mock_response.usage_metadata.candidates_token_count = 3
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'gemini-2.0-flash-001'
|
||||
|
||||
# Verify the model was called with the combined empty prompts
|
||||
mock_model.generate_content.assert_called_once()
|
||||
call_args = mock_model.generate_content.call_args
|
||||
# The prompt should be "" + "\n\n" + "" = "\n\n"
|
||||
assert call_args[0][0] == "\n\n"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
489
tests/unit/test_text_completion/test_vllm_processor.py
Normal file
489
tests/unit/test_text_completion/test_vllm_processor.py
Normal file
|
|
@ -0,0 +1,489 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.vllm
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.vllm.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test vLLM processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'TheBloke/Mistral-7B-v0.1-AWQ'
|
||||
assert processor.base_url == 'http://vllm-service:8899/v1'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 2048
|
||||
assert hasattr(processor, 'session')
|
||||
mock_session_class.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Generated response from vLLM'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 20,
|
||||
'completion_tokens': 12
|
||||
}
|
||||
})
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from vLLM"
|
||||
assert result.in_token == 20
|
||||
assert result.out_token == 12
|
||||
assert result.model == 'TheBloke/Mistral-7B-v0.1-AWQ'
|
||||
|
||||
# Verify the vLLM API call
|
||||
mock_session.post.assert_called_once_with(
|
||||
'http://vllm-service:8899/v1/completions',
|
||||
headers={'Content-Type': 'application/json'},
|
||||
json={
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'prompt': 'System prompt\n\nUser prompt',
|
||||
'max_tokens': 2048,
|
||||
'temperature': 0.0
|
||||
}
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_http_error(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test HTTP error handling"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 500
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Bad status: 500"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session.post.side_effect = Exception("Connection error")
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'custom-model',
|
||||
'url': 'http://custom-vllm:8080/v1',
|
||||
'temperature': 0.7,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'custom-model'
|
||||
assert processor.base_url == 'http://custom-vllm:8080/v1'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 1024
|
||||
mock_session_class.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Only provide required fields, should use defaults
|
||||
config = {
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'TheBloke/Mistral-7B-v0.1-AWQ' # default_model
|
||||
assert processor.base_url == 'http://vllm-service:8899/v1' # default_base_url
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 2048 # default_max_output
|
||||
mock_session_class.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Default response'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 2,
|
||||
'completion_tokens': 3
|
||||
}
|
||||
})
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'TheBloke/Mistral-7B-v0.1-AWQ'
|
||||
|
||||
# Verify the combined prompt is sent correctly
|
||||
call_args = mock_session.post.call_args
|
||||
expected_prompt = "\n\n" # Empty system + "\n\n" + empty user
|
||||
assert call_args[1]['json']['prompt'] == expected_prompt
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_request_structure(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test that vLLM request is structured correctly"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Response with proper structure'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 25,
|
||||
'completion_tokens': 15
|
||||
}
|
||||
})
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.5,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with proper structure"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
|
||||
# Verify the request structure
|
||||
call_args = mock_session.post.call_args
|
||||
|
||||
# Check URL
|
||||
assert call_args[0][0] == 'http://vllm-service:8899/v1/completions'
|
||||
|
||||
# Check headers
|
||||
assert call_args[1]['headers']['Content-Type'] == 'application/json'
|
||||
|
||||
# Check request body
|
||||
request_data = call_args[1]['json']
|
||||
assert request_data['model'] == 'TheBloke/Mistral-7B-v0.1-AWQ'
|
||||
assert request_data['prompt'] == "You are a helpful assistant\n\nWhat is AI?"
|
||||
assert request_data['temperature'] == 0.5
|
||||
assert request_data['max_tokens'] == 1024
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_vllm_session_initialization(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test that aiohttp session is initialized correctly"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'test-model',
|
||||
'url': 'http://test-vllm:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify ClientSession was created
|
||||
mock_session_class.assert_called_once()
|
||||
|
||||
# Verify processor has the session
|
||||
assert processor.session == mock_session
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_response_parsing(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test response parsing from vLLM API"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Parsed response text'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 35,
|
||||
'completion_tokens': 25
|
||||
}
|
||||
})
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System", "User query")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Parsed response text"
|
||||
assert result.in_token == 35
|
||||
assert result.out_token == 25
|
||||
assert result.model == 'TheBloke/Mistral-7B-v0.1-AWQ'
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test prompt construction with system and user prompts"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Response with system instructions'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 40,
|
||||
'completion_tokens': 30
|
||||
}
|
||||
})
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is machine learning?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with system instructions"
|
||||
assert result.in_token == 40
|
||||
assert result.out_token == 30
|
||||
|
||||
# Verify the combined prompt
|
||||
call_args = mock_session.post.call_args
|
||||
expected_prompt = "You are a helpful assistant\n\nWhat is machine learning?"
|
||||
assert call_args[1]['json']['prompt'] == expected_prompt
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
Loading…
Add table
Add a link
Reference in a new issue