Release/v1.2 (#457)

* Bump setup.py versions for 1.1

* PoC MCP server (#419)

* Very initial MCP server PoC for TrustGraph

* Put service on port 8000

* Add MCP container and packages to buildout

* Update docs for API/CLI changes in 1.0 (#421)

* Update some API basics for the 0.23/1.0 API change

* Add MCP container push (#425)

* Add command args to the MCP server (#426)

* Host and port parameters

* Added websocket arg

* More docs

* MCP client support (#427)

- MCP client service
- Tool request/response schema
- API gateway support for mcp-tool
- Message translation for tool request & response
- Make mcp-tool using configuration service for information
  about where the MCP services are.

* Feature/react call mcp (#428)

Key Features

  - MCP Tool Integration: Added core MCP tool support with ToolClientSpec and ToolClient classes
  - API Enhancement: New mcp_tool method for flow-specific tool invocation
  - CLI Tooling: New tg-invoke-mcp-tool command for testing MCP integration
  - React Agent Enhancement: Fixed and improved multi-tool invocation capabilities
  - Tool Management: Enhanced CLI for tool configuration and management

Changes

  - Added MCP tool invocation to API with flow-specific integration
  - Implemented ToolClientSpec and ToolClient for tool call handling
  - Updated agent-manager-react to invoke MCP tools with configurable types
  - Enhanced CLI with new commands and improved help text
  - Added comprehensive documentation for new CLI commands
  - Improved tool configuration management

Testing

  - Added tg-invoke-mcp-tool CLI command for isolated MCP integration testing
  - Enhanced agent capability to invoke multiple tools simultaneously

* Test suite executed from CI pipeline (#433)

* Test strategy & test cases

* Unit tests

* Integration tests

* Extending test coverage (#434)

* Contract tests

* Testing embeedings

* Agent unit tests

* Knowledge pipeline tests

* Turn on contract tests

* Increase storage test coverage (#435)

* Fixing storage and adding tests

* PR pipeline only runs quick tests

* Empty configuration is returned as empty list, previously was not in response (#436)

* Update config util to take files as well as command-line text (#437)

* Updated CLI invocation and config model for tools and mcp (#438)

* Updated CLI invocation and config model for tools and mcp

* CLI anomalies

* Tweaked the MCP tool implementation for new model

* Update agent implementation to match the new model

* Fix agent tools, now all tested

* Fixed integration tests

* Fix MCP delete tool params

* Update Python deps to 1.2

* Update to enable knowledge extraction using the agent framework (#439)

* Implement KG extraction agent (kg-extract-agent)

* Using ReAct framework (agent-manager-react)
 
* ReAct manager had an issue when emitting JSON, which conflicts which ReAct manager's own JSON messages, so refactored ReAct manager to use traditional ReAct messages, non-JSON structure.
 
* Minor refactor to take the prompt template client out of prompt-template so it can be more readily used by other modules. kg-extract-agent uses this framework.

* Migrate from setup.py to pyproject.toml (#440)

* Converted setup.py to pyproject.toml

* Modern package infrastructure as recommended by py docs

* Install missing build deps (#441)

* Install missing build deps (#442)

* Implement logging strategy (#444)

* Logging strategy and convert all prints() to logging invocations

* Fix/startup failure (#445)

* Fix loggin startup problems

* Fix logging startup problems (#446)

* Fix logging startup problems (#447)

* Fixed Mistral OCR to use current API (#448)

* Fixed Mistral OCR to use current API

* Added PDF decoder tests

* Fix Mistral OCR ident to be standard pdf-decoder (#450)

* Fix Mistral OCR ident to be standard pdf-decoder

* Correct test

* Schema structure refactor (#451)

* Write schema refactor spec

* Implemented schema refactor spec

* Structure data mvp (#452)

* Structured data tech spec

* Architecture principles

* New schemas

* Updated schemas and specs

* Object extractor

* Add .coveragerc

* New tests

* Cassandra object storage

* Trying to object extraction working, issues exist

* Validate librarian collection (#453)

* Fix token chunker, broken API invocation (#454)

* Fix token chunker, broken API invocation (#455)

* Knowledge load utility CLI (#456)

* Knowledge loader

* More tests
This commit is contained in:
cybermaggedon 2025-08-18 20:56:09 +01:00 committed by GitHub
parent c85ba197be
commit 89be656990
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
509 changed files with 49632 additions and 5159 deletions

View file

@ -0,0 +1,3 @@
"""
Unit tests for text completion services
"""

View file

@ -0,0 +1,3 @@
"""
Common utilities for text completion tests
"""

View file

@ -0,0 +1,69 @@
"""
Base test patterns that can be reused across different text completion models
"""
from abc import ABC, abstractmethod
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
class BaseTextCompletionTestCase(IsolatedAsyncioTestCase, ABC):
"""
Base test class for text completion processors
Provides common test patterns that can be reused
"""
@abstractmethod
def get_processor_class(self):
"""Return the processor class to test"""
pass
@abstractmethod
def get_base_config(self):
"""Return base configuration for the processor"""
pass
@abstractmethod
def get_mock_patches(self):
"""Return list of patch decorators for mocking dependencies"""
pass
def create_base_config(self, **overrides):
"""Create base config with optional overrides"""
config = self.get_base_config()
config.update(overrides)
return config
def create_mock_llm_result(self, text="Test response", in_token=10, out_token=5):
"""Create a mock LLM result"""
from trustgraph.base import LlmResult
return LlmResult(text=text, in_token=in_token, out_token=out_token)
class CommonTestPatterns:
"""
Common test patterns that can be used across different models
"""
@staticmethod
def basic_initialization_test_pattern(test_instance):
"""
Test pattern for basic processor initialization
test_instance should be a BaseTextCompletionTestCase
"""
# This would contain the common pattern for initialization testing
pass
@staticmethod
def successful_generation_test_pattern(test_instance):
"""
Test pattern for successful content generation
"""
pass
@staticmethod
def error_handling_test_pattern(test_instance):
"""
Test pattern for error handling
"""
pass

View file

@ -0,0 +1,53 @@
"""
Common mocking utilities for text completion tests
"""
from unittest.mock import AsyncMock, MagicMock
class CommonMocks:
"""Common mock objects used across text completion tests"""
@staticmethod
def create_mock_async_processor_init():
"""Create mock for AsyncProcessor.__init__"""
mock = MagicMock()
mock.return_value = None
return mock
@staticmethod
def create_mock_llm_service_init():
"""Create mock for LlmService.__init__"""
mock = MagicMock()
mock.return_value = None
return mock
@staticmethod
def create_mock_response(text="Test response", prompt_tokens=10, completion_tokens=5):
"""Create a mock response object"""
response = MagicMock()
response.text = text
response.usage_metadata.prompt_token_count = prompt_tokens
response.usage_metadata.candidates_token_count = completion_tokens
return response
@staticmethod
def create_basic_config():
"""Create basic config with required fields"""
return {
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
class MockPatches:
"""Common patch decorators for different services"""
@staticmethod
def get_base_patches():
"""Get patches that are common to all processors"""
return [
'trustgraph.base.async_processor.AsyncProcessor.__init__',
'trustgraph.base.llm_service.LlmService.__init__'
]

View file

@ -0,0 +1,499 @@
"""
Pytest configuration and fixtures for text completion tests
"""
import pytest
from unittest.mock import MagicMock, AsyncMock
from trustgraph.base import LlmResult
# === Common Fixtures for All Text Completion Models ===
@pytest.fixture
def base_processor_config():
"""Base configuration required by all processors"""
return {
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
@pytest.fixture
def sample_llm_result():
"""Sample LlmResult for testing"""
return LlmResult(
text="Test response",
in_token=10,
out_token=5
)
@pytest.fixture
def mock_async_processor_init():
"""Mock AsyncProcessor.__init__ to avoid infrastructure requirements"""
mock = MagicMock()
mock.return_value = None
return mock
@pytest.fixture
def mock_llm_service_init():
"""Mock LlmService.__init__ to avoid infrastructure requirements"""
mock = MagicMock()
mock.return_value = None
return mock
@pytest.fixture
def mock_prometheus_metrics():
"""Mock Prometheus metrics"""
mock_metric = MagicMock()
mock_metric.labels.return_value.time.return_value = MagicMock()
return mock_metric
@pytest.fixture
def mock_pulsar_consumer():
"""Mock Pulsar consumer for integration testing"""
return AsyncMock()
@pytest.fixture
def mock_pulsar_producer():
"""Mock Pulsar producer for integration testing"""
return AsyncMock()
@pytest.fixture(autouse=True)
def mock_env_vars(monkeypatch):
"""Mock environment variables for testing"""
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "test-project")
monkeypatch.setenv("GOOGLE_APPLICATION_CREDENTIALS", "/path/to/test-credentials.json")
@pytest.fixture
def mock_async_context_manager():
"""Mock async context manager for testing"""
class MockAsyncContextManager:
def __init__(self, return_value):
self.return_value = return_value
async def __aenter__(self):
return self.return_value
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
return MockAsyncContextManager
# === VertexAI Specific Fixtures ===
@pytest.fixture
def mock_vertexai_credentials():
"""Mock Google Cloud service account credentials"""
return MagicMock()
@pytest.fixture
def mock_vertexai_model():
"""Mock VertexAI GenerativeModel"""
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = "Test response"
mock_response.usage_metadata.prompt_token_count = 10
mock_response.usage_metadata.candidates_token_count = 5
mock_model.generate_content.return_value = mock_response
return mock_model
@pytest.fixture
def vertexai_processor_config(base_processor_config):
"""Default configuration for VertexAI processor"""
config = base_processor_config.copy()
config.update({
'region': 'us-central1',
'model': 'gemini-2.0-flash-001',
'temperature': 0.0,
'max_output': 8192,
'private_key': 'private.json'
})
return config
@pytest.fixture
def mock_safety_settings():
"""Mock safety settings for VertexAI"""
safety_settings = []
for i in range(4): # 4 safety categories
setting = MagicMock()
setting.category = f"HARM_CATEGORY_{i}"
setting.threshold = "BLOCK_MEDIUM_AND_ABOVE"
safety_settings.append(setting)
return safety_settings
@pytest.fixture
def mock_generation_config():
"""Mock generation configuration for VertexAI"""
config = MagicMock()
config.temperature = 0.0
config.max_output_tokens = 8192
config.top_p = 1.0
config.top_k = 10
config.candidate_count = 1
return config
@pytest.fixture
def mock_vertexai_exception():
"""Mock VertexAI exceptions"""
from google.api_core.exceptions import ResourceExhausted
return ResourceExhausted("Test resource exhausted error")
# === Ollama Specific Fixtures ===
@pytest.fixture
def ollama_processor_config(base_processor_config):
"""Default configuration for Ollama processor"""
config = base_processor_config.copy()
config.update({
'model': 'llama2',
'temperature': 0.0,
'max_output': 8192,
'host': 'localhost',
'port': 11434
})
return config
@pytest.fixture
def mock_ollama_client():
"""Mock Ollama client"""
mock_client = MagicMock()
mock_response = {
'response': 'Test response from Ollama',
'done': True,
'eval_count': 5,
'prompt_eval_count': 10
}
mock_client.generate.return_value = mock_response
return mock_client
# === OpenAI Specific Fixtures ===
@pytest.fixture
def openai_processor_config(base_processor_config):
"""Default configuration for OpenAI processor"""
config = base_processor_config.copy()
config.update({
'model': 'gpt-3.5-turbo',
'api_key': 'test-api-key',
'url': 'https://api.openai.com/v1',
'temperature': 0.0,
'max_output': 4096
})
return config
@pytest.fixture
def mock_openai_client():
"""Mock OpenAI client"""
mock_client = MagicMock()
# Mock the response structure
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Test response from OpenAI"
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 8
mock_client.chat.completions.create.return_value = mock_response
return mock_client
@pytest.fixture
def mock_openai_rate_limit_error():
"""Mock OpenAI rate limit error"""
from openai import RateLimitError
return RateLimitError("Rate limit exceeded", response=MagicMock(), body=None)
# === Azure OpenAI Specific Fixtures ===
@pytest.fixture
def azure_openai_processor_config(base_processor_config):
"""Default configuration for Azure OpenAI processor"""
config = base_processor_config.copy()
config.update({
'model': 'gpt-4',
'endpoint': 'https://test.openai.azure.com/',
'token': 'test-token',
'api_version': '2024-12-01-preview',
'temperature': 0.0,
'max_output': 4192
})
return config
@pytest.fixture
def mock_azure_openai_client():
"""Mock Azure OpenAI client"""
mock_client = MagicMock()
# Mock the response structure
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Test response from Azure OpenAI"
mock_response.usage.prompt_tokens = 20
mock_response.usage.completion_tokens = 10
mock_client.chat.completions.create.return_value = mock_response
return mock_client
@pytest.fixture
def mock_azure_openai_rate_limit_error():
"""Mock Azure OpenAI rate limit error"""
from openai import RateLimitError
return RateLimitError("Rate limit exceeded", response=MagicMock(), body=None)
# === Azure Specific Fixtures ===
@pytest.fixture
def azure_processor_config(base_processor_config):
"""Default configuration for Azure processor"""
config = base_processor_config.copy()
config.update({
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
'token': 'test-token',
'temperature': 0.0,
'max_output': 4192
})
return config
@pytest.fixture
def mock_azure_requests():
"""Mock requests for Azure processor"""
mock_requests = MagicMock()
# Mock successful response
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
'choices': [{
'message': {
'content': 'Test response from Azure'
}
}],
'usage': {
'prompt_tokens': 18,
'completion_tokens': 9
}
}
mock_requests.post.return_value = mock_response
return mock_requests
@pytest.fixture
def mock_azure_rate_limit_response():
"""Mock Azure rate limit response"""
mock_response = MagicMock()
mock_response.status_code = 429
return mock_response
# === Claude Specific Fixtures ===
@pytest.fixture
def claude_processor_config(base_processor_config):
"""Default configuration for Claude processor"""
config = base_processor_config.copy()
config.update({
'model': 'claude-3-5-sonnet-20240620',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 8192
})
return config
@pytest.fixture
def mock_claude_client():
"""Mock Claude (Anthropic) client"""
mock_client = MagicMock()
# Mock the response structure
mock_response = MagicMock()
mock_response.content = [MagicMock()]
mock_response.content[0].text = "Test response from Claude"
mock_response.usage.input_tokens = 22
mock_response.usage.output_tokens = 12
mock_client.messages.create.return_value = mock_response
return mock_client
@pytest.fixture
def mock_claude_rate_limit_error():
"""Mock Claude rate limit error"""
import anthropic
return anthropic.RateLimitError("Rate limit exceeded", response=MagicMock(), body=None)
# === vLLM Specific Fixtures ===
@pytest.fixture
def vllm_processor_config(base_processor_config):
"""Default configuration for vLLM processor"""
config = base_processor_config.copy()
config.update({
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
'url': 'http://vllm-service:8899/v1',
'temperature': 0.0,
'max_output': 2048
})
return config
@pytest.fixture
def mock_vllm_session():
"""Mock aiohttp ClientSession for vLLM"""
mock_session = MagicMock()
# Mock successful response
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'text': 'Test response from vLLM'
}],
'usage': {
'prompt_tokens': 16,
'completion_tokens': 8
}
})
# Mock the async context manager
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
return mock_session
@pytest.fixture
def mock_vllm_error_response():
"""Mock vLLM error response"""
mock_response = MagicMock()
mock_response.status = 500
return mock_response
# === Cohere Specific Fixtures ===
@pytest.fixture
def cohere_processor_config(base_processor_config):
"""Default configuration for Cohere processor"""
config = base_processor_config.copy()
config.update({
'model': 'c4ai-aya-23-8b',
'api_key': 'test-api-key',
'temperature': 0.0
})
return config
@pytest.fixture
def mock_cohere_client():
"""Mock Cohere client"""
mock_client = MagicMock()
# Mock the response structure
mock_output = MagicMock()
mock_output.text = "Test response from Cohere"
mock_output.meta.billed_units.input_tokens = 18
mock_output.meta.billed_units.output_tokens = 10
mock_client.chat.return_value = mock_output
return mock_client
@pytest.fixture
def mock_cohere_rate_limit_error():
"""Mock Cohere rate limit error"""
import cohere
return cohere.TooManyRequestsError("Rate limit exceeded")
# === Google AI Studio Specific Fixtures ===
@pytest.fixture
def googleaistudio_processor_config(base_processor_config):
"""Default configuration for Google AI Studio processor"""
config = base_processor_config.copy()
config.update({
'model': 'gemini-2.0-flash-001',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 8192
})
return config
@pytest.fixture
def mock_googleaistudio_client():
"""Mock Google AI Studio client"""
mock_client = MagicMock()
# Mock the response structure
mock_response = MagicMock()
mock_response.text = "Test response from Google AI Studio"
mock_response.usage_metadata.prompt_token_count = 20
mock_response.usage_metadata.candidates_token_count = 12
mock_client.models.generate_content.return_value = mock_response
return mock_client
@pytest.fixture
def mock_googleaistudio_rate_limit_error():
"""Mock Google AI Studio rate limit error"""
from google.api_core.exceptions import ResourceExhausted
return ResourceExhausted("Rate limit exceeded")
# === LlamaFile Specific Fixtures ===
@pytest.fixture
def llamafile_processor_config(base_processor_config):
"""Default configuration for LlamaFile processor"""
config = base_processor_config.copy()
config.update({
'model': 'LLaMA_CPP',
'llamafile': 'http://localhost:8080/v1',
'temperature': 0.0,
'max_output': 4096
})
return config
@pytest.fixture
def mock_llamafile_client():
"""Mock OpenAI client for LlamaFile"""
mock_client = MagicMock()
# Mock the response structure
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Test response from LlamaFile"
mock_response.usage.prompt_tokens = 14
mock_response.usage.completion_tokens = 8
mock_client.chat.completions.create.return_value = mock_response
return mock_client

View file

@ -0,0 +1,407 @@
"""
Unit tests for trustgraph.model.text_completion.azure_openai
Following the same successful pattern as previous tests
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.model.text_completion.azure_openai.llm import Processor
from trustgraph.base import LlmResult
from trustgraph.exceptions import TooManyRequests
class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
"""Test Azure OpenAI processor functionality"""
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
"""Test basic processor initialization"""
# Arrange
mock_azure_client = MagicMock()
mock_azure_openai_class.return_value = mock_azure_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-4',
'endpoint': 'https://test.openai.azure.com/',
'token': 'test-token',
'api_version': '2024-12-01-preview',
'temperature': 0.0,
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'gpt-4'
assert processor.temperature == 0.0
assert processor.max_output == 4192
assert hasattr(processor, 'openai')
mock_azure_openai_class.assert_called_once_with(
api_key='test-token',
api_version='2024-12-01-preview',
azure_endpoint='https://test.openai.azure.com/'
)
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
"""Test successful content generation"""
# Arrange
mock_azure_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Generated response from Azure OpenAI"
mock_response.usage.prompt_tokens = 25
mock_response.usage.completion_tokens = 15
mock_azure_client.chat.completions.create.return_value = mock_response
mock_azure_openai_class.return_value = mock_azure_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-4',
'endpoint': 'https://test.openai.azure.com/',
'token': 'test-token',
'api_version': '2024-12-01-preview',
'temperature': 0.0,
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System prompt", "User prompt")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Generated response from Azure OpenAI"
assert result.in_token == 25
assert result.out_token == 15
assert result.model == 'gpt-4'
# Verify the Azure OpenAI API call
mock_azure_client.chat.completions.create.assert_called_once_with(
model='gpt-4',
messages=[{
"role": "user",
"content": [{
"type": "text",
"text": "System prompt\n\nUser prompt"
}]
}],
temperature=0.0,
max_tokens=4192,
top_p=1
)
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
"""Test rate limit error handling"""
# Arrange
from openai import RateLimitError
mock_azure_client = MagicMock()
mock_azure_client.chat.completions.create.side_effect = RateLimitError("Rate limit exceeded", response=MagicMock(), body=None)
mock_azure_openai_class.return_value = mock_azure_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-4',
'endpoint': 'https://test.openai.azure.com/',
'token': 'test-token',
'api_version': '2024-12-01-preview',
'temperature': 0.0,
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act & Assert
with pytest.raises(TooManyRequests):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
"""Test handling of generic exceptions"""
# Arrange
mock_azure_client = MagicMock()
mock_azure_client.chat.completions.create.side_effect = Exception("Azure API connection error")
mock_azure_openai_class.return_value = mock_azure_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-4',
'endpoint': 'https://test.openai.azure.com/',
'token': 'test-token',
'api_version': '2024-12-01-preview',
'temperature': 0.0,
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act & Assert
with pytest.raises(Exception, match="Azure API connection error"):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_without_endpoint(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
"""Test processor initialization without endpoint (should fail)"""
# Arrange
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-4',
'endpoint': None, # No endpoint provided
'token': 'test-token',
'api_version': '2024-12-01-preview',
'temperature': 0.0,
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act & Assert
with pytest.raises(RuntimeError, match="Azure endpoint not specified"):
processor = Processor(**config)
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_without_token(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
"""Test processor initialization without token (should fail)"""
# Arrange
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-4',
'endpoint': 'https://test.openai.azure.com/',
'token': None, # No token provided
'api_version': '2024-12-01-preview',
'temperature': 0.0,
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act & Assert
with pytest.raises(RuntimeError, match="Azure token not specified"):
processor = Processor(**config)
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
"""Test processor initialization with custom parameters"""
# Arrange
mock_azure_client = MagicMock()
mock_azure_openai_class.return_value = mock_azure_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-35-turbo',
'endpoint': 'https://custom.openai.azure.com/',
'token': 'custom-token',
'api_version': '2023-05-15',
'temperature': 0.7,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'gpt-35-turbo'
assert processor.temperature == 0.7
assert processor.max_output == 2048
mock_azure_openai_class.assert_called_once_with(
api_key='custom-token',
api_version='2023-05-15',
azure_endpoint='https://custom.openai.azure.com/'
)
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
"""Test processor initialization with default values"""
# Arrange
mock_azure_client = MagicMock()
mock_azure_openai_class.return_value = mock_azure_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
# Only provide required fields, should use defaults
config = {
'endpoint': 'https://test.openai.azure.com/',
'token': 'test-token',
'model': 'gpt-4', # Required for Azure
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'gpt-4'
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 4192 # default_max_output
mock_azure_openai_class.assert_called_once_with(
api_key='test-token',
api_version='2024-12-01-preview', # default_api
azure_endpoint='https://test.openai.azure.com/'
)
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
"""Test content generation with empty prompts"""
# Arrange
mock_azure_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Default response"
mock_response.usage.prompt_tokens = 2
mock_response.usage.completion_tokens = 3
mock_azure_client.chat.completions.create.return_value = mock_response
mock_azure_openai_class.return_value = mock_azure_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-4',
'endpoint': 'https://test.openai.azure.com/',
'token': 'test-token',
'api_version': '2024-12-01-preview',
'temperature': 0.0,
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("", "")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Default response"
assert result.in_token == 2
assert result.out_token == 3
assert result.model == 'gpt-4'
# Verify the combined prompt is sent correctly
call_args = mock_azure_client.chat.completions.create.call_args
expected_prompt = "\n\n" # Empty system + "\n\n" + empty user
assert call_args[1]['messages'][0]['content'][0]['text'] == expected_prompt
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_message_structure(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
"""Test that Azure OpenAI messages are structured correctly"""
# Arrange
mock_azure_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response with proper structure"
mock_response.usage.prompt_tokens = 30
mock_response.usage.completion_tokens = 20
mock_azure_client.chat.completions.create.return_value = mock_response
mock_azure_openai_class.return_value = mock_azure_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-4',
'endpoint': 'https://test.openai.azure.com/',
'token': 'test-token',
'api_version': '2024-12-01-preview',
'temperature': 0.5,
'max_output': 1024,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
# Assert
assert result.text == "Response with proper structure"
assert result.in_token == 30
assert result.out_token == 20
# Verify the message structure matches Azure OpenAI Chat API format
call_args = mock_azure_client.chat.completions.create.call_args
messages = call_args[1]['messages']
assert len(messages) == 1
assert messages[0]['role'] == 'user'
assert messages[0]['content'][0]['type'] == 'text'
assert messages[0]['content'][0]['text'] == "You are a helpful assistant\n\nWhat is AI?"
# Verify other parameters
assert call_args[1]['model'] == 'gpt-4'
assert call_args[1]['temperature'] == 0.5
assert call_args[1]['max_tokens'] == 1024
assert call_args[1]['top_p'] == 1
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,463 @@
"""
Unit tests for trustgraph.model.text_completion.azure
Following the same successful pattern as previous tests
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.model.text_completion.azure.llm import Processor
from trustgraph.base import LlmResult
from trustgraph.exceptions import TooManyRequests
class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
"""Test Azure processor functionality"""
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_requests):
"""Test basic processor initialization"""
# Arrange
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
'token': 'test-token',
'temperature': 0.0,
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.endpoint == 'https://test.inference.ai.azure.com/v1/chat/completions'
assert processor.token == 'test-token'
assert processor.temperature == 0.0
assert processor.max_output == 4192
assert processor.model == 'AzureAI'
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_requests):
"""Test successful content generation"""
# Arrange
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
'choices': [{
'message': {
'content': 'Generated response from Azure'
}
}],
'usage': {
'prompt_tokens': 20,
'completion_tokens': 12
}
}
mock_requests.post.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
'token': 'test-token',
'temperature': 0.0,
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System prompt", "User prompt")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Generated response from Azure"
assert result.in_token == 20
assert result.out_token == 12
assert result.model == 'AzureAI'
# Verify the API call was made correctly
mock_requests.post.assert_called_once()
call_args = mock_requests.post.call_args
# Check URL
assert call_args[0][0] == 'https://test.inference.ai.azure.com/v1/chat/completions'
# Check headers
headers = call_args[1]['headers']
assert headers['Content-Type'] == 'application/json'
assert headers['Authorization'] == 'Bearer test-token'
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_requests):
"""Test rate limit error handling"""
# Arrange
mock_response = MagicMock()
mock_response.status_code = 429
mock_requests.post.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
'token': 'test-token',
'temperature': 0.0,
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act & Assert
with pytest.raises(TooManyRequests):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_http_error(self, mock_llm_init, mock_async_init, mock_requests):
"""Test HTTP error handling"""
# Arrange
mock_response = MagicMock()
mock_response.status_code = 500
mock_requests.post.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
'token': 'test-token',
'temperature': 0.0,
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act & Assert
with pytest.raises(RuntimeError, match="LLM failure"):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_requests):
"""Test handling of generic exceptions"""
# Arrange
mock_requests.post.side_effect = Exception("Connection error")
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
'token': 'test-token',
'temperature': 0.0,
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act & Assert
with pytest.raises(Exception, match="Connection error"):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_without_endpoint(self, mock_llm_init, mock_async_init, mock_requests):
"""Test processor initialization without endpoint (should fail)"""
# Arrange
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'endpoint': None, # No endpoint provided
'token': 'test-token',
'temperature': 0.0,
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act & Assert
with pytest.raises(RuntimeError, match="Azure endpoint not specified"):
processor = Processor(**config)
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_without_token(self, mock_llm_init, mock_async_init, mock_requests):
"""Test processor initialization without token (should fail)"""
# Arrange
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
'token': None, # No token provided
'temperature': 0.0,
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act & Assert
with pytest.raises(RuntimeError, match="Azure token not specified"):
processor = Processor(**config)
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_requests):
"""Test processor initialization with custom parameters"""
# Arrange
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'endpoint': 'https://custom.inference.ai.azure.com/v1/chat/completions',
'token': 'custom-token',
'temperature': 0.7,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.endpoint == 'https://custom.inference.ai.azure.com/v1/chat/completions'
assert processor.token == 'custom-token'
assert processor.temperature == 0.7
assert processor.max_output == 2048
assert processor.model == 'AzureAI'
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_requests):
"""Test processor initialization with default values"""
# Arrange
mock_async_init.return_value = None
mock_llm_init.return_value = None
# Only provide required fields, should use defaults
config = {
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
'token': 'test-token',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.endpoint == 'https://test.inference.ai.azure.com/v1/chat/completions'
assert processor.token == 'test-token'
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 4192 # default_max_output
assert processor.model == 'AzureAI' # default_model
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_requests):
"""Test content generation with empty prompts"""
# Arrange
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
'choices': [{
'message': {
'content': 'Default response'
}
}],
'usage': {
'prompt_tokens': 2,
'completion_tokens': 3
}
}
mock_requests.post.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
'token': 'test-token',
'temperature': 0.0,
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("", "")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Default response"
assert result.in_token == 2
assert result.out_token == 3
assert result.model == 'AzureAI'
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_build_prompt_structure(self, mock_llm_init, mock_async_init, mock_requests):
"""Test that build_prompt creates correct message structure"""
# Arrange
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
'choices': [{
'message': {
'content': 'Response with proper structure'
}
}],
'usage': {
'prompt_tokens': 25,
'completion_tokens': 15
}
}
mock_requests.post.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
'token': 'test-token',
'temperature': 0.5,
'max_output': 1024,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
# Assert
assert result.text == "Response with proper structure"
assert result.in_token == 25
assert result.out_token == 15
# Verify the request structure
mock_requests.post.assert_called_once()
call_args = mock_requests.post.call_args
# Parse the request body
import json
request_body = json.loads(call_args[1]['data'])
# Verify message structure
assert 'messages' in request_body
assert len(request_body['messages']) == 2
# Check system message
assert request_body['messages'][0]['role'] == 'system'
assert request_body['messages'][0]['content'] == 'You are a helpful assistant'
# Check user message
assert request_body['messages'][1]['role'] == 'user'
assert request_body['messages'][1]['content'] == 'What is AI?'
# Check parameters
assert request_body['temperature'] == 0.5
assert request_body['max_tokens'] == 1024
assert request_body['top_p'] == 1
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_call_llm_method(self, mock_llm_init, mock_async_init, mock_requests):
"""Test the call_llm method directly"""
# Arrange
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
'choices': [{
'message': {
'content': 'Test response'
}
}],
'usage': {
'prompt_tokens': 10,
'completion_tokens': 5
}
}
mock_requests.post.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
'token': 'test-token',
'temperature': 0.0,
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = processor.call_llm('{"test": "body"}')
# Assert
assert result == mock_response.json.return_value
# Verify the request was made correctly
mock_requests.post.assert_called_once_with(
'https://test.inference.ai.azure.com/v1/chat/completions',
data='{"test": "body"}',
headers={
'Content-Type': 'application/json',
'Authorization': 'Bearer test-token'
}
)
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,440 @@
"""
Unit tests for trustgraph.model.text_completion.claude
Following the same successful pattern as previous tests
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.model.text_completion.claude.llm import Processor
from trustgraph.base import LlmResult
from trustgraph.exceptions import TooManyRequests
class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
"""Test Claude processor functionality"""
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_anthropic_class):
"""Test basic processor initialization"""
# Arrange
mock_claude_client = MagicMock()
mock_anthropic_class.return_value = mock_claude_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'claude-3-5-sonnet-20240620',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'claude-3-5-sonnet-20240620'
assert processor.temperature == 0.0
assert processor.max_output == 8192
assert hasattr(processor, 'claude')
mock_anthropic_class.assert_called_once_with(api_key='test-api-key')
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_anthropic_class):
"""Test successful content generation"""
# Arrange
mock_claude_client = MagicMock()
mock_response = MagicMock()
mock_response.content = [MagicMock()]
mock_response.content[0].text = "Generated response from Claude"
mock_response.usage.input_tokens = 25
mock_response.usage.output_tokens = 15
mock_claude_client.messages.create.return_value = mock_response
mock_anthropic_class.return_value = mock_claude_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'claude-3-5-sonnet-20240620',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System prompt", "User prompt")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Generated response from Claude"
assert result.in_token == 25
assert result.out_token == 15
assert result.model == 'claude-3-5-sonnet-20240620'
# Verify the Claude API call
mock_claude_client.messages.create.assert_called_once_with(
model='claude-3-5-sonnet-20240620',
max_tokens=8192,
temperature=0.0,
system="System prompt",
messages=[{
"role": "user",
"content": [{
"type": "text",
"text": "User prompt"
}]
}]
)
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_anthropic_class):
"""Test rate limit error handling"""
# Arrange
import anthropic
mock_claude_client = MagicMock()
mock_claude_client.messages.create.side_effect = anthropic.RateLimitError(
"Rate limit exceeded",
response=MagicMock(),
body=None
)
mock_anthropic_class.return_value = mock_claude_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'claude-3-5-sonnet-20240620',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act & Assert
with pytest.raises(TooManyRequests):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_anthropic_class):
"""Test handling of generic exceptions"""
# Arrange
mock_claude_client = MagicMock()
mock_claude_client.messages.create.side_effect = Exception("API connection error")
mock_anthropic_class.return_value = mock_claude_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'claude-3-5-sonnet-20240620',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act & Assert
with pytest.raises(Exception, match="API connection error"):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_without_api_key(self, mock_llm_init, mock_async_init, mock_anthropic_class):
"""Test processor initialization without API key (should fail)"""
# Arrange
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'claude-3-5-sonnet-20240620',
'api_key': None, # No API key provided
'temperature': 0.0,
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act & Assert
with pytest.raises(RuntimeError, match="Claude API key not specified"):
processor = Processor(**config)
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_anthropic_class):
"""Test processor initialization with custom parameters"""
# Arrange
mock_claude_client = MagicMock()
mock_anthropic_class.return_value = mock_claude_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'claude-3-haiku-20240307',
'api_key': 'custom-api-key',
'temperature': 0.7,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'claude-3-haiku-20240307'
assert processor.temperature == 0.7
assert processor.max_output == 4096
mock_anthropic_class.assert_called_once_with(api_key='custom-api-key')
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_anthropic_class):
"""Test processor initialization with default values"""
# Arrange
mock_claude_client = MagicMock()
mock_anthropic_class.return_value = mock_claude_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
# Only provide required fields, should use defaults
config = {
'api_key': 'test-api-key',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'claude-3-5-sonnet-20240620' # default_model
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 8192 # default_max_output
mock_anthropic_class.assert_called_once_with(api_key='test-api-key')
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_anthropic_class):
"""Test content generation with empty prompts"""
# Arrange
mock_claude_client = MagicMock()
mock_response = MagicMock()
mock_response.content = [MagicMock()]
mock_response.content[0].text = "Default response"
mock_response.usage.input_tokens = 2
mock_response.usage.output_tokens = 3
mock_claude_client.messages.create.return_value = mock_response
mock_anthropic_class.return_value = mock_claude_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'claude-3-5-sonnet-20240620',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("", "")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Default response"
assert result.in_token == 2
assert result.out_token == 3
assert result.model == 'claude-3-5-sonnet-20240620'
# Verify the system prompt and user content are handled correctly
call_args = mock_claude_client.messages.create.call_args
assert call_args[1]['system'] == ""
assert call_args[1]['messages'][0]['content'][0]['text'] == ""
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_message_structure(self, mock_llm_init, mock_async_init, mock_anthropic_class):
"""Test that Claude messages are structured correctly"""
# Arrange
mock_claude_client = MagicMock()
mock_response = MagicMock()
mock_response.content = [MagicMock()]
mock_response.content[0].text = "Response with proper structure"
mock_response.usage.input_tokens = 30
mock_response.usage.output_tokens = 20
mock_claude_client.messages.create.return_value = mock_response
mock_anthropic_class.return_value = mock_claude_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'claude-3-5-sonnet-20240620',
'api_key': 'test-api-key',
'temperature': 0.5,
'max_output': 1024,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
# Assert
assert result.text == "Response with proper structure"
assert result.in_token == 30
assert result.out_token == 20
# Verify the message structure matches Claude API format
call_args = mock_claude_client.messages.create.call_args
# Check system prompt
assert call_args[1]['system'] == "You are a helpful assistant"
# Check user message structure
messages = call_args[1]['messages']
assert len(messages) == 1
assert messages[0]['role'] == 'user'
assert messages[0]['content'][0]['type'] == 'text'
assert messages[0]['content'][0]['text'] == "What is AI?"
# Verify other parameters
assert call_args[1]['model'] == 'claude-3-5-sonnet-20240620'
assert call_args[1]['temperature'] == 0.5
assert call_args[1]['max_tokens'] == 1024
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_multiple_content_blocks(self, mock_llm_init, mock_async_init, mock_anthropic_class):
"""Test handling of multiple content blocks in response"""
# Arrange
mock_claude_client = MagicMock()
mock_response = MagicMock()
# Mock multiple content blocks (Claude can return multiple)
mock_content_1 = MagicMock()
mock_content_1.text = "First part of response"
mock_content_2 = MagicMock()
mock_content_2.text = "Second part of response"
mock_response.content = [mock_content_1, mock_content_2]
mock_response.usage.input_tokens = 40
mock_response.usage.output_tokens = 30
mock_claude_client.messages.create.return_value = mock_response
mock_anthropic_class.return_value = mock_claude_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'claude-3-5-sonnet-20240620',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System prompt", "User prompt")
# Assert
assert isinstance(result, LlmResult)
# Should take the first content block
assert result.text == "First part of response"
assert result.in_token == 40
assert result.out_token == 30
assert result.model == 'claude-3-5-sonnet-20240620'
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_claude_client_initialization(self, mock_llm_init, mock_async_init, mock_anthropic_class):
"""Test that Claude client is initialized correctly"""
# Arrange
mock_claude_client = MagicMock()
mock_anthropic_class.return_value = mock_claude_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'claude-3-opus-20240229',
'api_key': 'sk-ant-test-key',
'temperature': 0.0,
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
# Verify Anthropic client was called with correct API key
mock_anthropic_class.assert_called_once_with(api_key='sk-ant-test-key')
# Verify processor has the client
assert processor.claude == mock_claude_client
assert processor.model == 'claude-3-opus-20240229'
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,447 @@
"""
Unit tests for trustgraph.model.text_completion.cohere
Following the same successful pattern as previous tests
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.model.text_completion.cohere.llm import Processor
from trustgraph.base import LlmResult
from trustgraph.exceptions import TooManyRequests
class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
"""Test Cohere processor functionality"""
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_cohere_class):
"""Test basic processor initialization"""
# Arrange
mock_cohere_client = MagicMock()
mock_cohere_class.return_value = mock_cohere_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'c4ai-aya-23-8b',
'api_key': 'test-api-key',
'temperature': 0.0,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'c4ai-aya-23-8b'
assert processor.temperature == 0.0
assert hasattr(processor, 'cohere')
mock_cohere_class.assert_called_once_with(api_key='test-api-key')
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_cohere_class):
"""Test successful content generation"""
# Arrange
mock_cohere_client = MagicMock()
mock_output = MagicMock()
mock_output.text = "Generated response from Cohere"
mock_output.meta.billed_units.input_tokens = 25
mock_output.meta.billed_units.output_tokens = 15
mock_cohere_client.chat.return_value = mock_output
mock_cohere_class.return_value = mock_cohere_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'c4ai-aya-23-8b',
'api_key': 'test-api-key',
'temperature': 0.0,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System prompt", "User prompt")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Generated response from Cohere"
assert result.in_token == 25
assert result.out_token == 15
assert result.model == 'c4ai-aya-23-8b'
# Verify the Cohere API call
mock_cohere_client.chat.assert_called_once_with(
model='c4ai-aya-23-8b',
message="User prompt",
preamble="System prompt",
temperature=0.0,
chat_history=[],
prompt_truncation='auto',
connectors=[]
)
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_cohere_class):
"""Test rate limit error handling"""
# Arrange
import cohere
mock_cohere_client = MagicMock()
mock_cohere_client.chat.side_effect = cohere.TooManyRequestsError("Rate limit exceeded")
mock_cohere_class.return_value = mock_cohere_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'c4ai-aya-23-8b',
'api_key': 'test-api-key',
'temperature': 0.0,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act & Assert
with pytest.raises(TooManyRequests):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_cohere_class):
"""Test handling of generic exceptions"""
# Arrange
mock_cohere_client = MagicMock()
mock_cohere_client.chat.side_effect = Exception("API connection error")
mock_cohere_class.return_value = mock_cohere_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'c4ai-aya-23-8b',
'api_key': 'test-api-key',
'temperature': 0.0,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act & Assert
with pytest.raises(Exception, match="API connection error"):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_without_api_key(self, mock_llm_init, mock_async_init, mock_cohere_class):
"""Test processor initialization without API key (should fail)"""
# Arrange
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'c4ai-aya-23-8b',
'api_key': None, # No API key provided
'temperature': 0.0,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act & Assert
with pytest.raises(RuntimeError, match="Cohere API key not specified"):
processor = Processor(**config)
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_cohere_class):
"""Test processor initialization with custom parameters"""
# Arrange
mock_cohere_client = MagicMock()
mock_cohere_class.return_value = mock_cohere_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'command-light',
'api_key': 'custom-api-key',
'temperature': 0.7,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'command-light'
assert processor.temperature == 0.7
mock_cohere_class.assert_called_once_with(api_key='custom-api-key')
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_cohere_class):
"""Test processor initialization with default values"""
# Arrange
mock_cohere_client = MagicMock()
mock_cohere_class.return_value = mock_cohere_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
# Only provide required fields, should use defaults
config = {
'api_key': 'test-api-key',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'c4ai-aya-23-8b' # default_model
assert processor.temperature == 0.0 # default_temperature
mock_cohere_class.assert_called_once_with(api_key='test-api-key')
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_cohere_class):
"""Test content generation with empty prompts"""
# Arrange
mock_cohere_client = MagicMock()
mock_output = MagicMock()
mock_output.text = "Default response"
mock_output.meta.billed_units.input_tokens = 2
mock_output.meta.billed_units.output_tokens = 3
mock_cohere_client.chat.return_value = mock_output
mock_cohere_class.return_value = mock_cohere_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'c4ai-aya-23-8b',
'api_key': 'test-api-key',
'temperature': 0.0,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("", "")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Default response"
assert result.in_token == 2
assert result.out_token == 3
assert result.model == 'c4ai-aya-23-8b'
# Verify the preamble and message are handled correctly
call_args = mock_cohere_client.chat.call_args
assert call_args[1]['preamble'] == ""
assert call_args[1]['message'] == ""
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_chat_structure(self, mock_llm_init, mock_async_init, mock_cohere_class):
"""Test that Cohere chat is structured correctly"""
# Arrange
mock_cohere_client = MagicMock()
mock_output = MagicMock()
mock_output.text = "Response with proper structure"
mock_output.meta.billed_units.input_tokens = 30
mock_output.meta.billed_units.output_tokens = 20
mock_cohere_client.chat.return_value = mock_output
mock_cohere_class.return_value = mock_cohere_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'c4ai-aya-23-8b',
'api_key': 'test-api-key',
'temperature': 0.5,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
# Assert
assert result.text == "Response with proper structure"
assert result.in_token == 30
assert result.out_token == 20
# Verify the chat structure matches Cohere API format
call_args = mock_cohere_client.chat.call_args
# Check parameters
assert call_args[1]['model'] == 'c4ai-aya-23-8b'
assert call_args[1]['message'] == "What is AI?"
assert call_args[1]['preamble'] == "You are a helpful assistant"
assert call_args[1]['temperature'] == 0.5
assert call_args[1]['chat_history'] == []
assert call_args[1]['prompt_truncation'] == 'auto'
assert call_args[1]['connectors'] == []
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_token_parsing(self, mock_llm_init, mock_async_init, mock_cohere_class):
"""Test token parsing from Cohere response"""
# Arrange
mock_cohere_client = MagicMock()
mock_output = MagicMock()
mock_output.text = "Token parsing test"
mock_output.meta.billed_units.input_tokens = 50
mock_output.meta.billed_units.output_tokens = 25
mock_cohere_client.chat.return_value = mock_output
mock_cohere_class.return_value = mock_cohere_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'c4ai-aya-23-8b',
'api_key': 'test-api-key',
'temperature': 0.0,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System", "User query")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Token parsing test"
assert result.in_token == 50
assert result.out_token == 25
assert result.model == 'c4ai-aya-23-8b'
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_cohere_client_initialization(self, mock_llm_init, mock_async_init, mock_cohere_class):
"""Test that Cohere client is initialized correctly"""
# Arrange
mock_cohere_client = MagicMock()
mock_cohere_class.return_value = mock_cohere_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'command-r',
'api_key': 'co-test-key',
'temperature': 0.0,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
# Verify Cohere client was called with correct API key
mock_cohere_class.assert_called_once_with(api_key='co-test-key')
# Verify processor has the client
assert processor.cohere == mock_cohere_client
assert processor.model == 'command-r'
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_chat_parameters(self, mock_llm_init, mock_async_init, mock_cohere_class):
"""Test that all chat parameters are passed correctly"""
# Arrange
mock_cohere_client = MagicMock()
mock_output = MagicMock()
mock_output.text = "Chat parameter test"
mock_output.meta.billed_units.input_tokens = 20
mock_output.meta.billed_units.output_tokens = 10
mock_cohere_client.chat.return_value = mock_output
mock_cohere_class.return_value = mock_cohere_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'c4ai-aya-23-8b',
'api_key': 'test-api-key',
'temperature': 0.3,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System instructions", "User question")
# Assert
assert result.text == "Chat parameter test"
# Verify all parameters are passed correctly
call_args = mock_cohere_client.chat.call_args
assert call_args[1]['model'] == 'c4ai-aya-23-8b'
assert call_args[1]['message'] == "User question"
assert call_args[1]['preamble'] == "System instructions"
assert call_args[1]['temperature'] == 0.3
assert call_args[1]['chat_history'] == []
assert call_args[1]['prompt_truncation'] == 'auto'
assert call_args[1]['connectors'] == []
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,482 @@
"""
Unit tests for trustgraph.model.text_completion.googleaistudio
Following the same successful pattern as previous tests
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.model.text_completion.googleaistudio.llm import Processor
from trustgraph.base import LlmResult
from trustgraph.exceptions import TooManyRequests
class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
"""Test Google AI Studio processor functionality"""
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_genai_class):
"""Test basic processor initialization"""
# Arrange
mock_genai_client = MagicMock()
mock_genai_class.return_value = mock_genai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'gemini-2.0-flash-001'
assert processor.temperature == 0.0
assert processor.max_output == 8192
assert hasattr(processor, 'client')
assert hasattr(processor, 'safety_settings')
assert len(processor.safety_settings) == 4 # 4 safety categories
mock_genai_class.assert_called_once_with(api_key='test-api-key')
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_genai_class):
"""Test successful content generation"""
# Arrange
mock_genai_client = MagicMock()
mock_response = MagicMock()
mock_response.text = "Generated response from Google AI Studio"
mock_response.usage_metadata.prompt_token_count = 25
mock_response.usage_metadata.candidates_token_count = 15
mock_genai_client.models.generate_content.return_value = mock_response
mock_genai_class.return_value = mock_genai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System prompt", "User prompt")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Generated response from Google AI Studio"
assert result.in_token == 25
assert result.out_token == 15
assert result.model == 'gemini-2.0-flash-001'
# Verify the Google AI Studio API call
mock_genai_client.models.generate_content.assert_called_once()
call_args = mock_genai_client.models.generate_content.call_args
assert call_args[1]['model'] == 'gemini-2.0-flash-001'
assert call_args[1]['contents'] == "User prompt"
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_genai_class):
"""Test rate limit error handling"""
# Arrange
from google.api_core.exceptions import ResourceExhausted
mock_genai_client = MagicMock()
mock_genai_client.models.generate_content.side_effect = ResourceExhausted("Rate limit exceeded")
mock_genai_class.return_value = mock_genai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act & Assert
with pytest.raises(TooManyRequests):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_genai_class):
"""Test handling of generic exceptions"""
# Arrange
mock_genai_client = MagicMock()
mock_genai_client.models.generate_content.side_effect = Exception("API connection error")
mock_genai_class.return_value = mock_genai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act & Assert
with pytest.raises(Exception, match="API connection error"):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_without_api_key(self, mock_llm_init, mock_async_init, mock_genai_class):
"""Test processor initialization without API key (should fail)"""
# Arrange
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001',
'api_key': None, # No API key provided
'temperature': 0.0,
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act & Assert
with pytest.raises(RuntimeError, match="Google AI Studio API key not specified"):
processor = Processor(**config)
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_genai_class):
"""Test processor initialization with custom parameters"""
# Arrange
mock_genai_client = MagicMock()
mock_genai_class.return_value = mock_genai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-1.5-pro',
'api_key': 'custom-api-key',
'temperature': 0.7,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'gemini-1.5-pro'
assert processor.temperature == 0.7
assert processor.max_output == 4096
mock_genai_class.assert_called_once_with(api_key='custom-api-key')
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_genai_class):
"""Test processor initialization with default values"""
# Arrange
mock_genai_client = MagicMock()
mock_genai_class.return_value = mock_genai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
# Only provide required fields, should use defaults
config = {
'api_key': 'test-api-key',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'gemini-2.0-flash-001' # default_model
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 8192 # default_max_output
mock_genai_class.assert_called_once_with(api_key='test-api-key')
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_genai_class):
"""Test content generation with empty prompts"""
# Arrange
mock_genai_client = MagicMock()
mock_response = MagicMock()
mock_response.text = "Default response"
mock_response.usage_metadata.prompt_token_count = 2
mock_response.usage_metadata.candidates_token_count = 3
mock_genai_client.models.generate_content.return_value = mock_response
mock_genai_class.return_value = mock_genai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("", "")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Default response"
assert result.in_token == 2
assert result.out_token == 3
assert result.model == 'gemini-2.0-flash-001'
# Verify the system instruction and content are handled correctly
call_args = mock_genai_client.models.generate_content.call_args
assert call_args[1]['contents'] == ""
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_configuration_structure(self, mock_llm_init, mock_async_init, mock_genai_class):
"""Test that generation configuration is structured correctly"""
# Arrange
mock_genai_client = MagicMock()
mock_response = MagicMock()
mock_response.text = "Response with proper structure"
mock_response.usage_metadata.prompt_token_count = 30
mock_response.usage_metadata.candidates_token_count = 20
mock_genai_client.models.generate_content.return_value = mock_response
mock_genai_class.return_value = mock_genai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001',
'api_key': 'test-api-key',
'temperature': 0.5,
'max_output': 1024,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
# Assert
assert result.text == "Response with proper structure"
assert result.in_token == 30
assert result.out_token == 20
# Verify the generation configuration
call_args = mock_genai_client.models.generate_content.call_args
config_arg = call_args[1]['config']
# Check that the configuration has the right structure
assert call_args[1]['model'] == 'gemini-2.0-flash-001'
assert call_args[1]['contents'] == "What is AI?"
# Config should be a GenerateContentConfig object
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_safety_settings_initialization(self, mock_llm_init, mock_async_init, mock_genai_class):
"""Test that safety settings are initialized correctly"""
# Arrange
mock_genai_client = MagicMock()
mock_genai_class.return_value = mock_genai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert hasattr(processor, 'safety_settings')
assert len(processor.safety_settings) == 4
# Should have 4 safety categories: hate speech, harassment, sexually explicit, dangerous content
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_token_parsing(self, mock_llm_init, mock_async_init, mock_genai_class):
"""Test token parsing from Google AI Studio response"""
# Arrange
mock_genai_client = MagicMock()
mock_response = MagicMock()
mock_response.text = "Token parsing test"
mock_response.usage_metadata.prompt_token_count = 50
mock_response.usage_metadata.candidates_token_count = 25
mock_genai_client.models.generate_content.return_value = mock_response
mock_genai_class.return_value = mock_genai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System", "User query")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Token parsing test"
assert result.in_token == 50
assert result.out_token == 25
assert result.model == 'gemini-2.0-flash-001'
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_genai_client_initialization(self, mock_llm_init, mock_async_init, mock_genai_class):
"""Test that Google AI Studio client is initialized correctly"""
# Arrange
mock_genai_client = MagicMock()
mock_genai_class.return_value = mock_genai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-1.5-flash',
'api_key': 'gai-test-key',
'temperature': 0.0,
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
# Verify Google AI Studio client was called with correct API key
mock_genai_class.assert_called_once_with(api_key='gai-test-key')
# Verify processor has the client
assert processor.client == mock_genai_client
assert processor.model == 'gemini-1.5-flash'
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_system_instruction(self, mock_llm_init, mock_async_init, mock_genai_class):
"""Test that system instruction is handled correctly"""
# Arrange
mock_genai_client = MagicMock()
mock_response = MagicMock()
mock_response.text = "System instruction test"
mock_response.usage_metadata.prompt_token_count = 35
mock_response.usage_metadata.candidates_token_count = 25
mock_genai_client.models.generate_content.return_value = mock_response
mock_genai_class.return_value = mock_genai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("Be helpful and concise", "Explain quantum computing")
# Assert
assert result.text == "System instruction test"
assert result.in_token == 35
assert result.out_token == 25
# Verify the system instruction is passed in the config
call_args = mock_genai_client.models.generate_content.call_args
config_arg = call_args[1]['config']
# The system instruction should be in the config object
assert call_args[1]['contents'] == "Explain quantum computing"
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,454 @@
"""
Unit tests for trustgraph.model.text_completion.llamafile
Following the same successful pattern as previous tests
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.model.text_completion.llamafile.llm import Processor
from trustgraph.base import LlmResult
from trustgraph.exceptions import TooManyRequests
class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
"""Test LlamaFile processor functionality"""
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test basic processor initialization"""
# Arrange
mock_openai_client = MagicMock()
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'LLaMA_CPP',
'llamafile': 'http://localhost:8080/v1',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'LLaMA_CPP'
assert processor.llamafile == 'http://localhost:8080/v1'
assert processor.temperature == 0.0
assert processor.max_output == 4096
assert hasattr(processor, 'openai')
mock_openai_class.assert_called_once_with(
base_url='http://localhost:8080/v1',
api_key='sk-no-key-required'
)
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test successful content generation"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Generated response from LlamaFile"
mock_response.usage.prompt_tokens = 20
mock_response.usage.completion_tokens = 12
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'LLaMA_CPP',
'llamafile': 'http://localhost:8080/v1',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System prompt", "User prompt")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Generated response from LlamaFile"
assert result.in_token == 20
assert result.out_token == 12
assert result.model == 'llama.cpp' # Note: model in result is hardcoded to 'llama.cpp'
# Verify the OpenAI API call structure
mock_openai_client.chat.completions.create.assert_called_once_with(
model='LLaMA_CPP',
messages=[{
"role": "user",
"content": "System prompt\n\nUser prompt"
}]
)
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test handling of generic exceptions"""
# Arrange
mock_openai_client = MagicMock()
mock_openai_client.chat.completions.create.side_effect = Exception("Connection error")
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'LLaMA_CPP',
'llamafile': 'http://localhost:8080/v1',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act & Assert
with pytest.raises(Exception, match="Connection error"):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test processor initialization with custom parameters"""
# Arrange
mock_openai_client = MagicMock()
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'custom-llama',
'llamafile': 'http://custom-host:8080/v1',
'temperature': 0.7,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'custom-llama'
assert processor.llamafile == 'http://custom-host:8080/v1'
assert processor.temperature == 0.7
assert processor.max_output == 2048
mock_openai_class.assert_called_once_with(
base_url='http://custom-host:8080/v1',
api_key='sk-no-key-required'
)
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test processor initialization with default values"""
# Arrange
mock_openai_client = MagicMock()
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
# Only provide required fields, should use defaults
config = {
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'LLaMA_CPP' # default_model
assert processor.llamafile == 'http://localhost:8080/v1' # default_llamafile
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 4096 # default_max_output
mock_openai_class.assert_called_once_with(
base_url='http://localhost:8080/v1',
api_key='sk-no-key-required'
)
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test content generation with empty prompts"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Default response"
mock_response.usage.prompt_tokens = 2
mock_response.usage.completion_tokens = 3
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'LLaMA_CPP',
'llamafile': 'http://localhost:8080/v1',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("", "")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Default response"
assert result.in_token == 2
assert result.out_token == 3
assert result.model == 'llama.cpp'
# Verify the combined prompt is sent correctly
call_args = mock_openai_client.chat.completions.create.call_args
expected_prompt = "\n\n" # Empty system + "\n\n" + empty user
assert call_args[1]['messages'][0]['content'] == expected_prompt
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_message_structure(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test that LlamaFile messages are structured correctly"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response with proper structure"
mock_response.usage.prompt_tokens = 25
mock_response.usage.completion_tokens = 15
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'LLaMA_CPP',
'llamafile': 'http://localhost:8080/v1',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
# Assert
assert result.text == "Response with proper structure"
assert result.in_token == 25
assert result.out_token == 15
# Verify the message structure
call_args = mock_openai_client.chat.completions.create.call_args
messages = call_args[1]['messages']
assert len(messages) == 1
assert messages[0]['role'] == 'user'
assert messages[0]['content'] == "You are a helpful assistant\n\nWhat is AI?"
# Verify model parameter
assert call_args[1]['model'] == 'LLaMA_CPP'
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_openai_client_initialization(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test that OpenAI client is initialized correctly for LlamaFile"""
# Arrange
mock_openai_client = MagicMock()
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'llama-custom',
'llamafile': 'http://llamafile-server:8080/v1',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
# Verify OpenAI client was called with correct parameters
mock_openai_class.assert_called_once_with(
base_url='http://llamafile-server:8080/v1',
api_key='sk-no-key-required'
)
# Verify processor has the client
assert processor.openai == mock_openai_client
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test prompt construction with system and user prompts"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response with system instructions"
mock_response.usage.prompt_tokens = 30
mock_response.usage.completion_tokens = 20
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'LLaMA_CPP',
'llamafile': 'http://localhost:8080/v1',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("You are a helpful assistant", "What is machine learning?")
# Assert
assert result.text == "Response with system instructions"
assert result.in_token == 30
assert result.out_token == 20
# Verify the combined prompt
call_args = mock_openai_client.chat.completions.create.call_args
expected_prompt = "You are a helpful assistant\n\nWhat is machine learning?"
assert call_args[1]['messages'][0]['content'] == expected_prompt
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_hardcoded_model_response(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test that response model is hardcoded to 'llama.cpp'"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Test response"
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 10
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'custom-model-name', # This should be ignored in response
'llamafile': 'http://localhost:8080/v1',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System", "User")
# Assert
assert result.model == 'llama.cpp' # Should always be 'llama.cpp', not 'custom-model-name'
assert processor.model == 'custom-model-name' # But processor.model should still be custom
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_no_rate_limiting(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test that no rate limiting is implemented (SLM assumption)"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "No rate limiting test"
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'LLaMA_CPP',
'llamafile': 'http://localhost:8080/v1',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System", "User")
# Assert
assert result.text == "No rate limiting test"
# No specific rate limit error handling tested since SLM presumably has no rate limits
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,317 @@
"""
Unit tests for trustgraph.model.text_completion.ollama
Following the same successful pattern as VertexAI tests
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.model.text_completion.ollama.llm import Processor
from trustgraph.base import LlmResult
class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
"""Test Ollama processor functionality"""
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test basic processor initialization"""
# Arrange
mock_client = MagicMock()
mock_client_class.return_value = mock_client
# Mock the parent class initialization
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'llama2',
'ollama': 'http://localhost:11434',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'llama2'
assert hasattr(processor, 'llm')
mock_client_class.assert_called_once_with(host='http://localhost:11434')
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test successful content generation"""
# Arrange
mock_client = MagicMock()
mock_response = {
'response': 'Generated response from Ollama',
'prompt_eval_count': 15,
'eval_count': 8
}
mock_client.generate.return_value = mock_response
mock_client_class.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'llama2',
'ollama': 'http://localhost:11434',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System prompt", "User prompt")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Generated response from Ollama"
assert result.in_token == 15
assert result.out_token == 8
assert result.model == 'llama2'
mock_client.generate.assert_called_once_with('llama2', "System prompt\n\nUser prompt")
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test handling of generic exceptions"""
# Arrange
mock_client = MagicMock()
mock_client.generate.side_effect = Exception("Connection error")
mock_client_class.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'llama2',
'ollama': 'http://localhost:11434',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act & Assert
with pytest.raises(Exception, match="Connection error"):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test processor initialization with custom parameters"""
# Arrange
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'mistral',
'ollama': 'http://192.168.1.100:11434',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'mistral'
mock_client_class.assert_called_once_with(host='http://192.168.1.100:11434')
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test processor initialization with default values"""
# Arrange
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
# Don't provide model or ollama - should use defaults
config = {
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'gemma2:9b' # default_model
# Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env)
mock_client_class.assert_called_once()
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test content generation with empty prompts"""
# Arrange
mock_client = MagicMock()
mock_response = {
'response': 'Default response',
'prompt_eval_count': 2,
'eval_count': 3
}
mock_client.generate.return_value = mock_response
mock_client_class.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'llama2',
'ollama': 'http://localhost:11434',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("", "")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Default response"
assert result.in_token == 2
assert result.out_token == 3
assert result.model == 'llama2'
# The prompt should be "" + "\n\n" + "" = "\n\n"
mock_client.generate.assert_called_once_with('llama2', "\n\n")
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_token_counting(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test token counting from Ollama response"""
# Arrange
mock_client = MagicMock()
mock_response = {
'response': 'Test response',
'prompt_eval_count': 50,
'eval_count': 25
}
mock_client.generate.return_value = mock_response
mock_client_class.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'llama2',
'ollama': 'http://localhost:11434',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System", "User prompt")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Test response"
assert result.in_token == 50
assert result.out_token == 25
assert result.model == 'llama2'
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_ollama_client_initialization(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test that Ollama client is initialized correctly"""
# Arrange
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'codellama',
'ollama': 'http://ollama-server:11434',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
# Verify Client was called with correct host
mock_client_class.assert_called_once_with(host='http://ollama-server:11434')
# Verify processor has the client
assert processor.llm == mock_client
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test prompt construction with system and user prompts"""
# Arrange
mock_client = MagicMock()
mock_response = {
'response': 'Response with system instructions',
'prompt_eval_count': 25,
'eval_count': 15
}
mock_client.generate.return_value = mock_response
mock_client_class.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'llama2',
'ollama': 'http://localhost:11434',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
# Assert
assert result.text == "Response with system instructions"
assert result.in_token == 25
assert result.out_token == 15
# Verify the combined prompt
mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?")
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,395 @@
"""
Unit tests for trustgraph.model.text_completion.openai
Following the same successful pattern as VertexAI and Ollama tests
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.model.text_completion.openai.llm import Processor
from trustgraph.base import LlmResult
from trustgraph.exceptions import TooManyRequests
class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
"""Test OpenAI processor functionality"""
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test basic processor initialization"""
# Arrange
mock_openai_client = MagicMock()
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-3.5-turbo',
'api_key': 'test-api-key',
'url': 'https://api.openai.com/v1',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'gpt-3.5-turbo'
assert processor.temperature == 0.0
assert processor.max_output == 4096
assert hasattr(processor, 'openai')
mock_openai_class.assert_called_once_with(base_url='https://api.openai.com/v1', api_key='test-api-key')
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test successful content generation"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Generated response from OpenAI"
mock_response.usage.prompt_tokens = 20
mock_response.usage.completion_tokens = 12
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-3.5-turbo',
'api_key': 'test-api-key',
'url': 'https://api.openai.com/v1',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System prompt", "User prompt")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Generated response from OpenAI"
assert result.in_token == 20
assert result.out_token == 12
assert result.model == 'gpt-3.5-turbo'
# Verify the OpenAI API call
mock_openai_client.chat.completions.create.assert_called_once_with(
model='gpt-3.5-turbo',
messages=[{
"role": "user",
"content": [{
"type": "text",
"text": "System prompt\n\nUser prompt"
}]
}],
temperature=0.0,
max_tokens=4096,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
response_format={"type": "text"}
)
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test rate limit error handling"""
# Arrange
from openai import RateLimitError
mock_openai_client = MagicMock()
mock_openai_client.chat.completions.create.side_effect = RateLimitError("Rate limit exceeded", response=MagicMock(), body=None)
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-3.5-turbo',
'api_key': 'test-api-key',
'url': 'https://api.openai.com/v1',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act & Assert
with pytest.raises(TooManyRequests):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test handling of generic exceptions"""
# Arrange
mock_openai_client = MagicMock()
mock_openai_client.chat.completions.create.side_effect = Exception("API connection error")
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-3.5-turbo',
'api_key': 'test-api-key',
'url': 'https://api.openai.com/v1',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act & Assert
with pytest.raises(Exception, match="API connection error"):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_without_api_key(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test processor initialization without API key (should fail)"""
# Arrange
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-3.5-turbo',
'api_key': None, # No API key provided
'url': 'https://api.openai.com/v1',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act & Assert
with pytest.raises(RuntimeError, match="OpenAI API key not specified"):
processor = Processor(**config)
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test processor initialization with custom parameters"""
# Arrange
mock_openai_client = MagicMock()
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-4',
'api_key': 'custom-api-key',
'url': 'https://custom-openai-url.com/v1',
'temperature': 0.7,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'gpt-4'
assert processor.temperature == 0.7
assert processor.max_output == 2048
mock_openai_class.assert_called_once_with(base_url='https://custom-openai-url.com/v1', api_key='custom-api-key')
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test processor initialization with default values"""
# Arrange
mock_openai_client = MagicMock()
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
# Only provide required fields, should use defaults
config = {
'api_key': 'test-api-key',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'gpt-3.5-turbo' # default_model
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 4096 # default_max_output
mock_openai_class.assert_called_once_with(base_url='https://api.openai.com/v1', api_key='test-api-key')
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test content generation with empty prompts"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Default response"
mock_response.usage.prompt_tokens = 2
mock_response.usage.completion_tokens = 3
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-3.5-turbo',
'api_key': 'test-api-key',
'url': 'https://api.openai.com/v1',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("", "")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Default response"
assert result.in_token == 2
assert result.out_token == 3
assert result.model == 'gpt-3.5-turbo'
# Verify the combined prompt is sent correctly
call_args = mock_openai_client.chat.completions.create.call_args
expected_prompt = "\n\n" # Empty system + "\n\n" + empty user
assert call_args[1]['messages'][0]['content'][0]['text'] == expected_prompt
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_openai_client_initialization_without_base_url(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test OpenAI client initialization without base_url"""
# Arrange
mock_openai_client = MagicMock()
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-3.5-turbo',
'api_key': 'test-api-key',
'url': None, # No base URL
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert - should be called without base_url when it's None
mock_openai_class.assert_called_once_with(api_key='test-api-key')
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_message_structure(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test that OpenAI messages are structured correctly"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response with proper structure"
mock_response.usage.prompt_tokens = 25
mock_response.usage.completion_tokens = 15
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-3.5-turbo',
'api_key': 'test-api-key',
'url': 'https://api.openai.com/v1',
'temperature': 0.5,
'max_output': 1024,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
# Assert
assert result.text == "Response with proper structure"
assert result.in_token == 25
assert result.out_token == 15
# Verify the message structure matches OpenAI Chat API format
call_args = mock_openai_client.chat.completions.create.call_args
messages = call_args[1]['messages']
assert len(messages) == 1
assert messages[0]['role'] == 'user'
assert messages[0]['content'][0]['type'] == 'text'
assert messages[0]['content'][0]['text'] == "You are a helpful assistant\n\nWhat is AI?"
# Verify other parameters
assert call_args[1]['model'] == 'gpt-3.5-turbo'
assert call_args[1]['temperature'] == 0.5
assert call_args[1]['max_tokens'] == 1024
assert call_args[1]['top_p'] == 1
assert call_args[1]['frequency_penalty'] == 0
assert call_args[1]['presence_penalty'] == 0
assert call_args[1]['response_format'] == {"type": "text"}
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,397 @@
"""
Unit tests for trustgraph.model.text_completion.vertexai
Starting simple with one test to get the basics working
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.model.text_completion.vertexai.llm import Processor
from trustgraph.base import LlmResult
class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
"""Simple test for processor initialization"""
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
"""Test basic processor initialization with mocked dependencies"""
# Arrange
mock_credentials = MagicMock()
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_generative_model.return_value = mock_model
# Mock the parent class initialization to avoid taskgroup requirement
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'region': 'us-central1',
'model': 'gemini-2.0-flash-001',
'temperature': 0.0,
'max_output': 8192,
'private_key': 'private.json',
'concurrency': 1,
'taskgroup': AsyncMock(), # Required by AsyncProcessor
'id': 'test-processor' # Required by AsyncProcessor
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'gemini-2.0-flash-001' # It's stored as 'model', not 'model_name'
assert hasattr(processor, 'generation_config')
assert hasattr(processor, 'safety_settings')
assert hasattr(processor, 'llm')
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('private.json')
mock_vertexai.init.assert_called_once()
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
"""Test successful content generation"""
# Arrange
mock_credentials = MagicMock()
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = "Generated response from Gemini"
mock_response.usage_metadata.prompt_token_count = 15
mock_response.usage_metadata.candidates_token_count = 8
mock_model.generate_content.return_value = mock_response
mock_generative_model.return_value = mock_model
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'region': 'us-central1',
'model': 'gemini-2.0-flash-001',
'temperature': 0.0,
'max_output': 8192,
'private_key': 'private.json',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System prompt", "User prompt")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Generated response from Gemini"
assert result.in_token == 15
assert result.out_token == 8
assert result.model == 'gemini-2.0-flash-001'
# Check that the method was called (actual prompt format may vary)
mock_model.generate_content.assert_called_once()
# Verify the call was made with the expected parameters
call_args = mock_model.generate_content.call_args
assert call_args[1]['generation_config'] == processor.generation_config
assert call_args[1]['safety_settings'] == processor.safety_settings
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
"""Test rate limit error handling"""
# Arrange
from google.api_core.exceptions import ResourceExhausted
from trustgraph.exceptions import TooManyRequests
mock_credentials = MagicMock()
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_model.generate_content.side_effect = ResourceExhausted("Rate limit exceeded")
mock_generative_model.return_value = mock_model
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'region': 'us-central1',
'model': 'gemini-2.0-flash-001',
'temperature': 0.0,
'max_output': 8192,
'private_key': 'private.json',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act & Assert
with pytest.raises(TooManyRequests):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_blocked_response(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
"""Test handling of blocked content (safety filters)"""
# Arrange
mock_credentials = MagicMock()
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = None # Blocked content returns None
mock_response.usage_metadata.prompt_token_count = 10
mock_response.usage_metadata.candidates_token_count = 0
mock_model.generate_content.return_value = mock_response
mock_generative_model.return_value = mock_model
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'region': 'us-central1',
'model': 'gemini-2.0-flash-001',
'temperature': 0.0,
'max_output': 8192,
'private_key': 'private.json',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System prompt", "Blocked content")
# Assert
assert isinstance(result, LlmResult)
assert result.text is None # Should preserve None for blocked content
assert result.in_token == 10
assert result.out_token == 0
assert result.model == 'gemini-2.0-flash-001'
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
"""Test processor initialization without private key (should fail)"""
# Arrange
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'region': 'us-central1',
'model': 'gemini-2.0-flash-001',
'temperature': 0.0,
'max_output': 8192,
'private_key': None, # No private key provided
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act & Assert
with pytest.raises(RuntimeError, match="Private key file not specified"):
processor = Processor(**config)
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
"""Test handling of generic exceptions"""
# Arrange
mock_credentials = MagicMock()
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_model.generate_content.side_effect = Exception("Network error")
mock_generative_model.return_value = mock_model
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'region': 'us-central1',
'model': 'gemini-2.0-flash-001',
'temperature': 0.0,
'max_output': 8192,
'private_key': 'private.json',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act & Assert
with pytest.raises(Exception, match="Network error"):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
"""Test processor initialization with custom parameters"""
# Arrange
mock_credentials = MagicMock()
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_generative_model.return_value = mock_model
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'region': 'us-west1',
'model': 'gemini-1.5-pro',
'temperature': 0.7,
'max_output': 4096,
'private_key': 'custom-key.json',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'gemini-1.5-pro'
# Verify that generation_config object exists (can't easily check internal values)
assert hasattr(processor, 'generation_config')
assert processor.generation_config is not None
# Verify that safety settings are configured
assert len(processor.safety_settings) == 4
# Verify service account was called with custom key
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('custom-key.json')
# Verify that parameters dict has the correct values (this is accessible)
assert processor.parameters["temperature"] == 0.7
assert processor.parameters["max_output_tokens"] == 4096
assert processor.parameters["top_p"] == 1.0
assert processor.parameters["top_k"] == 32
assert processor.parameters["candidate_count"] == 1
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_vertexai_initialization_with_credentials(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
"""Test that VertexAI is initialized correctly with credentials"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_generative_model.return_value = mock_model
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'region': 'europe-west1',
'model': 'gemini-2.0-flash-001',
'temperature': 0.0,
'max_output': 8192,
'private_key': 'service-account.json',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
# Verify VertexAI init was called with correct parameters
mock_vertexai.init.assert_called_once_with(
location='europe-west1',
credentials=mock_credentials,
project='test-project-123'
)
# Verify GenerativeModel was created with the right model name
mock_generative_model.assert_called_once_with('gemini-2.0-flash-001')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
"""Test content generation with empty prompts"""
# Arrange
mock_credentials = MagicMock()
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = "Default response"
mock_response.usage_metadata.prompt_token_count = 2
mock_response.usage_metadata.candidates_token_count = 3
mock_model.generate_content.return_value = mock_response
mock_generative_model.return_value = mock_model
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'region': 'us-central1',
'model': 'gemini-2.0-flash-001',
'temperature': 0.0,
'max_output': 8192,
'private_key': 'private.json',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("", "")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Default response"
assert result.in_token == 2
assert result.out_token == 3
assert result.model == 'gemini-2.0-flash-001'
# Verify the model was called with the combined empty prompts
mock_model.generate_content.assert_called_once()
call_args = mock_model.generate_content.call_args
# The prompt should be "" + "\n\n" + "" = "\n\n"
assert call_args[0][0] == "\n\n"
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,489 @@
"""
Unit tests for trustgraph.model.text_completion.vllm
Following the same successful pattern as previous tests
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.model.text_completion.vllm.llm import Processor
from trustgraph.base import LlmResult
from trustgraph.exceptions import TooManyRequests
class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
"""Test vLLM processor functionality"""
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test basic processor initialization"""
# Arrange
mock_session = MagicMock()
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
'url': 'http://vllm-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'TheBloke/Mistral-7B-v0.1-AWQ'
assert processor.base_url == 'http://vllm-service:8899/v1'
assert processor.temperature == 0.0
assert processor.max_output == 2048
assert hasattr(processor, 'session')
mock_session_class.assert_called_once()
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test successful content generation"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'text': 'Generated response from vLLM'
}],
'usage': {
'prompt_tokens': 20,
'completion_tokens': 12
}
})
# Mock the async context manager
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
'url': 'http://vllm-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System prompt", "User prompt")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Generated response from vLLM"
assert result.in_token == 20
assert result.out_token == 12
assert result.model == 'TheBloke/Mistral-7B-v0.1-AWQ'
# Verify the vLLM API call
mock_session.post.assert_called_once_with(
'http://vllm-service:8899/v1/completions',
headers={'Content-Type': 'application/json'},
json={
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
'prompt': 'System prompt\n\nUser prompt',
'max_tokens': 2048,
'temperature': 0.0
}
)
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_http_error(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test HTTP error handling"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 500
# Mock the async context manager
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
'url': 'http://vllm-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act & Assert
with pytest.raises(RuntimeError, match="Bad status: 500"):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test handling of generic exceptions"""
# Arrange
mock_session = MagicMock()
mock_session.post.side_effect = Exception("Connection error")
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
'url': 'http://vllm-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act & Assert
with pytest.raises(Exception, match="Connection error"):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test processor initialization with custom parameters"""
# Arrange
mock_session = MagicMock()
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'custom-model',
'url': 'http://custom-vllm:8080/v1',
'temperature': 0.7,
'max_output': 1024,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'custom-model'
assert processor.base_url == 'http://custom-vllm:8080/v1'
assert processor.temperature == 0.7
assert processor.max_output == 1024
mock_session_class.assert_called_once()
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test processor initialization with default values"""
# Arrange
mock_session = MagicMock()
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
# Only provide required fields, should use defaults
config = {
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'TheBloke/Mistral-7B-v0.1-AWQ' # default_model
assert processor.base_url == 'http://vllm-service:8899/v1' # default_base_url
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 2048 # default_max_output
mock_session_class.assert_called_once()
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test content generation with empty prompts"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'text': 'Default response'
}],
'usage': {
'prompt_tokens': 2,
'completion_tokens': 3
}
})
# Mock the async context manager
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
'url': 'http://vllm-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("", "")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Default response"
assert result.in_token == 2
assert result.out_token == 3
assert result.model == 'TheBloke/Mistral-7B-v0.1-AWQ'
# Verify the combined prompt is sent correctly
call_args = mock_session.post.call_args
expected_prompt = "\n\n" # Empty system + "\n\n" + empty user
assert call_args[1]['json']['prompt'] == expected_prompt
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_request_structure(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test that vLLM request is structured correctly"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'text': 'Response with proper structure'
}],
'usage': {
'prompt_tokens': 25,
'completion_tokens': 15
}
})
# Mock the async context manager
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
'url': 'http://vllm-service:8899/v1',
'temperature': 0.5,
'max_output': 1024,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
# Assert
assert result.text == "Response with proper structure"
assert result.in_token == 25
assert result.out_token == 15
# Verify the request structure
call_args = mock_session.post.call_args
# Check URL
assert call_args[0][0] == 'http://vllm-service:8899/v1/completions'
# Check headers
assert call_args[1]['headers']['Content-Type'] == 'application/json'
# Check request body
request_data = call_args[1]['json']
assert request_data['model'] == 'TheBloke/Mistral-7B-v0.1-AWQ'
assert request_data['prompt'] == "You are a helpful assistant\n\nWhat is AI?"
assert request_data['temperature'] == 0.5
assert request_data['max_tokens'] == 1024
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_vllm_session_initialization(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test that aiohttp session is initialized correctly"""
# Arrange
mock_session = MagicMock()
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'test-model',
'url': 'http://test-vllm:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
# Verify ClientSession was created
mock_session_class.assert_called_once()
# Verify processor has the session
assert processor.session == mock_session
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_response_parsing(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test response parsing from vLLM API"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'text': 'Parsed response text'
}],
'usage': {
'prompt_tokens': 35,
'completion_tokens': 25
}
})
# Mock the async context manager
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
'url': 'http://vllm-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System", "User query")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Parsed response text"
assert result.in_token == 35
assert result.out_token == 25
assert result.model == 'TheBloke/Mistral-7B-v0.1-AWQ'
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test prompt construction with system and user prompts"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'text': 'Response with system instructions'
}],
'usage': {
'prompt_tokens': 40,
'completion_tokens': 30
}
})
# Mock the async context manager
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
'url': 'http://vllm-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("You are a helpful assistant", "What is machine learning?")
# Assert
assert result.text == "Response with system instructions"
assert result.in_token == 40
assert result.out_token == 30
# Verify the combined prompt
call_args = mock_session.post.call_args
expected_prompt = "You are a helpful assistant\n\nWhat is machine learning?"
assert call_args[1]['json']['prompt'] == expected_prompt
if __name__ == '__main__':
pytest.main([__file__])