mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 16:36:21 +02:00
* Provenance tests * Embeddings tests * Test librarian * Test triples stream * Test concurrency * Entity centric graph writes * Agent tool service tests * Structured data tests * RDF tests * Addition LLM tests * Reliability tests
199 lines
7.4 KiB
Python
199 lines
7.4 KiB
Python
"""
|
|
Tests for Azure serverless endpoint streaming: model override during streaming,
|
|
HTTP 429 during streaming, SSE chunk parsing, and final token count emission.
|
|
"""
|
|
|
|
import json
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from unittest import IsolatedAsyncioTestCase
|
|
|
|
from trustgraph.model.text_completion.azure.llm import Processor
|
|
from trustgraph.base import LlmChunk
|
|
from trustgraph.exceptions import TooManyRequests
|
|
|
|
|
|
def _make_processor(mock_requests, model="AzureAI", temperature=0.0):
|
|
"""Create a Processor with mocked base classes."""
|
|
with patch('trustgraph.base.async_processor.AsyncProcessor.__init__',
|
|
return_value=None), \
|
|
patch('trustgraph.base.llm_service.LlmService.__init__',
|
|
return_value=None):
|
|
proc = Processor(
|
|
endpoint="https://test.azure.com/v1/chat/completions",
|
|
token="test-token",
|
|
temperature=temperature,
|
|
max_output=4192,
|
|
model=model,
|
|
concurrency=1,
|
|
taskgroup=AsyncMock(),
|
|
id="test-processor",
|
|
)
|
|
return proc
|
|
|
|
|
|
def _sse_lines(*data_items):
|
|
"""Build SSE byte lines from data items. '[DONE]' is appended."""
|
|
lines = []
|
|
for item in data_items:
|
|
if isinstance(item, dict):
|
|
lines.append(f"data: {json.dumps(item)}".encode())
|
|
else:
|
|
lines.append(f"data: {item}".encode())
|
|
lines.append(b"data: [DONE]")
|
|
return lines
|
|
|
|
|
|
class TestAzureServerlessStreaming(IsolatedAsyncioTestCase):
|
|
|
|
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
|
async def test_streaming_yields_chunks(self, mock_requests):
|
|
proc = _make_processor(mock_requests)
|
|
|
|
chunks = [
|
|
{"choices": [{"delta": {"content": "Hello"}}]},
|
|
{"choices": [{"delta": {"content": " world"}}]},
|
|
{"usage": {"prompt_tokens": 10, "completion_tokens": 5}},
|
|
]
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.iter_lines.return_value = _sse_lines(*chunks)
|
|
mock_requests.post.return_value = mock_response
|
|
|
|
results = []
|
|
async for chunk in proc.generate_content_stream("sys", "user"):
|
|
results.append(chunk)
|
|
|
|
# Content chunks + final chunk
|
|
assert len(results) == 3
|
|
assert results[0].text == "Hello"
|
|
assert results[0].is_final is False
|
|
assert results[1].text == " world"
|
|
assert results[1].is_final is False
|
|
assert results[2].is_final is True
|
|
assert results[2].in_token == 10
|
|
assert results[2].out_token == 5
|
|
|
|
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
|
async def test_streaming_model_override(self, mock_requests):
|
|
proc = _make_processor(mock_requests, model="default-model")
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.iter_lines.return_value = _sse_lines(
|
|
{"choices": [{"delta": {"content": "ok"}}]},
|
|
{"usage": {"prompt_tokens": 5, "completion_tokens": 2}},
|
|
)
|
|
mock_requests.post.return_value = mock_response
|
|
|
|
results = []
|
|
async for chunk in proc.generate_content_stream(
|
|
"sys", "user", model="override-model"
|
|
):
|
|
results.append(chunk)
|
|
|
|
# All chunks should carry the overridden model name
|
|
for r in results:
|
|
assert r.model == "override-model"
|
|
|
|
# Verify the request body used the overridden model
|
|
call_args = mock_requests.post.call_args
|
|
body = json.loads(call_args[1]["data"])
|
|
assert body["model"] == "override-model"
|
|
|
|
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
|
async def test_streaming_temperature_override(self, mock_requests):
|
|
proc = _make_processor(mock_requests, temperature=0.0)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.iter_lines.return_value = _sse_lines(
|
|
{"choices": [{"delta": {"content": "ok"}}]},
|
|
{"usage": {"prompt_tokens": 5, "completion_tokens": 2}},
|
|
)
|
|
mock_requests.post.return_value = mock_response
|
|
|
|
results = []
|
|
async for chunk in proc.generate_content_stream(
|
|
"sys", "user", temperature=0.9
|
|
):
|
|
results.append(chunk)
|
|
|
|
call_args = mock_requests.post.call_args
|
|
body = json.loads(call_args[1]["data"])
|
|
assert body["temperature"] == 0.9
|
|
|
|
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
|
async def test_streaming_429_raises_too_many_requests(self, mock_requests):
|
|
proc = _make_processor(mock_requests)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 429
|
|
mock_requests.post.return_value = mock_response
|
|
|
|
with pytest.raises(TooManyRequests):
|
|
async for _ in proc.generate_content_stream("sys", "user"):
|
|
pass
|
|
|
|
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
|
async def test_streaming_http_error_raises_runtime(self, mock_requests):
|
|
proc = _make_processor(mock_requests)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 503
|
|
mock_response.text = "Service Unavailable"
|
|
mock_requests.post.return_value = mock_response
|
|
|
|
with pytest.raises(RuntimeError, match="HTTP 503"):
|
|
async for _ in proc.generate_content_stream("sys", "user"):
|
|
pass
|
|
|
|
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
|
async def test_streaming_includes_stream_options(self, mock_requests):
|
|
"""Verify stream=True and stream_options in request body."""
|
|
proc = _make_processor(mock_requests)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.iter_lines.return_value = _sse_lines(
|
|
{"usage": {"prompt_tokens": 0, "completion_tokens": 0}},
|
|
)
|
|
mock_requests.post.return_value = mock_response
|
|
|
|
async for _ in proc.generate_content_stream("sys", "user"):
|
|
pass
|
|
|
|
call_args = mock_requests.post.call_args
|
|
body = json.loads(call_args[1]["data"])
|
|
assert body["stream"] is True
|
|
assert body["stream_options"]["include_usage"] is True
|
|
|
|
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
|
async def test_streaming_malformed_json_skipped(self, mock_requests):
|
|
"""Malformed JSON chunks should be skipped, not crash the stream."""
|
|
proc = _make_processor(mock_requests)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
lines = [
|
|
b"data: {not valid json}",
|
|
f'data: {json.dumps({"choices": [{"delta": {"content": "ok"}}]})}'.encode(),
|
|
f'data: {json.dumps({"usage": {"prompt_tokens": 1, "completion_tokens": 1}})}'.encode(),
|
|
b"data: [DONE]",
|
|
]
|
|
mock_response.iter_lines.return_value = lines
|
|
mock_requests.post.return_value = mock_response
|
|
|
|
results = []
|
|
async for chunk in proc.generate_content_stream("sys", "user"):
|
|
results.append(chunk)
|
|
|
|
# Should get the valid content chunk + final chunk
|
|
assert any(r.text == "ok" for r in results)
|
|
assert results[-1].is_final is True
|
|
|
|
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
|
async def test_streaming_supports_streaming_flag(self, mock_requests):
|
|
proc = _make_processor(mock_requests)
|
|
assert proc.supports_streaming() is True
|