233 lines
7.7 KiB
Python
233 lines
7.7 KiB
Python
"""
|
|
Test configuration for nomyo-router.
|
|
|
|
Run from project root:
|
|
pytest test/ -v
|
|
pytest test/ -m "not integration" # skip real-server tests
|
|
pytest test/ -m integration -v # only real-server tests
|
|
|
|
Environment variables:
|
|
NOMYO_TEST_OLLAMA Ollama endpoint (default: http://192.168.0.50:12434)
|
|
NOMYO_TEST_LLAMA llama-server endpoint (default: http://192.168.0.50:12434/v1)
|
|
NOMYO_TEST_MODEL_CHAT chat model to use (auto-discovered if unset)
|
|
NOMYO_TEST_EMBED_MODEL embedding model (auto-discovered if unset)
|
|
"""
|
|
import asyncio
|
|
import os
|
|
import ssl
|
|
import sys
|
|
import tempfile
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import aiohttp
|
|
import httpx
|
|
import pytest
|
|
|
|
_TEST_DIR = Path(__file__).parent
|
|
# Must be set before importing router so module-level Config.from_yaml + Config field
|
|
# defaults pick these up. db_path is intentionally absent from config_test.yaml so the
|
|
# env-var default wins — keeps tests portable across CI runners (Linux/macOS/Windows).
|
|
os.environ.setdefault("NOMYO_ROUTER_CONFIG_PATH", str(_TEST_DIR / "config_test.yaml"))
|
|
os.environ.setdefault(
|
|
"NOMYO_ROUTER_DB_PATH",
|
|
str(Path(tempfile.gettempdir()) / "nomyo_router_test_tokens.db"),
|
|
)
|
|
|
|
sys.path.insert(0, str(_TEST_DIR.parent))
|
|
|
|
import router # noqa: E402
|
|
|
|
TEST_OLLAMA = os.getenv("NOMYO_TEST_OLLAMA", "http://192.168.0.51:12434")
|
|
TEST_LLAMA = os.getenv("NOMYO_TEST_LLAMA", "http://192.168.0.51:12434/v1")
|
|
|
|
|
|
def pytest_configure(config):
|
|
config.addinivalue_line(
|
|
"markers",
|
|
"integration: tests that require a real backend at 192.168.0.50:12434",
|
|
)
|
|
|
|
|
|
# ── Config mocks ─────────────────────────────────────────────────────────────
|
|
|
|
@pytest.fixture
|
|
def mock_config():
|
|
"""Minimal config pointing at TEST_OLLAMA / TEST_LLAMA."""
|
|
cfg = MagicMock()
|
|
cfg.endpoints = [TEST_OLLAMA]
|
|
cfg.llama_server_endpoints = [TEST_LLAMA]
|
|
cfg.api_keys = {TEST_OLLAMA: "ollama", TEST_LLAMA: "llama"}
|
|
cfg.max_concurrent_connections = 2
|
|
cfg.router_api_key = None
|
|
cfg.cache_enabled = False
|
|
return cfg
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_config_no_llama():
|
|
"""Config with Ollama only, no llama-server."""
|
|
cfg = MagicMock()
|
|
cfg.endpoints = [TEST_OLLAMA]
|
|
cfg.llama_server_endpoints = []
|
|
cfg.api_keys = {TEST_OLLAMA: "ollama"}
|
|
cfg.max_concurrent_connections = 2
|
|
cfg.router_api_key = None
|
|
cfg.cache_enabled = False
|
|
return cfg
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_config_with_key():
|
|
"""Config with router_api_key set (enables auth middleware)."""
|
|
cfg = MagicMock()
|
|
cfg.endpoints = [TEST_OLLAMA]
|
|
cfg.llama_server_endpoints = []
|
|
cfg.api_keys = {}
|
|
cfg.max_concurrent_connections = 2
|
|
cfg.router_api_key = "test-secret-key"
|
|
cfg.cache_enabled = False
|
|
return cfg
|
|
|
|
|
|
# ── aiohttp session (used by fetch tests + choose_endpoint tests) ─────────────
|
|
|
|
@pytest.fixture
|
|
async def aio_session():
|
|
"""Real aiohttp session stored in app_state; intercepted by aioresponses."""
|
|
ssl_ctx = ssl.create_default_context()
|
|
conn = aiohttp.TCPConnector(ssl=ssl_ctx)
|
|
session = aiohttp.ClientSession(connector=conn)
|
|
router.app_state["session"] = session
|
|
|
|
# Clear caches to prevent test bleed
|
|
router._models_cache.clear()
|
|
router._loaded_models_cache.clear()
|
|
router._available_error_cache.clear()
|
|
router._loaded_error_cache.clear()
|
|
router._inflight_available_models.clear()
|
|
router._inflight_loaded_models.clear()
|
|
router._bg_refresh_available.clear()
|
|
router._bg_refresh_loaded.clear()
|
|
|
|
yield session
|
|
|
|
await session.close()
|
|
router.app_state["session"] = None
|
|
|
|
|
|
# ── Validation-only HTTP client (no real backend needed) ──────────────────────
|
|
|
|
@pytest.fixture
|
|
async def client(mock_config, tmp_path):
|
|
"""httpx client for validation/auth tests — no real backend calls made."""
|
|
from db import TokenDatabase
|
|
|
|
ssl_ctx = ssl.create_default_context()
|
|
conn = aiohttp.TCPConnector(ssl=ssl_ctx)
|
|
session = aiohttp.ClientSession(connector=conn)
|
|
|
|
db_inst = TokenDatabase(str(tmp_path / "test.db"))
|
|
await db_inst.init_db()
|
|
|
|
old_session = router.app_state.get("session")
|
|
old_db = router.db
|
|
|
|
router.app_state["session"] = session
|
|
router.db = db_inst
|
|
|
|
with patch.object(router, "config", mock_config):
|
|
transport = httpx.ASGITransport(app=router.app)
|
|
async with httpx.AsyncClient(
|
|
transport=transport, base_url="http://test", timeout=10.0
|
|
) as c:
|
|
yield c
|
|
|
|
await session.close()
|
|
router.app_state["session"] = old_session
|
|
router.db = old_db
|
|
|
|
|
|
@pytest.fixture
|
|
async def client_auth(mock_config_with_key, tmp_path):
|
|
"""httpx client with router_api_key configured (for auth middleware tests)."""
|
|
from db import TokenDatabase
|
|
|
|
ssl_ctx = ssl.create_default_context()
|
|
conn = aiohttp.TCPConnector(ssl=ssl_ctx)
|
|
session = aiohttp.ClientSession(connector=conn)
|
|
|
|
db_inst = TokenDatabase(str(tmp_path / "test_auth.db"))
|
|
await db_inst.init_db()
|
|
|
|
old_session = router.app_state.get("session")
|
|
old_db = router.db
|
|
|
|
router.app_state["session"] = session
|
|
router.db = db_inst
|
|
|
|
with patch.object(router, "config", mock_config_with_key):
|
|
transport = httpx.ASGITransport(app=router.app)
|
|
async with httpx.AsyncClient(
|
|
transport=transport, base_url="http://test", timeout=10.0
|
|
) as c:
|
|
yield c
|
|
|
|
await session.close()
|
|
router.app_state["session"] = old_session
|
|
router.db = old_db
|
|
|
|
|
|
# ── Integration client (full startup with real backend) ──────────────────────
|
|
|
|
@pytest.fixture(scope="module")
|
|
async def integration_client():
|
|
"""Full app startup pointing at the real test server."""
|
|
await router.startup_event()
|
|
transport = httpx.ASGITransport(app=router.app)
|
|
async with httpx.AsyncClient(
|
|
transport=transport,
|
|
base_url="http://test",
|
|
timeout=httpx.Timeout(60.0),
|
|
) as c:
|
|
yield c
|
|
await router.shutdown_event()
|
|
|
|
|
|
# ── Model discovery fixtures ──────────────────────────────────────────────────
|
|
|
|
@pytest.fixture(scope="module")
|
|
async def chat_model(integration_client):
|
|
"""Return a chat/generation model name available on the test server."""
|
|
env_model = os.getenv("NOMYO_TEST_MODEL_CHAT")
|
|
if env_model:
|
|
return env_model
|
|
resp = await integration_client.get("/api/tags")
|
|
if resp.status_code != 200:
|
|
pytest.skip("Cannot reach test server")
|
|
models = resp.json().get("models", [])
|
|
# Prefer small models for faster tests
|
|
for m in models:
|
|
name = m.get("name", "")
|
|
if any(x in name.lower() for x in ["0.5b", "1b", "3b", "1.5b", "2b"]):
|
|
return name
|
|
if models:
|
|
return models[0]["name"]
|
|
pytest.skip("No chat models available on test server")
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
async def embed_model(integration_client):
|
|
"""Return an embedding model name available on the test server."""
|
|
env_model = os.getenv("NOMYO_TEST_EMBED_MODEL")
|
|
if env_model:
|
|
return env_model
|
|
resp = await integration_client.get("/api/tags")
|
|
if resp.status_code != 200:
|
|
pytest.skip("Cannot reach test server")
|
|
models = resp.json().get("models", [])
|
|
for m in models:
|
|
name = m.get("name", "")
|
|
if any(x in name.lower() for x in ["embed", "nomic", "minilm", "bge", "e5"]):
|
|
return name
|
|
pytest.skip("No embedding model available on test server")
|