Merge pull request 'dev-0.9.x -> main' (#81) from dev-0.9.x into main
All checks were successful
Build and Publish Docker Image (Semantic Cache) / build (amd64, linux/amd64, docker-amd64) (push) Successful in 3m18s
Build and Publish Docker Image / build (amd64, linux/amd64, docker-amd64) (push) Successful in 1m18s
Build and Publish Docker Image (Semantic Cache) / build (arm64, linux/arm64, docker-arm64) (push) Successful in 14m25s
Build and Publish Docker Image (Semantic Cache) / merge (push) Successful in 32s
Build and Publish Docker Image / build (arm64, linux/arm64, docker-arm64) (push) Successful in 11m42s
Build and Publish Docker Image / merge (push) Successful in 1m2s
All checks were successful
Build and Publish Docker Image (Semantic Cache) / build (amd64, linux/amd64, docker-amd64) (push) Successful in 3m18s
Build and Publish Docker Image / build (amd64, linux/amd64, docker-amd64) (push) Successful in 1m18s
Build and Publish Docker Image (Semantic Cache) / build (arm64, linux/arm64, docker-arm64) (push) Successful in 14m25s
Build and Publish Docker Image (Semantic Cache) / merge (push) Successful in 32s
Build and Publish Docker Image / build (arm64, linux/arm64, docker-arm64) (push) Successful in 11m42s
Build and Publish Docker Image / merge (push) Successful in 1m2s
Reviewed-on: https://bitfreedom.net/code/code/nomyo-ai/nomyo-router/pulls/81
This commit is contained in:
commit
1df9d75cf7
18 changed files with 2895 additions and 4 deletions
39
.forgejo/workflows/pr-tests.yml
Normal file
39
.forgejo/workflows/pr-tests.yml
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
name: PR Tests
|
||||
on: [pull_request]
|
||||
jobs:
|
||||
test:
|
||||
runs-on: docker-arm64
|
||||
container:
|
||||
image: python:3.12-slim
|
||||
env:
|
||||
CMAKE_BUILD_PARALLEL_LEVEL: "4"
|
||||
steps:
|
||||
- name: Install system deps
|
||||
run: |
|
||||
apt-get update
|
||||
apt-get install -y --no-install-recommends \
|
||||
git ca-certificates \
|
||||
build-essential pkg-config
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
- name: Checkout
|
||||
run: |
|
||||
git config --global --add safe.directory "$PWD"
|
||||
git clone --depth=1 \
|
||||
"https://oauth2:${{ github.token }}@bitfreedom.net/code/${{ github.repository }}.git" .
|
||||
git fetch --depth=1 origin "+${{ github.event.pull_request.head.sha }}:pr"
|
||||
git checkout pr
|
||||
- name: Fetch action source
|
||||
run: |
|
||||
git clone --depth=1 --branch master \
|
||||
"https://oauth2:${{ github.token }}@bitfreedom.net/code/nomyo-ai/actions.git" \
|
||||
./.run-tests
|
||||
- uses: ./.run-tests/run-tests
|
||||
with:
|
||||
setup: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
pip install -r test/requirements_test.txt
|
||||
command: pytest test/ -m "not integration" --cov=router --cov=cache --cov=db --cov=enhance --cov-fail-under=45 --cov-report=term-missing --cov-report=xml --junitxml=report.xml
|
||||
artifacts-path: |
|
||||
report.xml
|
||||
coverage.xml
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
|
|
@ -66,7 +66,4 @@ config.yaml
|
|||
# SQLite
|
||||
*.db*
|
||||
|
||||
*settings.json
|
||||
|
||||
# Test suite (local only, not committed yet)
|
||||
test/
|
||||
*settings.json
|
||||
13
test/config_test.yaml
Normal file
13
test/config_test.yaml
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
endpoints:
|
||||
- http://192.168.0.51:12434
|
||||
|
||||
llama_server_endpoints:
|
||||
- http://192.168.0.51:12434/v1
|
||||
|
||||
max_concurrent_connections: 2
|
||||
|
||||
api_keys:
|
||||
"http://192.168.0.51:12434": "ollama"
|
||||
"http://192.168.0.51:12434/v1": "llama"
|
||||
|
||||
cache_enabled: false
|
||||
233
test/conftest.py
Normal file
233
test/conftest.py
Normal file
|
|
@ -0,0 +1,233 @@
|
|||
"""
|
||||
Test configuration for nomyo-router.
|
||||
|
||||
Run from project root:
|
||||
pytest test/ -v
|
||||
pytest test/ -m "not integration" # skip real-server tests
|
||||
pytest test/ -m integration -v # only real-server tests
|
||||
|
||||
Environment variables:
|
||||
NOMYO_TEST_OLLAMA Ollama endpoint (default: http://192.168.0.50:12434)
|
||||
NOMYO_TEST_LLAMA llama-server endpoint (default: http://192.168.0.50:12434/v1)
|
||||
NOMYO_TEST_MODEL_CHAT chat model to use (auto-discovered if unset)
|
||||
NOMYO_TEST_EMBED_MODEL embedding model (auto-discovered if unset)
|
||||
"""
|
||||
import asyncio
|
||||
import os
|
||||
import ssl
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import aiohttp
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
_TEST_DIR = Path(__file__).parent
|
||||
# Must be set before importing router so module-level Config.from_yaml + Config field
|
||||
# defaults pick these up. db_path is intentionally absent from config_test.yaml so the
|
||||
# env-var default wins — keeps tests portable across CI runners (Linux/macOS/Windows).
|
||||
os.environ.setdefault("NOMYO_ROUTER_CONFIG_PATH", str(_TEST_DIR / "config_test.yaml"))
|
||||
os.environ.setdefault(
|
||||
"NOMYO_ROUTER_DB_PATH",
|
||||
str(Path(tempfile.gettempdir()) / "nomyo_router_test_tokens.db"),
|
||||
)
|
||||
|
||||
sys.path.insert(0, str(_TEST_DIR.parent))
|
||||
|
||||
import router # noqa: E402
|
||||
|
||||
TEST_OLLAMA = os.getenv("NOMYO_TEST_OLLAMA", "http://192.168.0.51:12434")
|
||||
TEST_LLAMA = os.getenv("NOMYO_TEST_LLAMA", "http://192.168.0.51:12434/v1")
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"integration: tests that require a real backend at 192.168.0.50:12434",
|
||||
)
|
||||
|
||||
|
||||
# ── Config mocks ─────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
"""Minimal config pointing at TEST_OLLAMA / TEST_LLAMA."""
|
||||
cfg = MagicMock()
|
||||
cfg.endpoints = [TEST_OLLAMA]
|
||||
cfg.llama_server_endpoints = [TEST_LLAMA]
|
||||
cfg.api_keys = {TEST_OLLAMA: "ollama", TEST_LLAMA: "llama"}
|
||||
cfg.max_concurrent_connections = 2
|
||||
cfg.router_api_key = None
|
||||
cfg.cache_enabled = False
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_no_llama():
|
||||
"""Config with Ollama only, no llama-server."""
|
||||
cfg = MagicMock()
|
||||
cfg.endpoints = [TEST_OLLAMA]
|
||||
cfg.llama_server_endpoints = []
|
||||
cfg.api_keys = {TEST_OLLAMA: "ollama"}
|
||||
cfg.max_concurrent_connections = 2
|
||||
cfg.router_api_key = None
|
||||
cfg.cache_enabled = False
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_with_key():
|
||||
"""Config with router_api_key set (enables auth middleware)."""
|
||||
cfg = MagicMock()
|
||||
cfg.endpoints = [TEST_OLLAMA]
|
||||
cfg.llama_server_endpoints = []
|
||||
cfg.api_keys = {}
|
||||
cfg.max_concurrent_connections = 2
|
||||
cfg.router_api_key = "test-secret-key"
|
||||
cfg.cache_enabled = False
|
||||
return cfg
|
||||
|
||||
|
||||
# ── aiohttp session (used by fetch tests + choose_endpoint tests) ─────────────
|
||||
|
||||
@pytest.fixture
|
||||
async def aio_session():
|
||||
"""Real aiohttp session stored in app_state; intercepted by aioresponses."""
|
||||
ssl_ctx = ssl.create_default_context()
|
||||
conn = aiohttp.TCPConnector(ssl=ssl_ctx)
|
||||
session = aiohttp.ClientSession(connector=conn)
|
||||
router.app_state["session"] = session
|
||||
|
||||
# Clear caches to prevent test bleed
|
||||
router._models_cache.clear()
|
||||
router._loaded_models_cache.clear()
|
||||
router._available_error_cache.clear()
|
||||
router._loaded_error_cache.clear()
|
||||
router._inflight_available_models.clear()
|
||||
router._inflight_loaded_models.clear()
|
||||
router._bg_refresh_available.clear()
|
||||
router._bg_refresh_loaded.clear()
|
||||
|
||||
yield session
|
||||
|
||||
await session.close()
|
||||
router.app_state["session"] = None
|
||||
|
||||
|
||||
# ── Validation-only HTTP client (no real backend needed) ──────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
async def client(mock_config, tmp_path):
|
||||
"""httpx client for validation/auth tests — no real backend calls made."""
|
||||
from db import TokenDatabase
|
||||
|
||||
ssl_ctx = ssl.create_default_context()
|
||||
conn = aiohttp.TCPConnector(ssl=ssl_ctx)
|
||||
session = aiohttp.ClientSession(connector=conn)
|
||||
|
||||
db_inst = TokenDatabase(str(tmp_path / "test.db"))
|
||||
await db_inst.init_db()
|
||||
|
||||
old_session = router.app_state.get("session")
|
||||
old_db = router.db
|
||||
|
||||
router.app_state["session"] = session
|
||||
router.db = db_inst
|
||||
|
||||
with patch.object(router, "config", mock_config):
|
||||
transport = httpx.ASGITransport(app=router.app)
|
||||
async with httpx.AsyncClient(
|
||||
transport=transport, base_url="http://test", timeout=10.0
|
||||
) as c:
|
||||
yield c
|
||||
|
||||
await session.close()
|
||||
router.app_state["session"] = old_session
|
||||
router.db = old_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client_auth(mock_config_with_key, tmp_path):
|
||||
"""httpx client with router_api_key configured (for auth middleware tests)."""
|
||||
from db import TokenDatabase
|
||||
|
||||
ssl_ctx = ssl.create_default_context()
|
||||
conn = aiohttp.TCPConnector(ssl=ssl_ctx)
|
||||
session = aiohttp.ClientSession(connector=conn)
|
||||
|
||||
db_inst = TokenDatabase(str(tmp_path / "test_auth.db"))
|
||||
await db_inst.init_db()
|
||||
|
||||
old_session = router.app_state.get("session")
|
||||
old_db = router.db
|
||||
|
||||
router.app_state["session"] = session
|
||||
router.db = db_inst
|
||||
|
||||
with patch.object(router, "config", mock_config_with_key):
|
||||
transport = httpx.ASGITransport(app=router.app)
|
||||
async with httpx.AsyncClient(
|
||||
transport=transport, base_url="http://test", timeout=10.0
|
||||
) as c:
|
||||
yield c
|
||||
|
||||
await session.close()
|
||||
router.app_state["session"] = old_session
|
||||
router.db = old_db
|
||||
|
||||
|
||||
# ── Integration client (full startup with real backend) ──────────────────────
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def integration_client():
|
||||
"""Full app startup pointing at the real test server."""
|
||||
await router.startup_event()
|
||||
transport = httpx.ASGITransport(app=router.app)
|
||||
async with httpx.AsyncClient(
|
||||
transport=transport,
|
||||
base_url="http://test",
|
||||
timeout=httpx.Timeout(60.0),
|
||||
) as c:
|
||||
yield c
|
||||
await router.shutdown_event()
|
||||
|
||||
|
||||
# ── Model discovery fixtures ──────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def chat_model(integration_client):
|
||||
"""Return a chat/generation model name available on the test server."""
|
||||
env_model = os.getenv("NOMYO_TEST_MODEL_CHAT")
|
||||
if env_model:
|
||||
return env_model
|
||||
resp = await integration_client.get("/api/tags")
|
||||
if resp.status_code != 200:
|
||||
pytest.skip("Cannot reach test server")
|
||||
models = resp.json().get("models", [])
|
||||
# Prefer small models for faster tests
|
||||
for m in models:
|
||||
name = m.get("name", "")
|
||||
if any(x in name.lower() for x in ["0.5b", "1b", "3b", "1.5b", "2b"]):
|
||||
return name
|
||||
if models:
|
||||
return models[0]["name"]
|
||||
pytest.skip("No chat models available on test server")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def embed_model(integration_client):
|
||||
"""Return an embedding model name available on the test server."""
|
||||
env_model = os.getenv("NOMYO_TEST_EMBED_MODEL")
|
||||
if env_model:
|
||||
return env_model
|
||||
resp = await integration_client.get("/api/tags")
|
||||
if resp.status_code != 200:
|
||||
pytest.skip("Cannot reach test server")
|
||||
models = resp.json().get("models", [])
|
||||
for m in models:
|
||||
name = m.get("name", "")
|
||||
if any(x in name.lower() for x in ["embed", "nomic", "minilm", "bge", "e5"]):
|
||||
return name
|
||||
pytest.skip("No embedding model available on test server")
|
||||
7
test/pytest.ini
Normal file
7
test/pytest.ini
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
[pytest]
|
||||
asyncio_mode = auto
|
||||
markers =
|
||||
integration: tests that require a real backend at 192.168.0.51:12434
|
||||
testpaths = .
|
||||
filterwarnings =
|
||||
ignore::pytest.PytestUnhandledThreadExceptionWarning
|
||||
4
test/requirements_test.txt
Normal file
4
test/requirements_test.txt
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
pytest>=8.0
|
||||
pytest-asyncio>=0.24
|
||||
pytest-cov>=5.0
|
||||
aioresponses>=0.7
|
||||
60
test/test.md
Normal file
60
test/test.md
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
# Testing nomyo-router
|
||||
|
||||
## Setup
|
||||
|
||||
Install test dependencies (from the project root):
|
||||
|
||||
```bash
|
||||
pip install -r test/requirements_test.txt
|
||||
```
|
||||
|
||||
## Running tests
|
||||
|
||||
All commands run from the `test/` directory:
|
||||
|
||||
```bash
|
||||
cd test
|
||||
```
|
||||
|
||||
**All non-integration tests** (no backend required):
|
||||
```bash
|
||||
pytest -m "not integration" -v
|
||||
```
|
||||
|
||||
**Integration tests only** (requires backend at `192.168.0.51:12434`):
|
||||
```bash
|
||||
pytest -m integration -v
|
||||
```
|
||||
|
||||
**Everything:**
|
||||
```bash
|
||||
pytest -v
|
||||
```
|
||||
|
||||
## Test structure
|
||||
|
||||
| File | What it covers | Backend needed |
|
||||
|---|---|---|
|
||||
| `test_unit_helpers.py` | Pure helper functions (`_mask_secrets`, `_is_fresh`, `ep2base`, etc.) | No |
|
||||
| `test_unit_transforms.py` | Message transform functions (tool calls, image stripping, etc.) | No |
|
||||
| `test_unit_context.py` | Context window trimming logic | No |
|
||||
| `test_fetch.py` | `fetch.available_models` / `fetch.loaded_models` with mocked HTTP | No |
|
||||
| `test_choose_endpoint.py` | `choose_endpoint` routing logic with mocked fetch layer | No |
|
||||
| `test_api_validation.py` | HTTP 400/401/403 validation and auth middleware (in-process app) | No |
|
||||
| `test_api_integration.py` | Full request/response against a real Ollama/llama-server backend | **Yes** |
|
||||
|
||||
## Integration test backend
|
||||
|
||||
Integration tests start the router in-process via `startup_event()` and route traffic
|
||||
through `httpx.ASGITransport` — no separately running router instance is needed.
|
||||
|
||||
They do require a reachable Ollama or llama-server backend. Override the defaults via
|
||||
environment variables:
|
||||
|
||||
```bash
|
||||
export NOMYO_TEST_OLLAMA=http://192.168.0.51:12434
|
||||
export NOMYO_TEST_EMBED_MODEL=nomic-embed-text # optional, auto-discovered otherwise
|
||||
export NOMYO_TEST_MODEL_CHAT=llama3.2 # optional, auto-discovered otherwise
|
||||
```
|
||||
|
||||
If the backend is unreachable, integration tests are automatically skipped.
|
||||
304
test/test_api_integration.py
Normal file
304
test/test_api_integration.py
Normal file
|
|
@ -0,0 +1,304 @@
|
|||
"""
|
||||
Integration tests against the real backend at 192.168.0.50:12434.
|
||||
|
||||
Run with:
|
||||
pytest test/test_api_integration.py -v -m integration
|
||||
|
||||
All tests in this file are marked @pytest.mark.integration.
|
||||
They require the test server to be reachable and to have at least one
|
||||
chat model and one embedding model available.
|
||||
|
||||
Env vars to pin specific models:
|
||||
NOMYO_TEST_MODEL_CHAT e.g. qwen2.5:1.5b
|
||||
NOMYO_TEST_EMBED_MODEL e.g. nomic-embed-text:latest
|
||||
"""
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
# ── Health / discovery routes ─────────────────────────────────────────────────
|
||||
|
||||
class TestDiscoveryRoutes:
|
||||
async def test_version(self, integration_client):
|
||||
resp = await integration_client.get("/api/version")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "version" in data
|
||||
assert isinstance(data["version"], str)
|
||||
|
||||
async def test_tags_returns_models(self, integration_client):
|
||||
resp = await integration_client.get("/api/tags")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "models" in data
|
||||
assert isinstance(data["models"], list)
|
||||
assert len(data["models"]) > 0
|
||||
|
||||
async def test_ps_returns_list(self, integration_client):
|
||||
resp = await integration_client.get("/api/ps")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "models" in data
|
||||
assert isinstance(data["models"], list)
|
||||
|
||||
async def test_v1_models_returns_data(self, integration_client):
|
||||
resp = await integration_client.get("/v1/models")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "data" in data
|
||||
assert isinstance(data["data"], list)
|
||||
|
||||
async def test_usage_returns_counts(self, integration_client):
|
||||
resp = await integration_client.get("/api/usage")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "usage_counts" in data
|
||||
assert "token_usage_counts" in data
|
||||
|
||||
async def test_config_returns_endpoints(self, integration_client):
|
||||
resp = await integration_client.get("/api/config")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "endpoints" in data
|
||||
|
||||
async def test_hostname(self, integration_client):
|
||||
resp = await integration_client.get("/api/hostname")
|
||||
assert resp.status_code == 200
|
||||
assert "hostname" in resp.json()
|
||||
|
||||
async def test_health(self, integration_client):
|
||||
resp = await integration_client.get("/health")
|
||||
assert resp.status_code in (200, 503)
|
||||
data = resp.json()
|
||||
assert data["status"] in ("ok", "error")
|
||||
assert "endpoints" in data
|
||||
|
||||
async def test_cache_stats(self, integration_client):
|
||||
resp = await integration_client.get("/api/cache/stats")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "enabled" in data
|
||||
|
||||
|
||||
# ── /api/chat ─────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestApiChat:
|
||||
async def test_non_streaming(self, integration_client, chat_model):
|
||||
resp = await integration_client.post(
|
||||
"/api/chat",
|
||||
json={
|
||||
"model": chat_model,
|
||||
"stream": False,
|
||||
"messages": [{"role": "user", "content": "Reply with exactly: OK"}],
|
||||
"options": {"num_predict": 10},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "message" in data
|
||||
assert "content" in data["message"]
|
||||
|
||||
async def test_streaming_ndjson(self, integration_client, chat_model):
|
||||
resp = await integration_client.post(
|
||||
"/api/chat",
|
||||
json={
|
||||
"model": chat_model,
|
||||
"stream": True,
|
||||
"messages": [{"role": "user", "content": "Say hi"}],
|
||||
"options": {"num_predict": 5},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
lines = [l for l in resp.text.strip().split("\n") if l.strip()]
|
||||
assert len(lines) >= 1
|
||||
for line in lines:
|
||||
obj = json.loads(line)
|
||||
assert "model" in obj
|
||||
|
||||
async def test_non_streaming_has_token_counts(self, integration_client, chat_model):
|
||||
resp = await integration_client.post(
|
||||
"/api/chat",
|
||||
json={
|
||||
"model": chat_model,
|
||||
"stream": False,
|
||||
"messages": [{"role": "user", "content": "Count to 3"}],
|
||||
"options": {"num_predict": 20},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data.get("done") is True
|
||||
# Token counts should be present in the final chunk
|
||||
assert data.get("prompt_eval_count", 0) >= 0
|
||||
|
||||
async def test_system_message_honoured(self, integration_client, chat_model):
|
||||
resp = await integration_client.post(
|
||||
"/api/chat",
|
||||
json={
|
||||
"model": chat_model,
|
||||
"stream": False,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant. Always reply with exactly: PONG"},
|
||||
{"role": "user", "content": "PING"},
|
||||
],
|
||||
"options": {"num_predict": 10},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
content = resp.json()["message"]["content"]
|
||||
assert isinstance(content, str)
|
||||
assert len(content) > 0
|
||||
|
||||
|
||||
# ── /api/generate ─────────────────────────────────────────────────────────────
|
||||
|
||||
class TestApiGenerate:
|
||||
async def test_non_streaming(self, integration_client, chat_model):
|
||||
resp = await integration_client.post(
|
||||
"/api/generate",
|
||||
json={
|
||||
"model": chat_model,
|
||||
"prompt": "Complete: The sky is",
|
||||
"stream": False,
|
||||
"options": {"num_predict": 5},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "response" in data
|
||||
|
||||
async def test_streaming(self, integration_client, chat_model):
|
||||
resp = await integration_client.post(
|
||||
"/api/generate",
|
||||
json={
|
||||
"model": chat_model,
|
||||
"prompt": "One plus one equals",
|
||||
"stream": True,
|
||||
"options": {"num_predict": 5},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
lines = [l for l in resp.text.strip().split("\n") if l.strip()]
|
||||
assert len(lines) >= 1
|
||||
|
||||
|
||||
# ── /api/embed ────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestApiEmbed:
|
||||
async def test_embed_single_string(self, integration_client, embed_model):
|
||||
resp = await integration_client.post(
|
||||
"/api/embed",
|
||||
json={"model": embed_model, "input": "The quick brown fox"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "embeddings" in data
|
||||
assert isinstance(data["embeddings"], list)
|
||||
assert len(data["embeddings"]) == 1
|
||||
assert len(data["embeddings"][0]) > 0
|
||||
|
||||
async def test_embed_multiple_inputs(self, integration_client, embed_model):
|
||||
resp = await integration_client.post(
|
||||
"/api/embed",
|
||||
json={"model": embed_model, "input": ["sentence one", "sentence two"]},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "embeddings" in data
|
||||
assert len(data["embeddings"]) == 2
|
||||
|
||||
|
||||
# ── /v1/chat/completions ──────────────────────────────────────────────────────
|
||||
|
||||
class TestOpenAIChatCompletions:
|
||||
async def test_non_streaming(self, integration_client, chat_model):
|
||||
model = chat_model.replace(":latest", "") if ":latest" in chat_model else chat_model
|
||||
resp = await integration_client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": "Reply OK"}],
|
||||
"max_tokens": 10,
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "choices" in data
|
||||
assert len(data["choices"]) > 0
|
||||
assert "message" in data["choices"][0]
|
||||
|
||||
async def test_streaming_sse(self, integration_client, chat_model):
|
||||
model = chat_model.replace(":latest", "") if ":latest" in chat_model else chat_model
|
||||
resp = await integration_client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"max_tokens": 5,
|
||||
"stream": True,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
# Response should be SSE format
|
||||
assert "data:" in resp.text or "[DONE]" in resp.text
|
||||
|
||||
async def test_non_streaming_has_usage(self, integration_client, chat_model):
|
||||
model = chat_model.replace(":latest", "") if ":latest" in chat_model else chat_model
|
||||
resp = await integration_client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": "Say yes"}],
|
||||
"max_tokens": 5,
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
if "usage" in data and data["usage"]:
|
||||
assert data["usage"].get("prompt_tokens", 0) >= 0
|
||||
|
||||
|
||||
# ── /v1/embeddings ────────────────────────────────────────────────────────────
|
||||
|
||||
class TestOpenAIEmbeddings:
|
||||
async def test_single_input(self, integration_client, embed_model):
|
||||
model = embed_model.replace(":latest", "") if ":latest" in embed_model else embed_model
|
||||
resp = await integration_client.post(
|
||||
"/v1/embeddings",
|
||||
json={"model": model, "input": "Test sentence"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "data" in data
|
||||
assert len(data["data"]) > 0
|
||||
embedding = data["data"][0].get("embedding")
|
||||
assert isinstance(embedding, list)
|
||||
assert len(embedding) > 0
|
||||
|
||||
|
||||
# ── Token counts (database-backed) ───────────────────────────────────────────
|
||||
|
||||
class TestTokenCounts:
|
||||
async def test_token_counts_endpoint(self, integration_client):
|
||||
resp = await integration_client.get("/api/token_counts")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "total_tokens" in data
|
||||
assert "breakdown" in data
|
||||
|
||||
|
||||
# ── ps_details (extended ps) ─────────────────────────────────────────────────
|
||||
|
||||
class TestPsDetails:
|
||||
async def test_ps_details_returns_models(self, integration_client):
|
||||
resp = await integration_client.get("/api/ps_details")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "models" in data
|
||||
assert isinstance(data["models"], list)
|
||||
230
test/test_api_validation.py
Normal file
230
test/test_api_validation.py
Normal file
|
|
@ -0,0 +1,230 @@
|
|||
"""
|
||||
HTTP-level validation and auth middleware tests.
|
||||
|
||||
These tests use an in-process httpx client and never reach a real backend:
|
||||
all requests are rejected at the validation or auth layer before any
|
||||
endpoint-selection or upstream HTTP calls occur.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
|
||||
class TestChatValidation:
|
||||
async def test_missing_model_returns_400(self, client):
|
||||
resp = await client.post(
|
||||
"/api/chat",
|
||||
json={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert "model" in resp.json()["detail"].lower()
|
||||
|
||||
async def test_missing_messages_returns_400(self, client):
|
||||
resp = await client.post("/api/chat", json={"model": "llama3.2"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_invalid_json_returns_400(self, client):
|
||||
resp = await client.post(
|
||||
"/api/chat",
|
||||
content=b"not-json",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_messages_not_list_returns_400(self, client):
|
||||
resp = await client.post(
|
||||
"/api/chat",
|
||||
json={"model": "m", "messages": "not-a-list"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_options_not_dict_returns_400(self, client):
|
||||
resp = await client.post(
|
||||
"/api/chat",
|
||||
json={"model": "m", "messages": [{"role": "user", "content": "hi"}], "options": "bad"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
class TestGenerateValidation:
|
||||
async def test_missing_model_returns_400(self, client):
|
||||
resp = await client.post("/api/generate", json={"prompt": "hello"})
|
||||
assert resp.status_code == 400
|
||||
assert "model" in resp.json()["detail"].lower()
|
||||
|
||||
async def test_missing_prompt_returns_400(self, client):
|
||||
resp = await client.post("/api/generate", json={"model": "m"})
|
||||
assert resp.status_code == 400
|
||||
assert "prompt" in resp.json()["detail"].lower()
|
||||
|
||||
async def test_invalid_json_returns_400(self, client):
|
||||
resp = await client.post(
|
||||
"/api/generate",
|
||||
content=b"{bad-json",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
class TestEmbedValidation:
|
||||
async def test_missing_model_returns_400(self, client):
|
||||
resp = await client.post("/api/embed", json={"input": "hello"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_missing_input_returns_400(self, client):
|
||||
resp = await client.post("/api/embed", json={"model": "nomic-embed-text"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
class TestEmbeddingsValidation:
|
||||
async def test_missing_model_returns_400(self, client):
|
||||
resp = await client.post("/api/embeddings", json={"prompt": "hello"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_missing_prompt_returns_400(self, client):
|
||||
resp = await client.post("/api/embeddings", json={"model": "nomic-embed-text"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
class TestOpenAIChatValidation:
|
||||
async def test_missing_model_returns_400(self, client):
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_missing_messages_returns_400(self, client):
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"model": "gpt-4o"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_invalid_json_returns_400(self, client):
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
content=b"}{",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_svg_image_rejected(self, client):
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "vision-model",
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "describe"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/svg+xml;base64,abc"}},
|
||||
],
|
||||
}],
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert "svg" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
class TestOpenAICompletionsValidation:
|
||||
async def test_missing_model_returns_400(self, client):
|
||||
resp = await client.post("/v1/completions", json={"prompt": "hello"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_missing_prompt_returns_400(self, client):
|
||||
resp = await client.post("/v1/completions", json={"model": "m"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
class TestRerankValidation:
|
||||
async def test_missing_model_returns_400(self, client):
|
||||
resp = await client.post(
|
||||
"/v1/rerank",
|
||||
json={"query": "search query", "documents": ["doc1"]},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_missing_query_returns_400(self, client):
|
||||
resp = await client.post(
|
||||
"/v1/rerank",
|
||||
json={"model": "reranker", "documents": ["doc1"]},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_empty_documents_returns_400(self, client):
|
||||
resp = await client.post(
|
||||
"/v1/rerank",
|
||||
json={"model": "reranker", "query": "search", "documents": []},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
class TestShowValidation:
|
||||
async def test_missing_model_returns_400(self, client):
|
||||
resp = await client.post("/api/show", json={})
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
class TestCopyValidation:
|
||||
async def test_missing_source_returns_400(self, client):
|
||||
resp = await client.post("/api/copy", json={"destination": "dst"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_missing_destination_returns_400(self, client):
|
||||
resp = await client.post("/api/copy", json={"source": "src"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
class TestDeleteValidation:
|
||||
async def test_missing_model_returns_400(self, client):
|
||||
import json as _json
|
||||
resp = await client.request(
|
||||
"DELETE",
|
||||
"/api/delete",
|
||||
content=_json.dumps({}).encode(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
class TestAuthMiddleware:
|
||||
async def test_no_key_returns_401(self, client_auth):
|
||||
resp = await client_auth.post(
|
||||
"/api/chat",
|
||||
json={"model": "m", "messages": [{"role": "user", "content": "hi"}]},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert "Missing" in resp.json()["detail"]
|
||||
|
||||
async def test_invalid_key_returns_403(self, client_auth):
|
||||
resp = await client_auth.post(
|
||||
"/api/chat",
|
||||
headers={"Authorization": "Bearer wrong-key"},
|
||||
json={"model": "m", "messages": [{"role": "user", "content": "hi"}]},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
assert "Invalid" in resp.json()["detail"]
|
||||
|
||||
async def test_valid_key_passes_middleware(self, client_auth):
|
||||
# /api/usage reads in-memory counters only — no backend call needed
|
||||
resp = await client_auth.get(
|
||||
"/api/usage",
|
||||
headers={"Authorization": "Bearer test-secret-key"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
async def test_key_via_query_param(self, client_auth):
|
||||
resp = await client_auth.get("/api/usage?api_key=test-secret-key")
|
||||
assert resp.status_code == 200
|
||||
|
||||
async def test_options_bypasses_auth(self, client_auth):
|
||||
resp = await client_auth.options("/api/chat")
|
||||
assert resp.status_code not in (401, 403)
|
||||
|
||||
async def test_root_path_bypasses_auth(self, client_auth):
|
||||
resp = await client_auth.get("/")
|
||||
assert resp.status_code not in (401, 403)
|
||||
|
||||
async def test_favicon_bypasses_auth(self, client_auth):
|
||||
resp = await client_auth.get("/favicon.ico")
|
||||
# Should not be blocked by auth (may 404 in test but not 401/403)
|
||||
assert resp.status_code not in (401, 403)
|
||||
333
test/test_cache.py
Normal file
333
test/test_cache.py
Normal file
|
|
@ -0,0 +1,333 @@
|
|||
"""Unit tests for cache.LLMCache in exact-match mode (no sentence-transformers needed)."""
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
|
||||
import cache as cache_mod
|
||||
from cache import (
|
||||
LLMCache,
|
||||
_bm25_weighted_text,
|
||||
get_llm_cache,
|
||||
init_llm_cache,
|
||||
openai_nonstream_to_sse,
|
||||
)
|
||||
|
||||
_CACHE_DB_PATH = str(Path(tempfile.gettempdir()) / "nomyo_test_cache.db")
|
||||
|
||||
|
||||
def _exact_cfg(backend: str = "memory") -> SimpleNamespace:
|
||||
"""Config for exact-match mode — similarity=1.0 avoids embedding deps."""
|
||||
return SimpleNamespace(
|
||||
cache_enabled=True,
|
||||
cache_backend=backend,
|
||||
cache_similarity=1.0,
|
||||
cache_history_weight=0.3,
|
||||
cache_ttl=300,
|
||||
cache_db_path=_CACHE_DB_PATH,
|
||||
cache_redis_url="redis://localhost:6379",
|
||||
)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Pure helpers
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestBM25WeightedText:
|
||||
def test_empty_history(self):
|
||||
assert _bm25_weighted_text([]) == ""
|
||||
|
||||
def test_history_without_content(self):
|
||||
assert _bm25_weighted_text([{"role": "user"}, {"role": "assistant"}]) == ""
|
||||
|
||||
def test_repeats_high_idf_terms(self):
|
||||
history = [
|
||||
{"role": "user", "content": "Tell me about quantum entanglement"},
|
||||
{"role": "assistant", "content": "Quantum entanglement is a phenomenon"},
|
||||
{"role": "user", "content": "How does entanglement work?"},
|
||||
]
|
||||
out = _bm25_weighted_text(history)
|
||||
# Rare/domain term ("entanglement") should appear; short stopwords (<=2 chars) dropped
|
||||
assert "entanglement" in out
|
||||
assert "is" not in out.split()
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# openai_nonstream_to_sse
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestOpenAINonstreamToSSE:
|
||||
def test_valid_chat_completion(self):
|
||||
chat = {
|
||||
"id": "x1",
|
||||
"created": 123,
|
||||
"model": "gpt-4o",
|
||||
"choices": [{"message": {"role": "assistant", "content": "hello"}}],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
|
||||
}
|
||||
out = openai_nonstream_to_sse(orjson.dumps(chat), "gpt-4o")
|
||||
text = out.decode()
|
||||
assert text.startswith("data: ")
|
||||
assert text.endswith("data: [DONE]\n\n")
|
||||
# First chunk contains the original content
|
||||
first = text.split("\n\n")[0][len("data: "):]
|
||||
parsed = orjson.loads(first)
|
||||
assert parsed["choices"][0]["delta"]["content"] == "hello"
|
||||
assert parsed["usage"]["total_tokens"] == 3
|
||||
|
||||
def test_corrupt_bytes_return_done_only(self):
|
||||
out = openai_nonstream_to_sse(b"not-json", "m")
|
||||
assert out == b"data: [DONE]\n\n"
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# LLMCache internal helpers
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestLLMCacheParsing:
|
||||
def test_namespace_is_stable_and_isolated(self):
|
||||
c = LLMCache(_exact_cfg())
|
||||
a = c._namespace("chat", "m1", "system A")
|
||||
b = c._namespace("chat", "m1", "system A")
|
||||
assert a == b
|
||||
assert c._namespace("chat", "m1", "system B") != a
|
||||
assert c._namespace("generate", "m1", "system A") != a
|
||||
assert len(a) == 16
|
||||
|
||||
def test_parse_messages_flat_strings(self):
|
||||
c = LLMCache(_exact_cfg())
|
||||
sys, hist, last = c._parse_messages([
|
||||
{"role": "system", "content": "be helpful"},
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "hello"},
|
||||
{"role": "user", "content": "what is 2+2?"},
|
||||
])
|
||||
assert sys == "be helpful"
|
||||
assert last == "what is 2+2?"
|
||||
assert hist == [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "hello"},
|
||||
]
|
||||
|
||||
def test_parse_messages_multimodal_content(self):
|
||||
c = LLMCache(_exact_cfg())
|
||||
sys, _hist, last = c._parse_messages([
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": [
|
||||
{"type": "text", "text": "describe"},
|
||||
{"type": "image_url", "image_url": {"url": "data:..."}},
|
||||
]},
|
||||
])
|
||||
assert sys == "sys"
|
||||
assert last == "describe"
|
||||
|
||||
def test_parse_messages_no_user_message(self):
|
||||
c = LLMCache(_exact_cfg())
|
||||
sys, hist, last = c._parse_messages([
|
||||
{"role": "system", "content": "sys only"},
|
||||
])
|
||||
assert sys == "sys only"
|
||||
assert last == ""
|
||||
assert hist == []
|
||||
|
||||
|
||||
class TestPersonalTokenExtraction:
|
||||
def test_email_extracted(self):
|
||||
c = LLMCache(_exact_cfg())
|
||||
toks = c._extract_personal_tokens("Reach me at alice@example.com please")
|
||||
assert "alice@example.com" in toks
|
||||
|
||||
def test_numeric_id_after_keyword(self):
|
||||
c = LLMCache(_exact_cfg())
|
||||
toks = c._extract_personal_tokens("User id: 123456")
|
||||
assert "123456" in toks
|
||||
|
||||
def test_identity_tag_names_extracted(self):
|
||||
c = LLMCache(_exact_cfg())
|
||||
toks = c._extract_personal_tokens(
|
||||
"[Tags: identity] User's name is Andreas Schwibbe"
|
||||
)
|
||||
# Both name tokens should be extracted lowercased; stopwords dropped
|
||||
assert "andreas" in toks
|
||||
assert "schwibbe" in toks
|
||||
assert "name" not in toks # in _IDENTITY_STOPWORDS
|
||||
assert "user" not in toks
|
||||
|
||||
def test_empty_system_returns_empty_set(self):
|
||||
c = LLMCache(_exact_cfg())
|
||||
assert c._extract_personal_tokens("") == frozenset()
|
||||
|
||||
|
||||
class TestResponseIsPersonalized:
|
||||
def _resp(self, content: str) -> bytes:
|
||||
return orjson.dumps({"choices": [{"message": {"content": content}}]})
|
||||
|
||||
def test_email_in_response_is_personalized(self):
|
||||
c = LLMCache(_exact_cfg())
|
||||
assert c._response_is_personalized(self._resp("contact bob@x.com"), "")
|
||||
|
||||
def test_uuid_in_response_is_personalized(self):
|
||||
c = LLMCache(_exact_cfg())
|
||||
uuid = "550e8400-e29b-41d4-a716-446655440000"
|
||||
assert c._response_is_personalized(self._resp(f"id={uuid}"), "")
|
||||
|
||||
def test_long_numeric_id_in_response_is_personalized(self):
|
||||
c = LLMCache(_exact_cfg())
|
||||
assert c._response_is_personalized(self._resp("account 12345678"), "")
|
||||
|
||||
def test_identity_token_from_system_echoed_in_response(self):
|
||||
c = LLMCache(_exact_cfg())
|
||||
system = "[Tags: identity] Andreas works here"
|
||||
assert c._response_is_personalized(
|
||||
self._resp("Yes, Andreas is logged in"), system
|
||||
)
|
||||
|
||||
def test_generic_response_not_personalized(self):
|
||||
c = LLMCache(_exact_cfg())
|
||||
assert not c._response_is_personalized(
|
||||
self._resp("The capital of France is Paris."), "be helpful"
|
||||
)
|
||||
|
||||
def test_ollama_message_format_parsed(self):
|
||||
c = LLMCache(_exact_cfg())
|
||||
body = orjson.dumps({"message": {"content": "alice@example.com"}})
|
||||
assert c._response_is_personalized(body, "")
|
||||
|
||||
def test_unparseable_body_with_bytes_is_conservative(self):
|
||||
c = LLMCache(_exact_cfg())
|
||||
# Can't parse → returns True (err on the side of privacy)
|
||||
assert c._response_is_personalized(b"binary-junk", "")
|
||||
|
||||
def test_empty_response_not_personalized(self):
|
||||
c = LLMCache(_exact_cfg())
|
||||
assert not c._response_is_personalized(b"", "anything")
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# End-to-end exact-match cache with the memory backend
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
async def memcache():
|
||||
"""LLMCache wired up with the in-memory backend (no external deps)."""
|
||||
c = LLMCache(_exact_cfg("memory"))
|
||||
await c.init()
|
||||
return c
|
||||
|
||||
|
||||
class TestExactMatchCache:
|
||||
async def test_miss_then_set_then_hit(self, memcache):
|
||||
msgs = [
|
||||
{"role": "system", "content": "be helpful"},
|
||||
{"role": "user", "content": "what is 2+2?"},
|
||||
]
|
||||
resp = orjson.dumps({"choices": [{"message": {"content": "4"}}]})
|
||||
|
||||
assert await memcache.get_chat("chat", "m1", msgs) is None
|
||||
await memcache.set_chat("chat", "m1", msgs, resp)
|
||||
hit = await memcache.get_chat("chat", "m1", msgs)
|
||||
assert hit == resp
|
||||
|
||||
async def test_namespace_isolation_by_system(self, memcache):
|
||||
resp = orjson.dumps({"choices": [{"message": {"content": "ok"}}]})
|
||||
msgs_a = [
|
||||
{"role": "system", "content": "system A"},
|
||||
{"role": "user", "content": "same question"},
|
||||
]
|
||||
msgs_b = [
|
||||
{"role": "system", "content": "system B"},
|
||||
{"role": "user", "content": "same question"},
|
||||
]
|
||||
await memcache.set_chat("chat", "m", msgs_a, resp)
|
||||
# Same question + different system prompt = different namespace = miss
|
||||
assert await memcache.get_chat("chat", "m", msgs_b) is None
|
||||
|
||||
async def test_namespace_isolation_by_route(self, memcache):
|
||||
resp = orjson.dumps({"choices": [{"message": {"content": "ok"}}]})
|
||||
msgs = [{"role": "user", "content": "ping"}]
|
||||
await memcache.set_chat("chat", "m", msgs, resp)
|
||||
assert await memcache.get_chat("openai_chat", "m", msgs) is None
|
||||
|
||||
async def test_no_user_message_is_noop(self, memcache):
|
||||
msgs = [{"role": "system", "content": "sys only"}]
|
||||
resp = orjson.dumps({"choices": [{"message": {"content": "x"}}]})
|
||||
# Both get and set should silently no-op
|
||||
assert await memcache.get_chat("chat", "m", msgs) is None
|
||||
await memcache.set_chat("chat", "m", msgs, resp)
|
||||
assert await memcache.get_chat("chat", "m", msgs) is None
|
||||
|
||||
async def test_personalized_response_generic_system_not_stored(self, memcache):
|
||||
msgs = [
|
||||
{"role": "system", "content": "be helpful"}, # generic
|
||||
{"role": "user", "content": "give me an email"},
|
||||
]
|
||||
# Response contains an email → would leak across users sharing the
|
||||
# generic namespace → must NOT be stored at all
|
||||
resp = orjson.dumps({"choices": [{"message": {"content": "bob@x.com"}}]})
|
||||
await memcache.set_chat("chat", "m", msgs, resp)
|
||||
assert await memcache.get_chat("chat", "m", msgs) is None
|
||||
|
||||
async def test_personalized_response_user_specific_system_stored(self, memcache):
|
||||
msgs = [
|
||||
{"role": "system", "content": "User id: 998877 prefers concise answers"},
|
||||
{"role": "user", "content": "what is my id?"},
|
||||
]
|
||||
resp = orjson.dumps({"choices": [{"message": {"content": "Your id is 998877"}}]})
|
||||
await memcache.set_chat("chat", "m", msgs, resp)
|
||||
# User-specific namespace → exact-match within this user is OK
|
||||
assert await memcache.get_chat("chat", "m", msgs) == resp
|
||||
|
||||
async def test_generate_convenience_wrappers(self, memcache):
|
||||
resp = orjson.dumps({"response": "blue"})
|
||||
await memcache.set_generate("m", "what color is the sky?", "", resp)
|
||||
assert await memcache.get_generate("m", "what color is the sky?") == resp
|
||||
|
||||
|
||||
class TestStatsAndClear:
|
||||
async def test_stats_tracks_hits_and_misses(self, memcache):
|
||||
msgs = [{"role": "user", "content": "hello"}]
|
||||
await memcache.get_chat("chat", "m", msgs) # miss
|
||||
resp = orjson.dumps({"choices": [{"message": {"content": "hi"}}]})
|
||||
await memcache.set_chat("chat", "m", msgs, resp)
|
||||
await memcache.get_chat("chat", "m", msgs) # hit
|
||||
s = memcache.stats()
|
||||
assert s["hits"] == 1
|
||||
assert s["misses"] == 1
|
||||
assert s["hit_rate"] == 0.5
|
||||
assert s["semantic"] is False
|
||||
assert s["backend"] == "memory"
|
||||
|
||||
async def test_clear_resets_counters_and_storage(self, memcache):
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
resp = orjson.dumps({"choices": [{"message": {"content": "ok"}}]})
|
||||
await memcache.set_chat("chat", "m", msgs, resp)
|
||||
await memcache.get_chat("chat", "m", msgs)
|
||||
await memcache.clear()
|
||||
s = memcache.stats()
|
||||
assert s["hits"] == 0
|
||||
assert s["misses"] == 0
|
||||
assert await memcache.get_chat("chat", "m", msgs) is None
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Module-level helpers
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestInitLLMCache:
|
||||
async def test_disabled_returns_none(self):
|
||||
cfg = _exact_cfg()
|
||||
cfg.cache_enabled = False
|
||||
result = await init_llm_cache(cfg)
|
||||
assert result is None
|
||||
|
||||
async def test_enabled_returns_initialized_cache(self):
|
||||
cfg = _exact_cfg()
|
||||
try:
|
||||
result = await init_llm_cache(cfg)
|
||||
assert result is not None
|
||||
assert get_llm_cache() is result
|
||||
finally:
|
||||
# Reset singleton between tests
|
||||
cache_mod._cache = None
|
||||
345
test/test_choose_endpoint.py
Normal file
345
test/test_choose_endpoint.py
Normal file
|
|
@ -0,0 +1,345 @@
|
|||
"""Tests for choose_endpoint routing logic with mocked fetch calls."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import router
|
||||
|
||||
EP1 = "http://ep1:11434"
|
||||
EP2 = "http://ep2:11434"
|
||||
EP3 = "http://ep3:11434"
|
||||
LLAMA_EP = "http://llama:8080/v1"
|
||||
|
||||
|
||||
def _make_cfg(endpoints, llama_eps=None, max_conn=2, endpoint_config=None, priority_routing=False):
|
||||
cfg = MagicMock()
|
||||
cfg.endpoints = endpoints
|
||||
cfg.llama_server_endpoints = llama_eps or []
|
||||
cfg.api_keys = {}
|
||||
cfg.max_concurrent_connections = max_conn
|
||||
cfg.endpoint_config = endpoint_config or {}
|
||||
cfg.priority_routing = priority_routing
|
||||
cfg.router_api_key = None
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_usage():
|
||||
"""Clear usage_counts between tests to prevent bleed."""
|
||||
router.usage_counts.clear()
|
||||
yield
|
||||
router.usage_counts.clear()
|
||||
|
||||
|
||||
class TestChooseEndpointBasic:
|
||||
async def test_selects_single_candidate(self):
|
||||
cfg = _make_cfg([EP1])
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(router.fetch, "available_models", AsyncMock(return_value={"llama3.2:latest"})),
|
||||
patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"llama3.2:latest"})),
|
||||
):
|
||||
ep, tracking = await router.choose_endpoint("llama3.2:latest")
|
||||
assert ep == EP1
|
||||
assert tracking == "llama3.2:latest"
|
||||
|
||||
async def test_raises_when_no_endpoint_has_model(self):
|
||||
cfg = _make_cfg([EP1, EP2])
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(router.fetch, "available_models", AsyncMock(return_value=set())),
|
||||
patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="advertise the model"):
|
||||
await router.choose_endpoint("unknown-model:latest")
|
||||
|
||||
async def test_prefers_loaded_endpoint(self):
|
||||
cfg = _make_cfg([EP1, EP2])
|
||||
async def available(ep, *_):
|
||||
return {"llama3.2:latest"}
|
||||
|
||||
async def loaded(ep):
|
||||
return {"llama3.2:latest"} if ep == EP2 else set()
|
||||
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(router.fetch, "available_models", side_effect=available),
|
||||
patch.object(router.fetch, "loaded_models", side_effect=loaded),
|
||||
):
|
||||
ep, _ = await router.choose_endpoint("llama3.2:latest")
|
||||
assert ep == EP2
|
||||
|
||||
async def test_falls_back_to_free_slot(self):
|
||||
cfg = _make_cfg([EP1, EP2])
|
||||
async def available(ep, *_):
|
||||
return {"llama3.2:latest"}
|
||||
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(router.fetch, "available_models", side_effect=available),
|
||||
patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
|
||||
):
|
||||
ep, _ = await router.choose_endpoint("llama3.2:latest")
|
||||
assert ep in (EP1, EP2)
|
||||
|
||||
async def test_saturated_picks_least_busy(self):
|
||||
cfg = _make_cfg([EP1, EP2])
|
||||
cfg.max_concurrent_connections = 1
|
||||
|
||||
async def available(ep, *_):
|
||||
return {"llama3.2:latest"}
|
||||
|
||||
# Saturate EP1 with 2 active connections, EP2 with 1
|
||||
router.usage_counts[EP1]["llama3.2:latest"] = 2
|
||||
router.usage_counts[EP2]["llama3.2:latest"] = 1
|
||||
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(router.fetch, "available_models", side_effect=available),
|
||||
patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
|
||||
):
|
||||
ep, _ = await router.choose_endpoint("llama3.2:latest")
|
||||
# Least-busy is EP2
|
||||
assert ep == EP2
|
||||
|
||||
async def test_reserve_increments_usage(self):
|
||||
cfg = _make_cfg([EP1])
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(router.fetch, "available_models", AsyncMock(return_value={"model:latest"})),
|
||||
patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"model:latest"})),
|
||||
):
|
||||
ep, tracking = await router.choose_endpoint("model:latest", reserve=True)
|
||||
assert router.usage_counts[ep][tracking] == 1
|
||||
|
||||
|
||||
class TestChooseEndpointModelNaming:
|
||||
async def test_strips_latest_for_openai_endpoints(self):
|
||||
cfg = _make_cfg(endpoints=[], llama_eps=[LLAMA_EP])
|
||||
cfg.endpoints = []
|
||||
|
||||
async def available(ep, *_):
|
||||
# llama-server advertises without :latest
|
||||
return {"gpt-4o"}
|
||||
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(router.fetch, "available_models", side_effect=available),
|
||||
patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"gpt-4o"})),
|
||||
):
|
||||
ep, _ = await router.choose_endpoint("gpt-4o:latest")
|
||||
assert ep == LLAMA_EP
|
||||
|
||||
async def test_adds_latest_for_ollama_when_bare_name(self):
|
||||
cfg = _make_cfg([EP1])
|
||||
|
||||
async def available(ep, *_):
|
||||
return {"llama3.2:latest"}
|
||||
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(router.fetch, "available_models", side_effect=available),
|
||||
patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"llama3.2:latest"})),
|
||||
):
|
||||
ep, _ = await router.choose_endpoint("llama3.2")
|
||||
assert ep == EP1
|
||||
|
||||
|
||||
class TestChooseEndpointLoadBalancing:
|
||||
async def test_random_selection_among_idle(self):
|
||||
cfg = _make_cfg([EP1, EP2, EP3])
|
||||
selected = set()
|
||||
|
||||
async def available(ep, *_):
|
||||
return {"model:latest"}
|
||||
|
||||
async def loaded(ep):
|
||||
return {"model:latest"}
|
||||
|
||||
for _ in range(20):
|
||||
router.usage_counts.clear()
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(router.fetch, "available_models", side_effect=available),
|
||||
patch.object(router.fetch, "loaded_models", side_effect=loaded),
|
||||
):
|
||||
ep, _ = await router.choose_endpoint("model:latest", reserve=False)
|
||||
selected.add(ep)
|
||||
|
||||
# With 20 draws from 3 idle endpoints, all three should appear
|
||||
assert len(selected) > 1
|
||||
|
||||
async def test_sort_by_load_ascending(self):
|
||||
cfg = _make_cfg([EP1, EP2])
|
||||
router.usage_counts[EP1]["model:latest"] = 1
|
||||
router.usage_counts[EP2]["model:latest"] = 0
|
||||
|
||||
async def available(ep, *_):
|
||||
return {"model:latest"}
|
||||
|
||||
async def loaded(ep):
|
||||
return {"model:latest"}
|
||||
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(router.fetch, "available_models", side_effect=available),
|
||||
patch.object(router.fetch, "loaded_models", side_effect=loaded),
|
||||
):
|
||||
ep, _ = await router.choose_endpoint("model:latest", reserve=False)
|
||||
# EP2 has fewer active connections → should be selected
|
||||
assert ep == EP2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_max_connections unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetMaxConnections:
|
||||
def test_returns_global_default_when_no_override(self):
|
||||
cfg = _make_cfg([EP1, EP2], max_conn=3)
|
||||
with patch.object(router, "config", cfg):
|
||||
assert router.get_max_connections(EP1) == 3
|
||||
assert router.get_max_connections(EP2) == 3
|
||||
|
||||
def test_returns_per_endpoint_override(self):
|
||||
cfg = _make_cfg(
|
||||
[EP1, EP2],
|
||||
max_conn=2,
|
||||
endpoint_config={EP1: {"max_concurrent_connections": 5}},
|
||||
)
|
||||
with patch.object(router, "config", cfg):
|
||||
assert router.get_max_connections(EP1) == 5
|
||||
assert router.get_max_connections(EP2) == 2 # falls back to global
|
||||
|
||||
def test_unrecognised_endpoint_falls_back_to_global(self):
|
||||
cfg = _make_cfg([EP1], max_conn=4, endpoint_config={EP2: {"max_concurrent_connections": 1}})
|
||||
with patch.object(router, "config", cfg):
|
||||
assert router.get_max_connections(EP3) == 4
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Priority / WRR routing tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MODEL = "model:latest"
|
||||
|
||||
|
||||
def _all_loaded(ep):
|
||||
"""Side-effect: every endpoint advertises and has MODEL loaded."""
|
||||
return {MODEL}
|
||||
|
||||
|
||||
class TestPriorityRouting:
|
||||
"""Tests for priority_routing=True (WRR + config-order tiebreaking)."""
|
||||
|
||||
async def test_idle_picks_first_in_config_order(self):
|
||||
"""When all endpoints are idle, priority picks the first listed endpoint."""
|
||||
cfg = _make_cfg([EP1, EP2, EP3], priority_routing=True)
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
|
||||
patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
|
||||
):
|
||||
ep, _ = await router.choose_endpoint(MODEL, reserve=False)
|
||||
assert ep == EP1
|
||||
|
||||
async def test_lower_utilization_preferred_over_priority(self):
|
||||
"""An endpoint with lower ratio is preferred even if it has lower priority."""
|
||||
cfg = _make_cfg([EP1, EP2], priority_routing=True)
|
||||
# EP1 (priority 0) is busier: 1/2 = 0.5; EP2 (priority 1) is idle: 0/2 = 0.0
|
||||
router.usage_counts[EP1][MODEL] = 1
|
||||
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
|
||||
patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
|
||||
):
|
||||
ep, _ = await router.choose_endpoint(MODEL, reserve=False)
|
||||
assert ep == EP2
|
||||
|
||||
async def test_wrr_distribution_matches_expected_sequence(self):
|
||||
"""
|
||||
Full WRR sequence with heterogeneous capacities, mirroring the issue example:
|
||||
EP1 max=2, EP2 max=2, EP3 max=1
|
||||
|
||||
Expected routing order for 5 sequential requests:
|
||||
EP1, EP2, EP3, EP1, EP2
|
||||
"""
|
||||
cfg = _make_cfg(
|
||||
[EP1, EP2, EP3],
|
||||
max_conn=2,
|
||||
endpoint_config={EP3: {"max_concurrent_connections": 1}},
|
||||
priority_routing=True,
|
||||
)
|
||||
|
||||
expected = [EP1, EP2, EP3, EP1, EP2]
|
||||
actual = []
|
||||
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
|
||||
patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
|
||||
):
|
||||
for _ in expected:
|
||||
ep, _ = await router.choose_endpoint(MODEL, reserve=True)
|
||||
actual.append(ep)
|
||||
|
||||
assert actual == expected
|
||||
|
||||
async def test_saturated_picks_lowest_ratio_then_priority(self):
|
||||
"""When all endpoints are saturated, pick lowest utilization ratio; break ties by priority."""
|
||||
cfg = _make_cfg(
|
||||
[EP1, EP2, EP3],
|
||||
max_conn=1,
|
||||
endpoint_config={EP3: {"max_concurrent_connections": 2}},
|
||||
priority_routing=True,
|
||||
)
|
||||
# EP1 usage=1/1=1.0, EP2 usage=1/1=1.0, EP3 usage=1/2=0.5 → EP3 wins
|
||||
router.usage_counts[EP1][MODEL] = 1
|
||||
router.usage_counts[EP2][MODEL] = 1
|
||||
router.usage_counts[EP3][MODEL] = 1
|
||||
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
|
||||
patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
|
||||
):
|
||||
ep, _ = await router.choose_endpoint(MODEL, reserve=False)
|
||||
assert ep == EP3
|
||||
|
||||
async def test_saturated_ties_broken_by_priority(self):
|
||||
"""When all are saturated with equal ratio, config order wins."""
|
||||
cfg = _make_cfg([EP1, EP2, EP3], max_conn=1, priority_routing=True)
|
||||
router.usage_counts[EP1][MODEL] = 1
|
||||
router.usage_counts[EP2][MODEL] = 1
|
||||
router.usage_counts[EP3][MODEL] = 1
|
||||
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
|
||||
patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
|
||||
):
|
||||
ep, _ = await router.choose_endpoint(MODEL, reserve=False)
|
||||
assert ep == EP1
|
||||
|
||||
|
||||
class TestPriorityRoutingDisabled:
|
||||
"""Verify that priority_routing=False keeps the original random behaviour."""
|
||||
|
||||
async def test_idle_endpoints_are_randomised(self):
|
||||
"""Without priority routing, all-idle selection must eventually pick each endpoint."""
|
||||
cfg = _make_cfg([EP1, EP2, EP3], priority_routing=False)
|
||||
selected = set()
|
||||
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
|
||||
patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
|
||||
):
|
||||
for _ in range(30):
|
||||
router.usage_counts.clear()
|
||||
ep, _ = await router.choose_endpoint(MODEL, reserve=False)
|
||||
selected.add(ep)
|
||||
|
||||
# With 30 draws from 3 equally-idle endpoints, all three must appear
|
||||
assert selected == {EP1, EP2, EP3}
|
||||
197
test/test_db.py
Normal file
197
test/test_db.py
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
"""Direct unit tests for db.TokenDatabase — no router/app dependency."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from db import TokenDatabase
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db(tmp_path):
|
||||
inst = TokenDatabase(str(tmp_path / "tokens.db"))
|
||||
await inst.init_db()
|
||||
yield inst
|
||||
await inst.close()
|
||||
|
||||
|
||||
class TestInit:
|
||||
async def test_init_creates_tables(self, db):
|
||||
# Re-init must be idempotent
|
||||
await db.init_db()
|
||||
# Insert + read confirms tables exist
|
||||
await db.update_token_counts("http://ep", "m", 1, 2)
|
||||
rows = [r async for r in db.load_token_counts()]
|
||||
assert len(rows) == 1
|
||||
|
||||
async def test_creates_parent_directory(self, tmp_path):
|
||||
nested = tmp_path / "nested" / "subdir" / "x.db"
|
||||
inst = TokenDatabase(str(nested))
|
||||
await inst.init_db()
|
||||
try:
|
||||
assert nested.parent.exists()
|
||||
finally:
|
||||
await inst.close()
|
||||
|
||||
|
||||
class TestUpdateTokenCounts:
|
||||
async def test_insert_then_update_aggregates(self, db):
|
||||
await db.update_token_counts("http://ep", "m1", 10, 20)
|
||||
await db.update_token_counts("http://ep", "m1", 5, 7)
|
||||
rows = [r async for r in db.load_token_counts()]
|
||||
assert len(rows) == 1
|
||||
r = rows[0]
|
||||
assert r["endpoint"] == "http://ep"
|
||||
assert r["model"] == "m1"
|
||||
assert r["input_tokens"] == 15
|
||||
assert r["output_tokens"] == 27
|
||||
assert r["total_tokens"] == 42
|
||||
|
||||
async def test_independent_endpoint_model_pairs(self, db):
|
||||
await db.update_token_counts("http://ep1", "m1", 1, 1)
|
||||
await db.update_token_counts("http://ep1", "m2", 2, 2)
|
||||
await db.update_token_counts("http://ep2", "m1", 3, 3)
|
||||
rows = [r async for r in db.load_token_counts()]
|
||||
assert len(rows) == 3
|
||||
totals = {(r["endpoint"], r["model"]): r["total_tokens"] for r in rows}
|
||||
assert totals == {
|
||||
("http://ep1", "m1"): 2,
|
||||
("http://ep1", "m2"): 4,
|
||||
("http://ep2", "m1"): 6,
|
||||
}
|
||||
|
||||
|
||||
class TestBatchedCounts:
|
||||
async def test_update_batched_counts(self, db):
|
||||
counts = {
|
||||
"http://a": {"m": (4, 6)},
|
||||
"http://b": {"m": (1, 1), "n": (10, 0)},
|
||||
}
|
||||
await db.update_batched_counts(counts)
|
||||
rows = [r async for r in db.load_token_counts()]
|
||||
totals = {(r["endpoint"], r["model"]): r["total_tokens"] for r in rows}
|
||||
assert totals == {
|
||||
("http://a", "m"): 10,
|
||||
("http://b", "m"): 2,
|
||||
("http://b", "n"): 10,
|
||||
}
|
||||
|
||||
async def test_empty_batch_is_noop(self, db):
|
||||
await db.update_batched_counts({})
|
||||
rows = [r async for r in db.load_token_counts()]
|
||||
assert rows == []
|
||||
|
||||
|
||||
class TestTimeSeries:
|
||||
async def test_add_time_series_entry(self, db):
|
||||
# The aggregate FK requires the (endpoint,model) row to exist first
|
||||
await db.update_token_counts("http://ep", "m", 0, 0)
|
||||
await db.add_time_series_entry("http://ep", "m", 3, 4)
|
||||
await db.add_time_series_entry("http://ep", "m", 1, 1)
|
||||
rows = [r async for r in db.get_latest_time_series(limit=10)]
|
||||
assert len(rows) == 2
|
||||
# Newest-first ordering; both timestamps are within the same minute,
|
||||
# so just check totals are present and well-formed
|
||||
for r in rows:
|
||||
assert r["endpoint"] == "http://ep"
|
||||
assert r["model"] == "m"
|
||||
assert r["total_tokens"] == r["input_tokens"] + r["output_tokens"]
|
||||
|
||||
async def test_add_batched_time_series(self, db):
|
||||
await db.update_token_counts("http://ep", "m", 0, 0)
|
||||
now = int(datetime.now(tz=timezone.utc).timestamp())
|
||||
entries = [
|
||||
{"endpoint": "http://ep", "model": "m", "input_tokens": 1,
|
||||
"output_tokens": 2, "total_tokens": 3, "timestamp": now - 60},
|
||||
{"endpoint": "http://ep", "model": "m", "input_tokens": 4,
|
||||
"output_tokens": 5, "total_tokens": 9, "timestamp": now},
|
||||
]
|
||||
await db.add_batched_time_series(entries)
|
||||
rows = [r async for r in db.get_latest_time_series(limit=10)]
|
||||
assert len(rows) == 2
|
||||
assert rows[0]["timestamp"] >= rows[1]["timestamp"]
|
||||
|
||||
async def test_get_time_series_for_model_filters(self, db):
|
||||
await db.update_token_counts("http://ep", "m1", 0, 0)
|
||||
await db.update_token_counts("http://ep", "m2", 0, 0)
|
||||
now = int(datetime.now(tz=timezone.utc).timestamp())
|
||||
await db.add_batched_time_series([
|
||||
{"endpoint": "http://ep", "model": "m1", "input_tokens": 1,
|
||||
"output_tokens": 1, "total_tokens": 2, "timestamp": now},
|
||||
{"endpoint": "http://ep", "model": "m2", "input_tokens": 9,
|
||||
"output_tokens": 9, "total_tokens": 18, "timestamp": now},
|
||||
])
|
||||
rows = [r async for r in db.get_time_series_for_model("m1")]
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["total_tokens"] == 2
|
||||
|
||||
async def test_endpoint_distribution_for_model(self, db):
|
||||
await db.update_token_counts("http://a", "m", 0, 0)
|
||||
await db.update_token_counts("http://b", "m", 0, 0)
|
||||
now = int(datetime.now(tz=timezone.utc).timestamp())
|
||||
await db.add_batched_time_series([
|
||||
{"endpoint": "http://a", "model": "m", "input_tokens": 1,
|
||||
"output_tokens": 1, "total_tokens": 2, "timestamp": now},
|
||||
{"endpoint": "http://a", "model": "m", "input_tokens": 1,
|
||||
"output_tokens": 1, "total_tokens": 2, "timestamp": now},
|
||||
{"endpoint": "http://b", "model": "m", "input_tokens": 5,
|
||||
"output_tokens": 5, "total_tokens": 10, "timestamp": now},
|
||||
])
|
||||
dist = await db.get_endpoint_distribution_for_model("m")
|
||||
assert dist == {"http://a": 4, "http://b": 10}
|
||||
|
||||
|
||||
class TestGetTokenCountsForModel:
|
||||
async def test_aggregates_across_endpoints(self, db):
|
||||
await db.update_token_counts("http://a", "m", 1, 2)
|
||||
await db.update_token_counts("http://b", "m", 3, 4)
|
||||
result = await db.get_token_counts_for_model("m")
|
||||
assert result is not None
|
||||
assert result["endpoint"] == "aggregated"
|
||||
assert result["model"] == "m"
|
||||
assert result["input_tokens"] == 4
|
||||
assert result["output_tokens"] == 6
|
||||
assert result["total_tokens"] == 10
|
||||
|
||||
async def test_unknown_model_returns_zero_aggregate(self, db):
|
||||
# SUM(...) WHERE no-match returns one row with NULLs — exposed as zeros
|
||||
result = await db.get_token_counts_for_model("nope")
|
||||
assert result is not None
|
||||
assert result["input_tokens"] in (0, None)
|
||||
|
||||
|
||||
class TestAggregateTimeSeriesOlderThan:
|
||||
async def test_aggregates_old_entries_by_day(self, db):
|
||||
await db.update_token_counts("http://ep", "m", 0, 0)
|
||||
now = int(datetime.now(tz=timezone.utc).timestamp())
|
||||
old = now - (40 * 86400) # 40 days ago
|
||||
await db.add_batched_time_series([
|
||||
{"endpoint": "http://ep", "model": "m", "input_tokens": 1,
|
||||
"output_tokens": 1, "total_tokens": 2, "timestamp": old},
|
||||
{"endpoint": "http://ep", "model": "m", "input_tokens": 3,
|
||||
"output_tokens": 3, "total_tokens": 6, "timestamp": old + 60},
|
||||
{"endpoint": "http://ep", "model": "m", "input_tokens": 99,
|
||||
"output_tokens": 99, "total_tokens": 198, "timestamp": now},
|
||||
])
|
||||
n = await db.aggregate_time_series_older_than(30, trim_old=False)
|
||||
assert n == 1 # one (endpoint, model, day) group rolled up
|
||||
|
||||
async def test_invalid_days_falls_back_to_30(self, db):
|
||||
# Just ensure it doesn't blow up with a bogus value
|
||||
n = await db.aggregate_time_series_older_than(0)
|
||||
assert n == 0
|
||||
|
||||
async def test_trim_old_removes_aggregated_rows(self, db):
|
||||
await db.update_token_counts("http://ep", "m", 0, 0)
|
||||
now = int(datetime.now(tz=timezone.utc).timestamp())
|
||||
old = now - (40 * 86400)
|
||||
await db.add_batched_time_series([
|
||||
{"endpoint": "http://ep", "model": "m", "input_tokens": 1,
|
||||
"output_tokens": 1, "total_tokens": 2, "timestamp": old},
|
||||
{"endpoint": "http://ep", "model": "m", "input_tokens": 99,
|
||||
"output_tokens": 99, "total_tokens": 198, "timestamp": now},
|
||||
])
|
||||
await db.aggregate_time_series_older_than(30, trim_old=True)
|
||||
remaining = [r async for r in db.get_latest_time_series(limit=10)]
|
||||
# Only the recent (within-cutoff) row should remain
|
||||
assert len(remaining) == 1
|
||||
assert remaining[0]["total_tokens"] == 198
|
||||
180
test/test_fetch.py
Normal file
180
test/test_fetch.py
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
"""Tests for fetch.available_models and fetch.loaded_models using aioresponses mocking."""
|
||||
import time
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
from aioresponses import aioresponses
|
||||
|
||||
import router
|
||||
from conftest import TEST_OLLAMA, TEST_LLAMA
|
||||
|
||||
MOCK_OLLAMA_EP = "http://mock-ollama:11434"
|
||||
MOCK_LLAMA_EP = "http://mock-llama:8080/v1"
|
||||
|
||||
|
||||
def _make_cfg(ollama_eps=None, llama_eps=None, api_keys=None):
|
||||
cfg = MagicMock()
|
||||
cfg.endpoints = ollama_eps or [MOCK_OLLAMA_EP]
|
||||
cfg.llama_server_endpoints = llama_eps or [MOCK_LLAMA_EP]
|
||||
cfg.api_keys = api_keys or {}
|
||||
cfg.max_concurrent_connections = 2
|
||||
cfg.router_api_key = None
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_caches(aio_session):
|
||||
"""aio_session fixture already clears caches and sets up app_state."""
|
||||
yield
|
||||
|
||||
|
||||
class TestFetchAvailableModels:
|
||||
async def test_ollama_tags(self):
|
||||
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(
|
||||
f"{MOCK_OLLAMA_EP}/api/tags",
|
||||
payload={"models": [
|
||||
{"name": "llama3.2:latest"},
|
||||
{"name": "qwen2.5:7b"},
|
||||
]},
|
||||
)
|
||||
models = await router.fetch.available_models(MOCK_OLLAMA_EP)
|
||||
assert models == {"llama3.2:latest", "qwen2.5:7b"}
|
||||
|
||||
async def test_openai_compatible_models_endpoint(self):
|
||||
cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP])
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(
|
||||
f"{MOCK_LLAMA_EP}/models",
|
||||
payload={"data": [{"id": "unsloth/model:Q8_0"}]},
|
||||
)
|
||||
models = await router.fetch.available_models(MOCK_LLAMA_EP, api_key="tok")
|
||||
assert "unsloth/model:Q8_0" in models
|
||||
|
||||
async def test_caches_successful_result(self):
|
||||
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(
|
||||
f"{MOCK_OLLAMA_EP}/api/tags",
|
||||
payload={"models": [{"name": "llama3.2:latest"}]},
|
||||
)
|
||||
first = await router.fetch.available_models(MOCK_OLLAMA_EP)
|
||||
second = await router.fetch.available_models(MOCK_OLLAMA_EP)
|
||||
# second call must be served from cache without a second HTTP request
|
||||
assert first == second == {"llama3.2:latest"}
|
||||
|
||||
async def test_returns_empty_on_http_500(self):
|
||||
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(f"{MOCK_OLLAMA_EP}/api/tags", status=500, payload={"error": "oops"})
|
||||
models = await router.fetch.available_models(MOCK_OLLAMA_EP)
|
||||
assert models == set()
|
||||
|
||||
async def test_returns_empty_on_connection_error(self):
|
||||
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
|
||||
import aiohttp
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(
|
||||
f"{MOCK_OLLAMA_EP}/api/tags",
|
||||
exception=aiohttp.ClientConnectorError(
|
||||
connection_key=MagicMock(host="mock-ollama", port=11434),
|
||||
os_error=OSError(111, "refused"),
|
||||
),
|
||||
)
|
||||
models = await router.fetch.available_models(MOCK_OLLAMA_EP)
|
||||
assert models == set()
|
||||
|
||||
async def test_stale_cache_returned_while_refresh_runs(self):
|
||||
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(
|
||||
f"{MOCK_OLLAMA_EP}/api/tags",
|
||||
payload={"models": [{"name": "llama3.2:latest"}]},
|
||||
)
|
||||
await router.fetch.available_models(MOCK_OLLAMA_EP)
|
||||
|
||||
# Manually age cache into stale-but-valid window (300-600s)
|
||||
async with router._models_cache_lock:
|
||||
models, _ = router._models_cache[MOCK_OLLAMA_EP]
|
||||
router._models_cache[MOCK_OLLAMA_EP] = (models, time.time() - 400)
|
||||
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(
|
||||
f"{MOCK_OLLAMA_EP}/api/tags",
|
||||
payload={"models": [{"name": "llama3.2:latest"}]},
|
||||
)
|
||||
# Should return stale data immediately
|
||||
stale = await router.fetch.available_models(MOCK_OLLAMA_EP)
|
||||
assert "llama3.2:latest" in stale
|
||||
|
||||
async def test_error_cache_short_circuits(self):
|
||||
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
|
||||
# Seed error cache with a very recent error
|
||||
async with router._available_error_cache_lock:
|
||||
router._available_error_cache[MOCK_OLLAMA_EP] = time.time()
|
||||
|
||||
with patch.object(router, "config", cfg), aioresponses():
|
||||
# No HTTP mock registered — if a call happens it will raise
|
||||
models = await router.fetch.available_models(MOCK_OLLAMA_EP)
|
||||
assert models == set()
|
||||
|
||||
|
||||
class TestFetchLoadedModels:
|
||||
async def test_ollama_ps(self):
|
||||
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(
|
||||
f"{MOCK_OLLAMA_EP}/api/ps",
|
||||
payload={"models": [{"name": "llama3.2:latest"}]},
|
||||
)
|
||||
models = await router.fetch.loaded_models(MOCK_OLLAMA_EP)
|
||||
assert models == {"llama3.2:latest"}
|
||||
|
||||
async def test_llama_server_filters_loaded(self):
|
||||
cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP])
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(
|
||||
f"{MOCK_LLAMA_EP}/models",
|
||||
payload={"data": [
|
||||
{"id": "model-a", "status": {"value": "loaded"}},
|
||||
{"id": "model-b", "status": {"value": "unloaded"}},
|
||||
]},
|
||||
)
|
||||
models = await router.fetch.loaded_models(MOCK_LLAMA_EP)
|
||||
assert models == {"model-a"}
|
||||
|
||||
async def test_llama_server_no_status_field_always_loaded(self):
|
||||
cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP])
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(
|
||||
f"{MOCK_LLAMA_EP}/models",
|
||||
payload={"data": [{"id": "always-on-model"}]},
|
||||
)
|
||||
models = await router.fetch.loaded_models(MOCK_LLAMA_EP)
|
||||
assert "always-on-model" in models
|
||||
|
||||
async def test_returns_empty_on_error(self):
|
||||
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(f"{MOCK_OLLAMA_EP}/api/ps", status=503, payload={})
|
||||
models = await router.fetch.loaded_models(MOCK_OLLAMA_EP)
|
||||
assert models == set()
|
||||
|
||||
async def test_ext_openai_always_empty(self):
|
||||
ext_ep = "https://api.openai.com/v1"
|
||||
cfg = _make_cfg(ollama_eps=[ext_ep], llama_eps=[])
|
||||
with patch.object(router, "config", cfg):
|
||||
models = await router.fetch.loaded_models(ext_ep)
|
||||
assert models == set()
|
||||
|
||||
async def test_caches_result(self):
|
||||
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(
|
||||
f"{MOCK_OLLAMA_EP}/api/ps",
|
||||
payload={"models": [{"name": "qwen:7b"}]},
|
||||
)
|
||||
first = await router.fetch.loaded_models(MOCK_OLLAMA_EP)
|
||||
second = await router.fetch.loaded_models(MOCK_OLLAMA_EP)
|
||||
assert first == second
|
||||
181
test/test_openai_proxies.py
Normal file
181
test/test_openai_proxies.py
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
"""Cache-hit short-circuit tests for the OpenAI-compatible proxy routes.
|
||||
|
||||
These tests verify that when the LLM cache reports a hit, the route returns
|
||||
the cached payload *without* selecting an endpoint or contacting any backend.
|
||||
"""
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
import router
|
||||
|
||||
|
||||
_BYPASS = HTTPException(status_code=599, detail="bypassed")
|
||||
|
||||
|
||||
class _FakeCache:
|
||||
"""Minimal stand-in for cache.LLMCache.get_chat."""
|
||||
def __init__(self, response_bytes: bytes | None):
|
||||
self._resp = response_bytes
|
||||
self.calls: list[tuple] = []
|
||||
|
||||
async def get_chat(self, route, model, messages):
|
||||
self.calls.append((route, model, messages))
|
||||
return self._resp
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cache_hit_payload():
|
||||
return orjson.dumps({
|
||||
"id": "cmpl-xyz",
|
||||
"created": 1,
|
||||
"model": "test-model",
|
||||
"choices": [{"message": {"role": "assistant", "content": "from-cache"}}],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
})
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# /v1/chat/completions
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestOpenAIChatCompletionsCacheHit:
|
||||
async def test_nonstream_cache_hit_returns_cached_json(self, client, cache_hit_payload):
|
||||
fake = _FakeCache(cache_hit_payload)
|
||||
# Patch the route's references to both helpers — they're imported by name
|
||||
# into router's namespace at module load time.
|
||||
with (
|
||||
patch.object(router, "get_llm_cache", return_value=fake),
|
||||
patch.object(router, "choose_endpoint",
|
||||
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
|
||||
):
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"messages": [{"role": "user", "content": "ping"}],
|
||||
"stream": False,
|
||||
"nomyo": {"cache": True},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
# Body is streamed; collect it
|
||||
body = resp.content
|
||||
parsed = orjson.loads(body)
|
||||
assert parsed["choices"][0]["message"]["content"] == "from-cache"
|
||||
assert fake.calls and fake.calls[0][0] == "openai_chat"
|
||||
|
||||
async def test_stream_cache_hit_returns_sse(self, client, cache_hit_payload):
|
||||
fake = _FakeCache(cache_hit_payload)
|
||||
with (
|
||||
patch.object(router, "get_llm_cache", return_value=fake),
|
||||
patch.object(router, "choose_endpoint",
|
||||
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
|
||||
):
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"messages": [{"role": "user", "content": "ping"}],
|
||||
"stream": True,
|
||||
"nomyo": {"cache": True},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["content-type"].startswith("text/event-stream")
|
||||
text = resp.content.decode()
|
||||
# First SSE frame contains the cached content as a delta
|
||||
first_frame = text.split("\n\n")[0]
|
||||
assert first_frame.startswith("data: ")
|
||||
chunk = orjson.loads(first_frame[len("data: "):])
|
||||
assert chunk["choices"][0]["delta"]["content"] == "from-cache"
|
||||
# Stream is terminated with [DONE]
|
||||
assert "data: [DONE]" in text
|
||||
|
||||
async def test_cache_disabled_in_payload_bypasses_cache_check(self, client):
|
||||
"""When nomyo.cache=False, get_chat is never called even if a cache exists."""
|
||||
fake = _FakeCache(b"") # has a response, but should never be consulted
|
||||
with (
|
||||
patch.object(router, "get_llm_cache", return_value=fake),
|
||||
patch.object(router, "choose_endpoint",
|
||||
AsyncMock(side_effect=_BYPASS)),
|
||||
):
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "m",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"nomyo": {"cache": False},
|
||||
},
|
||||
)
|
||||
# Got past the cache short-circuit → endpoint selection invoked
|
||||
assert resp.status_code == 599
|
||||
assert fake.calls == []
|
||||
|
||||
async def test_no_cache_configured_bypasses_cache_check(self, client):
|
||||
"""get_llm_cache() returning None should not break the route."""
|
||||
with (
|
||||
patch.object(router, "get_llm_cache", return_value=None),
|
||||
patch.object(router, "choose_endpoint",
|
||||
AsyncMock(side_effect=_BYPASS)),
|
||||
):
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "m",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"nomyo": {"cache": True},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 599
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# /v1/completions
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestOpenAICompletionsCacheHit:
|
||||
async def test_nonstream_cache_hit(self, client, cache_hit_payload):
|
||||
fake = _FakeCache(cache_hit_payload)
|
||||
with (
|
||||
patch.object(router, "get_llm_cache", return_value=fake),
|
||||
patch.object(router, "choose_endpoint",
|
||||
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
|
||||
):
|
||||
resp = await client.post(
|
||||
"/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": "Tell me a joke",
|
||||
"stream": False,
|
||||
"nomyo": {"cache": True},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
# Prompt-style cache lookup is namespaced under "openai_completions"
|
||||
assert fake.calls[0][0] == "openai_completions"
|
||||
# Cache lookup receives the prompt as a single user message
|
||||
cached_msgs = fake.calls[0][2]
|
||||
assert cached_msgs == [{"role": "user", "content": "Tell me a joke"}]
|
||||
|
||||
async def test_stream_cache_hit(self, client, cache_hit_payload):
|
||||
fake = _FakeCache(cache_hit_payload)
|
||||
with (
|
||||
patch.object(router, "get_llm_cache", return_value=fake),
|
||||
patch.object(router, "choose_endpoint",
|
||||
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
|
||||
):
|
||||
resp = await client.post(
|
||||
"/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": "What is 2+2?",
|
||||
"stream": True,
|
||||
"nomyo": {"cache": True},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["content-type"].startswith("text/event-stream")
|
||||
assert "data: [DONE]" in resp.content.decode()
|
||||
116
test/test_unit_context.py
Normal file
116
test/test_unit_context.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
"""Unit tests for context-window trimming logic."""
|
||||
import pytest
|
||||
import router
|
||||
|
||||
|
||||
def _msgs(roles_contents):
|
||||
return [{"role": r, "content": c} for r, c in roles_contents]
|
||||
|
||||
|
||||
class TestCountMessageTokens:
|
||||
def test_returns_int(self):
|
||||
msgs = _msgs([("user", "hello")])
|
||||
assert isinstance(router._count_message_tokens(msgs), int)
|
||||
|
||||
def test_empty_list(self):
|
||||
assert router._count_message_tokens([]) >= 0
|
||||
|
||||
def test_longer_content_more_tokens(self):
|
||||
short = _msgs([("user", "hi")])
|
||||
long_ = _msgs([("user", "a " * 500)])
|
||||
assert router._count_message_tokens(long_) > router._count_message_tokens(short)
|
||||
|
||||
def test_list_content(self):
|
||||
msgs = [{"role": "user", "content": [
|
||||
{"type": "text", "text": "what do you see?"},
|
||||
]}]
|
||||
tokens = router._count_message_tokens(msgs)
|
||||
assert tokens > 0
|
||||
|
||||
def test_multiple_messages(self):
|
||||
msgs = _msgs([("system", "you are helpful"), ("user", "hello"), ("assistant", "hi!")])
|
||||
assert router._count_message_tokens(msgs) > 10
|
||||
|
||||
|
||||
class TestTrimMessagesForContext:
|
||||
def test_short_history_unchanged(self):
|
||||
msgs = _msgs([("user", "hello"), ("assistant", "hi"), ("user", "bye")])
|
||||
result = router._trim_messages_for_context(msgs, n_ctx=4096)
|
||||
assert result == msgs
|
||||
|
||||
def test_system_messages_always_kept(self):
|
||||
msgs = (
|
||||
_msgs([("system", "you are helpful")])
|
||||
+ _msgs([("user", f"msg {i}") for i in range(50)])
|
||||
+ _msgs([("user", "final question")])
|
||||
)
|
||||
result = router._trim_messages_for_context(msgs, n_ctx=512)
|
||||
system_msgs = [m for m in result if m["role"] == "system"]
|
||||
assert len(system_msgs) == 1
|
||||
assert system_msgs[0]["content"] == "you are helpful"
|
||||
|
||||
def test_last_user_message_always_kept(self):
|
||||
msgs = _msgs([("user", f"old msg {i}") for i in range(100)] + [("user", "very important last question")])
|
||||
result = router._trim_messages_for_context(msgs, n_ctx=256)
|
||||
assert result[-1]["content"] == "very important last question"
|
||||
|
||||
def test_oldest_dropped_first(self):
|
||||
msgs = _msgs([
|
||||
("user", "oldest msg"),
|
||||
("assistant", "oldest reply"),
|
||||
("user", "newer msg"),
|
||||
("assistant", "newer reply"),
|
||||
("user", "newest"),
|
||||
])
|
||||
# Use very small target to force trimming
|
||||
result = router._trim_messages_for_context(msgs, n_ctx=256, target_tokens=10)
|
||||
contents = [m["content"] for m in result]
|
||||
# "oldest msg" should be dropped before "newest"
|
||||
if "oldest msg" in contents:
|
||||
assert "newest" in contents
|
||||
else:
|
||||
assert "newest" in contents
|
||||
|
||||
def test_result_starts_with_user(self):
|
||||
msgs = _msgs([
|
||||
("assistant", "leftover assistant"),
|
||||
("user", "question"),
|
||||
])
|
||||
result = router._trim_messages_for_context(msgs, n_ctx=256, target_tokens=20)
|
||||
if result:
|
||||
assert result[0]["role"] == "user"
|
||||
|
||||
def test_target_tokens_overrides_safety_margin(self):
|
||||
msgs = _msgs([("user", "a " * 200)])
|
||||
result_small = router._trim_messages_for_context(msgs, n_ctx=8192, target_tokens=10)
|
||||
result_large = router._trim_messages_for_context(msgs, n_ctx=8192, target_tokens=5000)
|
||||
# Both should return at least the last message
|
||||
assert len(result_small) >= 1
|
||||
assert len(result_large) >= 1
|
||||
|
||||
|
||||
class TestCalibratedTrimTarget:
|
||||
def test_returns_positive_int(self):
|
||||
msgs = [{"role": "user", "content": "hello " * 100}]
|
||||
result = router._calibrated_trim_target(msgs, n_ctx=4096, actual_tokens=3000)
|
||||
assert isinstance(result, int)
|
||||
assert result >= 1
|
||||
|
||||
def test_over_limit_reduces_target(self):
|
||||
msgs = [{"role": "user", "content": "a " * 500}]
|
||||
# actual_tokens > n_ctx means we need to shed more
|
||||
target = router._calibrated_trim_target(msgs, n_ctx=2048, actual_tokens=2500)
|
||||
assert target < router._count_message_tokens(msgs)
|
||||
|
||||
def test_well_within_limit_returns_current(self):
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
# actual_tokens << n_ctx means nothing to shed
|
||||
target = router._calibrated_trim_target(msgs, n_ctx=16384, actual_tokens=50)
|
||||
# Should return cur_tiktoken since to_shed == 0
|
||||
assert target == max(1, router._count_message_tokens(msgs))
|
||||
|
||||
def test_minimum_is_one(self):
|
||||
# Even if we need to shed everything, result is at least 1
|
||||
msgs = [{"role": "user", "content": "hello"}]
|
||||
target = router._calibrated_trim_target(msgs, n_ctx=100, actual_tokens=99999)
|
||||
assert target >= 1
|
||||
279
test/test_unit_helpers.py
Normal file
279
test/test_unit_helpers.py
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
"""Unit tests for pure helper functions in router.py (no network, no app)."""
|
||||
import time
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
import router
|
||||
|
||||
|
||||
class TestMaskSecrets:
|
||||
def test_masks_openai_key(self):
|
||||
text = "Authorization: Bearer sk-abcd1234XYZabcd1234XYZabcd1234XYZ"
|
||||
result = router._mask_secrets(text)
|
||||
assert "sk-***redacted***" in result
|
||||
assert "sk-abcd1234" not in result
|
||||
|
||||
def test_masks_api_key_assignment(self):
|
||||
result = router._mask_secrets("api_key: supersecretvalue123")
|
||||
assert "supersecretvalue123" not in result
|
||||
assert "***redacted***" in result
|
||||
|
||||
def test_masks_api_key_with_colon(self):
|
||||
result = router._mask_secrets("api-key: mykey")
|
||||
assert "mykey" not in result
|
||||
|
||||
def test_empty_string_returns_empty(self):
|
||||
assert router._mask_secrets("") == ""
|
||||
|
||||
def test_none_returns_none(self):
|
||||
assert router._mask_secrets(None) is None
|
||||
|
||||
def test_no_secrets_unchanged(self):
|
||||
text = "this is a normal log line"
|
||||
assert router._mask_secrets(text) == text
|
||||
|
||||
|
||||
class TestIsFresh:
|
||||
def test_fresh_within_ttl(self):
|
||||
cached_at = time.time() - 10
|
||||
assert router._is_fresh(cached_at, 300) is True
|
||||
|
||||
def test_expired_beyond_ttl(self):
|
||||
cached_at = time.time() - 400
|
||||
assert router._is_fresh(cached_at, 300) is False
|
||||
|
||||
def test_exactly_at_boundary(self):
|
||||
cached_at = time.time() - 300
|
||||
# May be True or False depending on timing, just verify it runs
|
||||
result = router._is_fresh(cached_at, 300)
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_just_cached(self):
|
||||
assert router._is_fresh(time.time(), 1) is True
|
||||
|
||||
|
||||
class TestNormalizeLlamaModelName:
|
||||
def test_strips_hf_prefix(self):
|
||||
assert router._normalize_llama_model_name("unsloth/gpt-oss-20b-GGUF") == "gpt-oss-20b-GGUF"
|
||||
|
||||
def test_strips_quant_suffix(self):
|
||||
assert router._normalize_llama_model_name("model:Q8_0") == "model"
|
||||
|
||||
def test_strips_both(self):
|
||||
result = router._normalize_llama_model_name("unsloth/gpt-oss-20b-GGUF:F16")
|
||||
assert result == "gpt-oss-20b-GGUF"
|
||||
|
||||
def test_no_prefix_no_suffix(self):
|
||||
assert router._normalize_llama_model_name("plain-model") == "plain-model"
|
||||
|
||||
def test_multiple_slashes(self):
|
||||
result = router._normalize_llama_model_name("org/user/model-name:Q4_K_M")
|
||||
assert result == "model-name"
|
||||
|
||||
|
||||
class TestExtractLlamaQuant:
|
||||
def test_extracts_quant(self):
|
||||
assert router._extract_llama_quant("unsloth/model:Q8_0") == "Q8_0"
|
||||
|
||||
def test_no_quant_returns_empty(self):
|
||||
assert router._extract_llama_quant("plain-model") == ""
|
||||
|
||||
def test_f16(self):
|
||||
assert router._extract_llama_quant("model:F16") == "F16"
|
||||
|
||||
def test_q4_k_m(self):
|
||||
assert router._extract_llama_quant("model:Q4_K_M") == "Q4_K_M"
|
||||
|
||||
|
||||
class TestIsUnixSocketEndpoint:
|
||||
def test_sock_endpoint_detected(self):
|
||||
assert router._is_unix_socket_endpoint("http://192.168.0.52.sock/v1") is True
|
||||
|
||||
def test_regular_http_not_sock(self):
|
||||
assert router._is_unix_socket_endpoint("http://192.168.0.52:8080/v1") is False
|
||||
|
||||
def test_ollama_not_sock(self):
|
||||
assert router._is_unix_socket_endpoint("http://localhost:11434") is False
|
||||
|
||||
def test_dot_sock_in_host_detected(self):
|
||||
assert router._is_unix_socket_endpoint("http://llama.sock/v1") is True
|
||||
|
||||
|
||||
class TestGetSocketPath:
|
||||
def test_returns_run_user_path(self):
|
||||
import os
|
||||
path = router._get_socket_path("http://192.168.0.52.sock/v1")
|
||||
uid = os.getuid()
|
||||
assert path == f"/run/user/{uid}/192.168.0.52.sock"
|
||||
|
||||
|
||||
class TestIsBase64:
|
||||
def test_valid_base64(self):
|
||||
import base64
|
||||
data = base64.b64encode(b"hello world").decode()
|
||||
assert router.is_base64(data) is True
|
||||
|
||||
def test_invalid_base64(self):
|
||||
assert router.is_base64("not-base64!@#$") is False
|
||||
|
||||
def test_empty_string(self):
|
||||
# Empty string is valid base64 (decodes to empty bytes)
|
||||
assert router.is_base64("") is True
|
||||
|
||||
def test_non_string(self):
|
||||
# Non-strings fall through without returning True (returns None)
|
||||
assert not router.is_base64(12345)
|
||||
|
||||
|
||||
class TestIsLlamaModelLoaded:
|
||||
def test_status_dict_loaded(self):
|
||||
assert router._is_llama_model_loaded({"id": "m", "status": {"value": "loaded"}}) is True
|
||||
|
||||
def test_status_dict_unloaded(self):
|
||||
assert router._is_llama_model_loaded({"id": "m", "status": {"value": "unloaded"}}) is False
|
||||
|
||||
def test_status_string_loaded(self):
|
||||
assert router._is_llama_model_loaded({"id": "m", "status": "loaded"}) is True
|
||||
|
||||
def test_status_string_unloaded(self):
|
||||
assert router._is_llama_model_loaded({"id": "m", "status": "unloaded"}) is False
|
||||
|
||||
def test_no_status_field_always_loaded(self):
|
||||
# No status field → always available (single-model server)
|
||||
assert router._is_llama_model_loaded({"id": "m"}) is True
|
||||
|
||||
def test_status_none_always_loaded(self):
|
||||
assert router._is_llama_model_loaded({"id": "m", "status": None}) is True
|
||||
|
||||
|
||||
class TestEp2Base:
|
||||
def test_adds_v1_to_ollama(self):
|
||||
assert router.ep2base("http://localhost:11434") == "http://localhost:11434/v1"
|
||||
|
||||
def test_keeps_v1_if_present(self):
|
||||
assert router.ep2base("http://host/v1") == "http://host/v1"
|
||||
|
||||
def test_llama_server_endpoint_unchanged(self):
|
||||
ep = "http://192.168.0.50:8889/v1"
|
||||
assert router.ep2base(ep) == ep
|
||||
|
||||
|
||||
class TestDedupeOnKeys:
|
||||
def test_removes_duplicate_by_single_key(self):
|
||||
items = [{"name": "a", "x": 1}, {"name": "b", "x": 2}, {"name": "a", "x": 3}]
|
||||
result = router.dedupe_on_keys(items, ["name"])
|
||||
assert len(result) == 2
|
||||
assert result[0]["name"] == "a"
|
||||
assert result[1]["name"] == "b"
|
||||
|
||||
def test_removes_duplicate_by_two_keys(self):
|
||||
items = [
|
||||
{"digest": "abc", "name": "m1"},
|
||||
{"digest": "abc", "name": "m1"},
|
||||
{"digest": "def", "name": "m2"},
|
||||
]
|
||||
result = router.dedupe_on_keys(items, ["digest", "name"])
|
||||
assert len(result) == 2
|
||||
|
||||
def test_empty_list(self):
|
||||
assert router.dedupe_on_keys([], ["name"]) == []
|
||||
|
||||
def test_no_duplicates(self):
|
||||
items = [{"name": "a"}, {"name": "b"}, {"name": "c"}]
|
||||
assert len(router.dedupe_on_keys(items, ["name"])) == 3
|
||||
|
||||
|
||||
class TestFormatConnectionIssue:
|
||||
def test_connector_error_message(self):
|
||||
err = aiohttp.ClientConnectorError(
|
||||
connection_key=MagicMock(host="localhost", port=11434),
|
||||
os_error=OSError(111, "Connection refused"),
|
||||
)
|
||||
msg = router._format_connection_issue("http://localhost:11434", err)
|
||||
assert "localhost" in msg
|
||||
assert "Connection refused" in msg or "111" in msg
|
||||
|
||||
def test_timeout_error_message(self):
|
||||
msg = router._format_connection_issue("http://host:1234", asyncio.TimeoutError())
|
||||
assert "Timed out" in msg
|
||||
assert "host:1234" in msg
|
||||
|
||||
def test_generic_error(self):
|
||||
msg = router._format_connection_issue("http://host:1234", ValueError("boom"))
|
||||
assert "host:1234" in msg
|
||||
assert "boom" in msg
|
||||
|
||||
|
||||
class TestIsExtOpenaiEndpoint:
|
||||
def test_openai_com_is_ext(self):
|
||||
cfg = MagicMock()
|
||||
cfg.endpoints = []
|
||||
cfg.llama_server_endpoints = []
|
||||
with patch.object(router, "config", cfg):
|
||||
assert router.is_ext_openai_endpoint("https://api.openai.com/v1") is True
|
||||
|
||||
def test_ollama_default_port_not_ext(self):
|
||||
cfg = MagicMock()
|
||||
cfg.endpoints = ["http://host:11434"]
|
||||
cfg.llama_server_endpoints = []
|
||||
with patch.object(router, "config", cfg):
|
||||
assert router.is_ext_openai_endpoint("http://host:11434") is False
|
||||
|
||||
def test_llama_server_not_ext(self):
|
||||
cfg = MagicMock()
|
||||
cfg.endpoints = []
|
||||
cfg.llama_server_endpoints = ["http://host:8080/v1"]
|
||||
with patch.object(router, "config", cfg):
|
||||
assert router.is_ext_openai_endpoint("http://host:8080/v1") is False
|
||||
|
||||
def test_no_v1_not_ext(self):
|
||||
cfg = MagicMock()
|
||||
cfg.endpoints = ["http://host:11434"]
|
||||
cfg.llama_server_endpoints = []
|
||||
with patch.object(router, "config", cfg):
|
||||
assert router.is_ext_openai_endpoint("http://host:11434") is False
|
||||
|
||||
|
||||
class TestIsOpenaiCompatible:
|
||||
def test_v1_endpoint_compatible(self):
|
||||
cfg = MagicMock()
|
||||
cfg.llama_server_endpoints = []
|
||||
with patch.object(router, "config", cfg):
|
||||
assert router.is_openai_compatible("http://host/v1") is True
|
||||
|
||||
def test_ollama_not_compatible(self):
|
||||
cfg = MagicMock()
|
||||
cfg.llama_server_endpoints = []
|
||||
with patch.object(router, "config", cfg):
|
||||
assert router.is_openai_compatible("http://localhost:11434") is False
|
||||
|
||||
def test_llama_server_in_list_compatible(self):
|
||||
cfg = MagicMock()
|
||||
cfg.llama_server_endpoints = ["http://host:8080"]
|
||||
with patch.object(router, "config", cfg):
|
||||
assert router.is_openai_compatible("http://host:8080") is True
|
||||
|
||||
|
||||
class TestGetTrackingModel:
|
||||
def test_ollama_adds_latest(self):
|
||||
cfg = MagicMock()
|
||||
cfg.llama_server_endpoints = []
|
||||
with patch.object(router, "config", cfg):
|
||||
assert router.get_tracking_model("http://ollama:11434", "llama3.2") == "llama3.2:latest"
|
||||
|
||||
def test_ollama_keeps_existing_tag(self):
|
||||
cfg = MagicMock()
|
||||
cfg.llama_server_endpoints = []
|
||||
with patch.object(router, "config", cfg):
|
||||
assert router.get_tracking_model("http://ollama:11434", "llama3.2:7b") == "llama3.2:7b"
|
||||
|
||||
def test_llama_server_normalizes(self):
|
||||
ep = "http://host:8080/v1"
|
||||
cfg = MagicMock()
|
||||
cfg.llama_server_endpoints = [ep]
|
||||
with patch.object(router, "config", cfg):
|
||||
result = router.get_tracking_model(ep, "unsloth/model:Q8_0")
|
||||
assert result == "model"
|
||||
173
test/test_unit_rechunk.py
Normal file
173
test/test_unit_rechunk.py
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
"""Unit tests for router.rechunk — OpenAI ↔ Ollama chunk shape conversion."""
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
|
||||
import ollama
|
||||
|
||||
import router
|
||||
|
||||
|
||||
def _ns(**kw):
|
||||
return SimpleNamespace(**kw)
|
||||
|
||||
|
||||
def _stream_chunk(content="hi", role="assistant", finish_reason=None,
|
||||
usage=None, model="m"):
|
||||
"""Build a SimpleNamespace mimicking a streaming OpenAI chunk."""
|
||||
delta = _ns(content=content, role=role, reasoning=None, reasoning_content=None,
|
||||
tool_calls=None)
|
||||
choice = _ns(delta=delta, finish_reason=finish_reason, logprobs=None)
|
||||
return _ns(model=model, choices=[choice], usage=usage)
|
||||
|
||||
|
||||
def _nonstream_chunk(content="hi", role="assistant", finish_reason="stop",
|
||||
usage=None, model="m", tool_calls=None):
|
||||
"""Build a SimpleNamespace mimicking a non-streaming OpenAI ChatCompletion."""
|
||||
message = _ns(content=content, role=role, reasoning=None, reasoning_content=None,
|
||||
tool_calls=tool_calls)
|
||||
choice = _ns(message=message, finish_reason=finish_reason, logprobs=None)
|
||||
return _ns(model=model, choices=[choice], usage=usage)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# openai_chat_completion2ollama
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestChatCompletionToOllama:
|
||||
def test_streaming_content_chunk(self):
|
||||
chunk = _stream_chunk(content="hello", finish_reason=None, usage=None)
|
||||
out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter())
|
||||
assert isinstance(out, ollama.ChatResponse)
|
||||
assert out.message.role == "assistant"
|
||||
assert out.message.content == "hello"
|
||||
assert out.done is False # usage is None → not done yet
|
||||
assert out.model == "m"
|
||||
|
||||
def test_streaming_empty_content_defaults(self):
|
||||
# Some chunks have content=None — should coerce to empty string
|
||||
chunk = _stream_chunk(content=None, role=None)
|
||||
out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter())
|
||||
assert out.message.role == "assistant" # role defaulted
|
||||
assert out.message.content == ""
|
||||
|
||||
def test_final_usage_only_chunk_marks_done(self):
|
||||
usage = _ns(prompt_tokens=10, completion_tokens=5, total_tokens=15)
|
||||
chunk = _ns(model="m", choices=[], usage=usage)
|
||||
out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter())
|
||||
assert out.done is True
|
||||
assert out.done_reason == "stop"
|
||||
assert out.prompt_eval_count == 10
|
||||
assert out.eval_count == 5
|
||||
assert out.message.content == ""
|
||||
|
||||
def test_nonstreaming_with_content(self):
|
||||
usage = _ns(prompt_tokens=2, completion_tokens=3, total_tokens=5)
|
||||
chunk = _nonstream_chunk(content="response text", finish_reason="stop", usage=usage)
|
||||
out = router.rechunk.openai_chat_completion2ollama(chunk, False, time.perf_counter())
|
||||
assert out.done is True
|
||||
assert out.message.content == "response text"
|
||||
assert out.prompt_eval_count == 2
|
||||
assert out.eval_count == 3
|
||||
|
||||
def test_nonstreaming_tool_calls_converted(self):
|
||||
"""Tool calls with JSON string arguments are parsed into dicts."""
|
||||
tc = _ns(function=_ns(name="get_weather", arguments='{"city": "Paris"}'))
|
||||
usage = _ns(prompt_tokens=1, completion_tokens=1, total_tokens=2)
|
||||
chunk = _nonstream_chunk(
|
||||
content="", finish_reason="tool_calls", usage=usage, tool_calls=[tc]
|
||||
)
|
||||
out = router.rechunk.openai_chat_completion2ollama(chunk, False, time.perf_counter())
|
||||
assert out.message.tool_calls is not None
|
||||
assert len(out.message.tool_calls) == 1
|
||||
first = out.message.tool_calls[0]
|
||||
assert first.function.name == "get_weather"
|
||||
assert first.function.arguments == {"city": "Paris"}
|
||||
|
||||
def test_nonstreaming_tool_calls_with_invalid_json_fall_back_to_empty(self):
|
||||
tc = _ns(function=_ns(name="f", arguments="not-json"))
|
||||
usage = _ns(prompt_tokens=1, completion_tokens=1, total_tokens=2)
|
||||
chunk = _nonstream_chunk(content="", usage=usage, tool_calls=[tc])
|
||||
out = router.rechunk.openai_chat_completion2ollama(chunk, False, time.perf_counter())
|
||||
assert out.message.tool_calls[0].function.arguments == {}
|
||||
|
||||
def test_streaming_tool_calls_in_delta_are_skipped(self):
|
||||
"""Streaming mode must not assemble tool calls (caller handles it)."""
|
||||
chunk = _stream_chunk(content="x", finish_reason=None)
|
||||
# Even if a chunk somehow carried tool_calls in the delta, streaming
|
||||
# mode should ignore them.
|
||||
out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter())
|
||||
assert out.message.tool_calls is None
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# openai_completion2ollama
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestCompletionToOllama:
|
||||
def test_streaming_text_chunk(self):
|
||||
choice = _ns(text="word", finish_reason=None, reasoning=None)
|
||||
chunk = _ns(model="m", choices=[choice], usage=None)
|
||||
out = router.rechunk.openai_completion2ollama(chunk, True, time.perf_counter())
|
||||
assert isinstance(out, ollama.GenerateResponse)
|
||||
assert out.response == "word"
|
||||
assert out.done is False
|
||||
|
||||
def test_final_chunk_with_usage(self):
|
||||
usage = _ns(prompt_tokens=4, completion_tokens=6, total_tokens=10)
|
||||
choice = _ns(text="end", finish_reason="stop", reasoning=None)
|
||||
chunk = _ns(model="m", choices=[choice], usage=usage)
|
||||
out = router.rechunk.openai_completion2ollama(chunk, True, time.perf_counter())
|
||||
assert out.done is True
|
||||
assert out.prompt_eval_count == 4
|
||||
assert out.eval_count == 6
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# embeddings / embed
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestEmbeddingConversions:
|
||||
def test_openai_embeddings2ollama(self):
|
||||
chunk = _ns(data=[_ns(embedding=[0.1, 0.2, 0.3])])
|
||||
out = router.rechunk.openai_embeddings2ollama(chunk)
|
||||
assert isinstance(out, ollama.EmbeddingsResponse)
|
||||
assert list(out.embedding) == [0.1, 0.2, 0.3]
|
||||
|
||||
def test_openai_embed2ollama(self):
|
||||
chunk = _ns(data=[_ns(embedding=[0.5, 0.6])])
|
||||
out = router.rechunk.openai_embed2ollama(chunk, "my-embed-model")
|
||||
assert isinstance(out, ollama.EmbedResponse)
|
||||
assert out.model == "my-embed-model"
|
||||
assert list(out.embeddings[0]) == [0.5, 0.6]
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# extract_usage_from_llama_timings
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestExtractUsageFromLlamaTimings:
|
||||
def test_none_when_no_timings_attr(self):
|
||||
obj = _ns()
|
||||
assert router.rechunk.extract_usage_from_llama_timings(obj) is None
|
||||
|
||||
def test_prompt_plus_cache_sums(self):
|
||||
obj = _ns(timings={"prompt_n": 1, "cache_n": 236, "predicted_n": 35})
|
||||
prompt, completion = router.rechunk.extract_usage_from_llama_timings(obj)
|
||||
assert prompt == 237
|
||||
assert completion == 35
|
||||
|
||||
def test_missing_keys_default_to_zero(self):
|
||||
obj = _ns(timings={"predicted_n": 12})
|
||||
prompt, completion = router.rechunk.extract_usage_from_llama_timings(obj)
|
||||
assert prompt == 0
|
||||
assert completion == 12
|
||||
|
||||
def test_null_values_treated_as_zero(self):
|
||||
obj = _ns(timings={"prompt_n": None, "cache_n": None, "predicted_n": None})
|
||||
prompt, completion = router.rechunk.extract_usage_from_llama_timings(obj)
|
||||
assert prompt == 0
|
||||
assert completion == 0
|
||||
|
||||
def test_non_dict_timings_returns_none(self):
|
||||
obj = _ns(timings="not-a-dict")
|
||||
assert router.rechunk.extract_usage_from_llama_timings(obj) is None
|
||||
200
test/test_unit_transforms.py
Normal file
200
test/test_unit_transforms.py
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
"""Unit tests for message transformation functions."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import router
|
||||
|
||||
|
||||
class TestStripAssistantPrefill:
|
||||
def test_removes_trailing_assistant(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "prefill"},
|
||||
]
|
||||
result = router._strip_assistant_prefill(msgs)
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "user"
|
||||
|
||||
def test_keeps_non_trailing_assistant(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "response"},
|
||||
{"role": "user", "content": "follow-up"},
|
||||
]
|
||||
result = router._strip_assistant_prefill(msgs)
|
||||
assert len(result) == 3
|
||||
|
||||
def test_empty_list_unchanged(self):
|
||||
assert router._strip_assistant_prefill([]) == []
|
||||
|
||||
def test_single_user_message_unchanged(self):
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
assert router._strip_assistant_prefill(msgs) == msgs
|
||||
|
||||
|
||||
class TestTransformToolCallsToOpenAI:
|
||||
def test_adds_type_function(self):
|
||||
msgs = [{"role": "assistant", "tool_calls": [
|
||||
{"function": {"name": "get_weather", "arguments": {"city": "Berlin"}}}
|
||||
]}]
|
||||
result = router.transform_tool_calls_to_openai(msgs)
|
||||
tc = result[0]["tool_calls"][0]
|
||||
assert tc["type"] == "function"
|
||||
|
||||
def test_adds_id_when_missing(self):
|
||||
msgs = [{"role": "assistant", "tool_calls": [
|
||||
{"function": {"name": "fn", "arguments": {}}}
|
||||
]}]
|
||||
result = router.transform_tool_calls_to_openai(msgs)
|
||||
assert "id" in result[0]["tool_calls"][0]
|
||||
|
||||
def test_converts_dict_arguments_to_string(self):
|
||||
msgs = [{"role": "assistant", "tool_calls": [
|
||||
{"function": {"name": "fn", "arguments": {"key": "val"}}}
|
||||
]}]
|
||||
result = router.transform_tool_calls_to_openai(msgs)
|
||||
args = result[0]["tool_calls"][0]["function"]["arguments"]
|
||||
assert isinstance(args, str)
|
||||
import orjson
|
||||
parsed = orjson.loads(args)
|
||||
assert parsed == {"key": "val"}
|
||||
|
||||
def test_keeps_string_arguments_unchanged(self):
|
||||
msgs = [{"role": "assistant", "tool_calls": [
|
||||
{"function": {"name": "fn", "arguments": '{"key": "val"}'}}
|
||||
]}]
|
||||
result = router.transform_tool_calls_to_openai(msgs)
|
||||
args = result[0]["tool_calls"][0]["function"]["arguments"]
|
||||
assert args == '{"key": "val"}'
|
||||
|
||||
def test_links_tool_call_id_to_tool_response(self):
|
||||
msgs = [
|
||||
{"role": "assistant", "tool_calls": [
|
||||
{"function": {"name": "get_weather", "arguments": {}}}
|
||||
]},
|
||||
{"role": "tool", "name": "get_weather", "content": "sunny"},
|
||||
]
|
||||
result = router.transform_tool_calls_to_openai(msgs)
|
||||
tc_id = result[0]["tool_calls"][0]["id"]
|
||||
assert result[1].get("tool_call_id") == tc_id
|
||||
|
||||
def test_non_tool_messages_unchanged(self):
|
||||
msgs = [{"role": "user", "content": "hello"}]
|
||||
result = router.transform_tool_calls_to_openai(msgs)
|
||||
assert result == msgs
|
||||
|
||||
|
||||
class TestStripImagesFromMessages:
|
||||
def test_removes_image_url_parts(self):
|
||||
msgs = [{"role": "user", "content": [
|
||||
{"type": "text", "text": "what is this?"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
||||
]}]
|
||||
result = router._strip_images_from_messages(msgs)
|
||||
content = result[0]["content"]
|
||||
assert content == "what is this?"
|
||||
|
||||
def test_keeps_text_only_messages(self):
|
||||
msgs = [{"role": "user", "content": "plain text"}]
|
||||
result = router._strip_images_from_messages(msgs)
|
||||
assert result[0]["content"] == "plain text"
|
||||
|
||||
def test_multiple_text_parts_kept_as_list(self):
|
||||
msgs = [{"role": "user", "content": [
|
||||
{"type": "text", "text": "part one"},
|
||||
{"type": "text", "text": "part two"},
|
||||
{"type": "image_url", "image_url": {"url": "data:..."}},
|
||||
]}]
|
||||
result = router._strip_images_from_messages(msgs)
|
||||
content = result[0]["content"]
|
||||
assert isinstance(content, list)
|
||||
assert len(content) == 2
|
||||
|
||||
def test_all_images_removed_empty_list(self):
|
||||
msgs = [{"role": "user", "content": [
|
||||
{"type": "image_url", "image_url": {"url": "data:..."}},
|
||||
]}]
|
||||
result = router._strip_images_from_messages(msgs)
|
||||
# Image-only content becomes empty list
|
||||
content = result[0]["content"]
|
||||
assert content == []
|
||||
|
||||
|
||||
class TestAccumulateOpenAITcDelta:
|
||||
def _make_chunk(self, index, name=None, args_fragment="", tc_id=None):
|
||||
delta = MagicMock()
|
||||
tc = MagicMock()
|
||||
tc.index = index
|
||||
tc.id = tc_id
|
||||
tc.function = MagicMock()
|
||||
tc.function.name = name
|
||||
tc.function.arguments = args_fragment
|
||||
delta.tool_calls = [tc]
|
||||
chunk = MagicMock()
|
||||
chunk.choices = [MagicMock(delta=delta)]
|
||||
return chunk
|
||||
|
||||
def test_first_delta_creates_entry(self):
|
||||
acc = {}
|
||||
chunk = self._make_chunk(0, name="my_fn", args_fragment='{"k"')
|
||||
router._accumulate_openai_tc_delta(chunk, acc)
|
||||
assert 0 in acc
|
||||
assert acc[0]["name"] == "my_fn"
|
||||
assert acc[0]["arguments"] == '{"k"'
|
||||
|
||||
def test_subsequent_deltas_concatenate_args(self):
|
||||
acc = {}
|
||||
router._accumulate_openai_tc_delta(self._make_chunk(0, name="fn", args_fragment='{"k"'), acc)
|
||||
router._accumulate_openai_tc_delta(self._make_chunk(0, args_fragment=': "v"}'), acc)
|
||||
assert acc[0]["arguments"] == '{"k": "v"}'
|
||||
|
||||
def test_multiple_tool_calls_tracked_separately(self):
|
||||
acc = {}
|
||||
c1 = self._make_chunk(0, name="fn1", args_fragment="{}")
|
||||
c2 = self._make_chunk(1, name="fn2", args_fragment="{}")
|
||||
chunk = MagicMock()
|
||||
tc1 = MagicMock()
|
||||
tc1.index = 0
|
||||
tc1.id = "id1"
|
||||
tc1.function = MagicMock(name="fn1", arguments="{}")
|
||||
tc2 = MagicMock()
|
||||
tc2.index = 1
|
||||
tc2.id = "id2"
|
||||
tc2.function = MagicMock(name="fn2", arguments="{}")
|
||||
chunk.choices = [MagicMock(delta=MagicMock(tool_calls=[tc1, tc2]))]
|
||||
router._accumulate_openai_tc_delta(chunk, acc)
|
||||
assert 0 in acc and 1 in acc
|
||||
|
||||
def test_no_choices_is_noop(self):
|
||||
acc = {}
|
||||
chunk = MagicMock(choices=[])
|
||||
router._accumulate_openai_tc_delta(chunk, acc)
|
||||
assert acc == {}
|
||||
|
||||
|
||||
class TestBuildOllamaToolCalls:
|
||||
def test_builds_from_accumulator(self):
|
||||
acc = {0: {"id": "call_abc", "name": "get_weather", "arguments": '{"city": "Berlin"}'}}
|
||||
result = router._build_ollama_tool_calls(acc)
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert result[0].function.name == "get_weather"
|
||||
assert result[0].function.arguments == {"city": "Berlin"}
|
||||
|
||||
def test_invalid_json_args_becomes_empty_dict(self):
|
||||
acc = {0: {"id": "c1", "name": "fn", "arguments": "not-json"}}
|
||||
result = router._build_ollama_tool_calls(acc)
|
||||
assert result[0].function.arguments == {}
|
||||
|
||||
def test_empty_accumulator_returns_none(self):
|
||||
assert router._build_ollama_tool_calls({}) is None
|
||||
|
||||
def test_preserves_order_by_index(self):
|
||||
acc = {
|
||||
1: {"id": "c2", "name": "fn2", "arguments": "{}"},
|
||||
0: {"id": "c1", "name": "fn1", "arguments": "{}"},
|
||||
}
|
||||
result = router._build_ollama_tool_calls(acc)
|
||||
assert result[0].function.name == "fn1"
|
||||
assert result[1].function.name == "fn2"
|
||||
Loading…
Add table
Add a link
Reference in a new issue