diff --git a/.forgejo/workflows/docker-publish-semantic.yml b/.forgejo/workflows/docker-publish-semantic.yml index 163f1a1..ebbfd36 100644 --- a/.forgejo/workflows/docker-publish-semantic.yml +++ b/.forgejo/workflows/docker-publish-semantic.yml @@ -86,7 +86,7 @@ jobs: provenance: false build-args: | SEMANTIC_CACHE=true - tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-${{ matrix.arch }}-${{ github.run_id }} + tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-${{ matrix.arch }} merge: runs-on: docker-amd64 @@ -142,6 +142,6 @@ jobs: run: | docker buildx imagetools create \ $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \ - ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-amd64-${{ github.run_id }} \ - ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-arm64-${{ github.run_id }} + ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-amd64 \ + ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-arm64 diff --git a/.forgejo/workflows/docker-publish.yml b/.forgejo/workflows/docker-publish.yml index 09e145c..27cd879 100644 --- a/.forgejo/workflows/docker-publish.yml +++ b/.forgejo/workflows/docker-publish.yml @@ -77,7 +77,7 @@ jobs: platforms: ${{ matrix.platform }} push: true provenance: false - tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-${{ matrix.arch }}-${{ github.run_id }} + tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-${{ matrix.arch }} merge: runs-on: docker-amd64 @@ -133,6 +133,6 @@ jobs: run: | docker buildx imagetools create \ $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \ - ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-amd64-${{ github.run_id }} \ - ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-arm64-${{ github.run_id }} + ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-amd64 \ + ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-arm64 diff --git a/.forgejo/workflows/pr-tests.yml b/.forgejo/workflows/pr-tests.yml deleted file mode 100644 index aa96b84..0000000 --- a/.forgejo/workflows/pr-tests.yml +++ /dev/null @@ -1,39 +0,0 @@ -name: PR Tests -on: [pull_request] -jobs: - test: - runs-on: docker-arm64 - container: - image: python:3.12-slim - env: - CMAKE_BUILD_PARALLEL_LEVEL: "4" - steps: - - name: Install system deps - run: | - apt-get update - apt-get install -y --no-install-recommends \ - git ca-certificates \ - build-essential pkg-config - rm -rf /var/lib/apt/lists/* - - name: Checkout - run: | - git config --global --add safe.directory "$PWD" - git clone --depth=1 \ - "https://oauth2:${{ github.token }}@bitfreedom.net/code/${{ github.repository }}.git" . - git fetch --depth=1 origin "+${{ github.event.pull_request.head.sha }}:pr" - git checkout pr - - name: Fetch action source - run: | - git clone --depth=1 --branch master \ - "https://oauth2:${{ github.token }}@bitfreedom.net/code/nomyo-ai/actions.git" \ - ./.run-tests - - uses: ./.run-tests/run-tests - with: - setup: | - python -m pip install --upgrade pip - pip install -r requirements.txt - pip install -r test/requirements_test.txt - command: pytest test/ -m "not integration" --cov=router --cov=cache --cov=db --cov=enhance --cov-fail-under=45 --cov-report=term-missing --cov-report=xml --junitxml=report.xml - artifacts-path: | - report.xml - coverage.xml diff --git a/.gitignore b/.gitignore index 7cd8431..cfce37c 100644 --- a/.gitignore +++ b/.gitignore @@ -66,4 +66,7 @@ config.yaml # SQLite *.db* -*settings.json \ No newline at end of file +*settings.json + +# Test suite (local only, not committed yet) +test/ \ No newline at end of file diff --git a/doc/architecture.md b/doc/architecture.md index c2408d8..f725573 100644 --- a/doc/architecture.md +++ b/doc/architecture.md @@ -206,8 +206,6 @@ The `/health` endpoint provides comprehensive health status: } ``` -For Ollama endpoints the probe is a parallel check of `/api/version` (liveness) and `/api/ps` (the route used by `choose_endpoint` when selecting a backend for a request). Reporting `ok` only when both succeed prevents the router from advertising an endpoint as healthy while completion calls dead-end on `/api/ps`. The same dual probe backs `/api/config`, which the dashboard uses to render endpoint health. - ## Database Schema The router uses SQLite for persistent storage: diff --git a/doc/monitoring.md b/doc/monitoring.md index 9ce25ec..ab75d25 100644 --- a/doc/monitoring.md +++ b/doc/monitoring.md @@ -29,10 +29,6 @@ Response: - `200`: All endpoints healthy - `503`: One or more endpoints unhealthy -**Probe scope per endpoint**: -- **Ollama endpoints** are probed at both `/api/version` (liveness) and `/api/ps` (model-introspection used by the router). If either fails the endpoint is reported as `error`; the response still includes `version` when the daemon is reachable so operators can tell a partial failure from a full outage. The `detail` field names the failing probe, e.g. `"/api/ps: 502 …"`. -- **OpenAI-compatible / llama-server endpoints** are probed at `/models`. - ### Current Usage ```bash @@ -137,8 +133,6 @@ Response: } ``` -Uses the same dual-probe logic as `/health` (Ollama: `/api/version` + `/api/ps`; OpenAI-compatible: `/models`). An endpoint will report `error` whenever either probe fails. The dashboard renders the `detail` field as a tooltip on the status cell. - ### Cache Statistics ```bash diff --git a/requirements.txt b/requirements.txt index 8353c44..0b4db81 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ anyio==4.13.0 async-timeout==5.0.1 attrs==26.1.0 certifi==2026.4.22 -click==8.4.0 +click==8.3.3 distro==1.9.0 exceptiongroup==1.3.1 fastapi==0.136.1 @@ -16,10 +16,10 @@ h11==0.16.0 httpcore==1.0.9 httpx==0.28.1 idna==3.15 -jiter==0.15.0 +jiter==0.14.0 multidict==6.7.1 ollama==0.6.2 -openai==2.37.0 +openai==1.109.1 orjson>=3.11.5 numpy>=1.26 pillow==12.2.0 @@ -30,7 +30,7 @@ pydantic_core==2.46.4 python-dotenv==1.2.2 PyYAML==6.0.3 sniffio==1.3.1 -starlette==1.0.1 +starlette==1.0.0 truststore==0.10.4 tiktoken==0.13.0 tqdm==4.67.3 diff --git a/router.py b/router.py index c465ebc..08225fb 100644 --- a/router.py +++ b/router.py @@ -1000,7 +1000,7 @@ class fetch: async with client.get(f"{endpoint}/models") as resp: await _ensure_success(resp) data = await resp.json() - + # Filter for loaded models only items = data.get("data", []) models = { @@ -1012,19 +1012,11 @@ class fetch: # Update cache with lock protection async with _loaded_models_cache_lock: _loaded_models_cache[endpoint] = (models, time.time()) - # Probe succeeded — clear any stale error so the endpoint - # becomes routable again. - async with _loaded_error_cache_lock: - _loaded_error_cache.pop(endpoint, None) return models except Exception as e: # If anything goes wrong we simply assume the endpoint has no models message = _format_connection_issue(f"{endpoint}/models", e) print(f"[fetch.loaded_models] {message}") - # Record the failure so `choose_endpoint` can avoid routing - # to an unhealthy backend and repeated probes short-circuit. - async with _loaded_error_cache_lock: - _loaded_error_cache[endpoint] = time.time() return set() else: # Original Ollama /api/ps logic @@ -1039,15 +1031,11 @@ class fetch: # Update cache with lock protection async with _loaded_models_cache_lock: _loaded_models_cache[endpoint] = (models, time.time()) - async with _loaded_error_cache_lock: - _loaded_error_cache.pop(endpoint, None) return models except Exception as e: # If anything goes wrong we simply assume the endpoint has no models message = _format_connection_issue(f"{endpoint}/api/ps", e) print(f"[fetch.loaded_models] {message}") - async with _loaded_error_cache_lock: - _loaded_error_cache[endpoint] = time.time() return set() async def _refresh_loaded_models(endpoint: str) -> None: @@ -1865,28 +1853,6 @@ async def choose_endpoint(model: str, reserve: bool = True, load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints] loaded_sets = await asyncio.gather(*load_tasks) - # 3️⃣.5 Exclude endpoints whose loaded-model probe has been failing - # recently. Without this filter, an endpoint where `/api/ps` returns 5xx - # would appear with an empty loaded set but pass through to the - # free-slot fallback (step 4) — sending completion calls to an - # unhealthy backend. See issue #83. - async with _loaded_error_cache_lock: - unhealthy = { - ep for ep, ts in _loaded_error_cache.items() - if _is_fresh(ts, 300) - } - if unhealthy: - filtered = [ - (ep, models) for ep, models in zip(candidate_endpoints, loaded_sets) - if ep not in unhealthy - ] - if filtered: - candidate_endpoints = [ep for ep, _ in filtered] - loaded_sets = [models for _, models in filtered] - # If *every* candidate is unhealthy we still fall through with the - # original list — refusing to route is worse than retrying a - # possibly-recovered backend. - # Look up a possible affinity hint *before* taking usage_lock. The two # locks are never held together to avoid lock-ordering issues. affine_ep: Optional[str] = None @@ -3188,103 +3154,44 @@ async def usage_proxy(request: Request): "token_usage_counts": token_usage_counts} # ------------------------------------------------------------- -# 20. Endpoint health probes (shared by /api/config and /health) -# ------------------------------------------------------------- -async def _raw_probe( - ep: str, - route: str, - api_key: Optional[str] = None, - timeout: Optional[float] = None, -) -> tuple[bool, object]: - """Direct HTTP probe that distinguishes success from failure - (unlike `fetch.endpoint_details`, which returns [] on either). - Returns `(ok, payload_or_error_message)`. - """ - headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")} - if api_key is not None: - headers["Authorization"] = "Bearer " + api_key - url = f"{ep.rstrip('/')}/{route.lstrip('/')}" - req_kwargs = {} - if timeout is not None: - req_kwargs["timeout"] = aiohttp.ClientTimeout(total=timeout) - try: - client: aiohttp.ClientSession = get_session(ep) - async with client.get(url, headers=headers, **req_kwargs) as resp: - await _ensure_success(resp) - data = await resp.json() - return True, data - except Exception as exc: - return False, _format_connection_issue(url, exc) - - -async def _endpoint_health(ep: str, *, timeout: Optional[float] = None) -> dict: - """Probe an endpoint and return `{status, version?, detail?}`. - - Ollama endpoints get a dual probe of `/api/version` and `/api/ps` so - that a daemon which is reachable but has a broken model-introspection - path (issue #83) is reported as `error` rather than `ok`. - OpenAI-compatible endpoints use a single `/models` probe. - """ - if is_openai_compatible(ep): - ok, payload = await _raw_probe( - ep, "/models", config.api_keys.get(ep), timeout=timeout, - ) - if ok: - return {"status": "ok", "version": "latest"} - return {"status": "error", "detail": str(payload)} - - (version_ok, version_payload), (ps_ok, ps_payload) = await asyncio.gather( - _raw_probe(ep, "/api/version", timeout=timeout), - _raw_probe(ep, "/api/ps", timeout=timeout), - ) - - version_value = ( - version_payload.get("version") - if version_ok and isinstance(version_payload, dict) - else None - ) - - if version_ok and ps_ok: - return {"status": "ok", "version": version_value} - if not version_ok and not ps_ok: - return {"status": "error", "detail": str(version_payload)} - # Partial failure — daemon reachable but one probe failed. Report - # as "error" so callers can surface the issue; include `version` so - # the operator knows the daemon itself is alive. - if not ps_ok: - return { - "status": "error", - "version": version_value, - "detail": f"/api/ps: {ps_payload}", - } - return { - "status": "error", - "detail": f"/api/version: {version_payload}", - } - - -# ------------------------------------------------------------- -# 20b. Proxy config route – for monitoring and frontend usage +# 20. Proxy config route – for monitoring and frontent usage # ------------------------------------------------------------- @app.get("/api/config") async def config_proxy(request: Request): """ Return a simple JSON object that contains the configured - Ollama endpoints and llama_server_endpoints. The front‑end uses this - to display which endpoints are being proxied and their health. - Status is "error" when either liveness (/api/version) or routing - health (/api/ps) fails — see issue #83. + Ollama endpoints and llama_server_endpoints. The front‑end uses this to display + which endpoints are being proxied. """ - async def check(url: str) -> dict: - return {"url": url, **(await _endpoint_health(url, timeout=5))} + async def check_endpoint(url: str): + client: aiohttp.ClientSession = get_session(url) + headers = None + if "/v1" in url: + headers = {"Authorization": "Bearer " + config.api_keys.get(url, "no-key")} + target_url = f"{url}/models" + else: + target_url = f"{url}/api/version" - ollama_results = await asyncio.gather(*[check(ep) for ep in config.endpoints]) + try: + async with client.get(target_url, headers=headers, timeout=aiohttp.ClientTimeout(total=5)) as resp: + await _ensure_success(resp) + data = await resp.json() + if "/v1" in url: + return {"url": url, "status": "ok", "version": "latest"} + else: + return {"url": url, "status": "ok", "version": data.get("version")} + except Exception as e: + detail = _format_connection_issue(target_url, e) + return {"url": url, "status": "error", "detail": detail} + + # Check Ollama endpoints + ollama_results = await asyncio.gather(*[check_endpoint(ep) for ep in config.endpoints]) + + # Check llama-server endpoints llama_results = [] if config.llama_server_endpoints: - llama_results = await asyncio.gather( - *[check(ep) for ep in config.llama_server_endpoints] - ) - + llama_results = await asyncio.gather(*[check_endpoint(ep) for ep in config.llama_server_endpoints]) + return { "endpoints": ollama_results, "llama_server_endpoints": llama_results, @@ -4096,30 +4003,44 @@ async def health_proxy(request: Request): """ Health‑check endpoint for monitoring the proxy. - * Queries each configured endpoint for both liveness and routing health: - Ollama endpoints are probed at `/api/version` AND `/api/ps`, - OpenAI-compatible endpoints at `/models`. + * Queries each configured endpoint for its `/api/version` response. * Returns a JSON object containing: - - `status`: "ok" if every endpoint replied to every probe, otherwise "error". + - `status`: "ok" if every endpoint replied, otherwise "error". - `endpoints`: a mapping of endpoint URL → `{status, version|detail}`. * The HTTP status code is 200 when everything is healthy, 503 otherwise. """ # Run all health checks in parallel. - # Ollama endpoints expose /api/version (liveness) and /api/ps (routing - # health — required by `choose_endpoint`). OpenAI-compatible endpoints - # (vLLM, llama-server, external) expose /models, which serves both - # purposes. Probing /api/version alone would miss the case where the - # Ollama process is up but /api/ps is failing — see issue #83. + # Ollama endpoints expose /api/version; OpenAI-compatible endpoints (vLLM, + # llama-server, external) expose /models. Using /api/version against an + # OpenAI-compatible endpoint yields a 404 and noisy log output. all_endpoints = list(config.endpoints) llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints] all_endpoints += llama_eps_extra - probe_results = await asyncio.gather( - *(_endpoint_health(ep) for ep in all_endpoints), - ) + tasks = [] + for ep in all_endpoints: + if is_openai_compatible(ep): + tasks.append(fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True)) + else: + tasks.append(fetch.endpoint_details(ep, "/api/version", "version", skip_error_cache=True)) - health_summary = dict(zip(all_endpoints, probe_results)) - overall_ok = all(entry.get("status") == "ok" for entry in probe_results) + results = await asyncio.gather(*tasks, return_exceptions=True) + + health_summary = {} + overall_ok = True + + for ep, result in zip(all_endpoints, results): + if isinstance(result, Exception): + # Endpoint did not respond / returned an error + health_summary[ep] = {"status": "error", "detail": str(result)} + overall_ok = False + else: + # Successful response – report the reported version (Ollama) or + # indicate the endpoint is reachable (OpenAI-compatible). + if is_openai_compatible(ep): + health_summary[ep] = {"status": "ok"} + else: + health_summary[ep] = {"status": "ok", "version": result} response_payload = { "status": "ok" if overall_ok else "error", diff --git a/static/index.html b/static/index.html index 8c0b16c..b29f22b 100644 --- a/static/index.html +++ b/static/index.html @@ -192,10 +192,6 @@ color: #8b0000; font-weight: bold; } - .status-error[title] { - cursor: help; - text-decoration: underline dotted; - } .copy-link, .delete-link, .show-link, @@ -740,16 +736,6 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) { return await resp.json(); } - function escapeHtml(value) { - if (value === null || value === undefined) return ""; - return String(value) - .replace(/&/g, "&") - .replace(//g, ">") - .replace(/"/g, """) - .replace(/'/g, "'"); - } - function toggleDarkMode() { document.documentElement.classList.toggle("dark-mode"); } @@ -766,24 +752,40 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) { // Build HTML for both endpoints and llama_server_endpoints let html = ""; - const renderRow = (e) => { - const statusClass = - e.status === "ok" ? "status-ok" : "status-error"; - const version = e.version || "N/A"; - const titleAttr = e.detail - ? ` title="${escapeHtml(e.detail)}"` - : ""; - return ` + // Add Ollama endpoints + html += data.endpoints + .map((e) => { + const statusClass = + e.status === "ok" + ? "status-ok" + : "status-error"; + const version = e.version || "N/A"; + return ` - ${escapeHtml(e.url)} - ${escapeHtml(e.status)} - ${escapeHtml(version)} + ${e.url} + ${e.status} + ${version} `; - }; - - html += data.endpoints.map(renderRow).join(""); + }) + .join(""); + + // Add llama-server endpoints if (data.llama_server_endpoints && data.llama_server_endpoints.length > 0) { - html += data.llama_server_endpoints.map(renderRow).join(""); + html += data.llama_server_endpoints + .map((e) => { + const statusClass = + e.status === "ok" + ? "status-ok" + : "status-error"; + const version = e.version || "N/A"; + return ` + + ${e.url} + ${e.status} + ${version} + `; + }) + .join(""); } body.innerHTML = html; diff --git a/test/config_test.yaml b/test/config_test.yaml deleted file mode 100644 index 30f2fa3..0000000 --- a/test/config_test.yaml +++ /dev/null @@ -1,13 +0,0 @@ -endpoints: - - http://192.168.0.51:12434 - -llama_server_endpoints: - - http://192.168.0.51:12434/v1 - -max_concurrent_connections: 2 - -api_keys: - "http://192.168.0.51:12434": "ollama" - "http://192.168.0.51:12434/v1": "llama" - -cache_enabled: false diff --git a/test/conftest.py b/test/conftest.py deleted file mode 100644 index c5142da..0000000 --- a/test/conftest.py +++ /dev/null @@ -1,233 +0,0 @@ -""" -Test configuration for nomyo-router. - -Run from project root: - pytest test/ -v - pytest test/ -m "not integration" # skip real-server tests - pytest test/ -m integration -v # only real-server tests - -Environment variables: - NOMYO_TEST_OLLAMA Ollama endpoint (default: http://192.168.0.50:12434) - NOMYO_TEST_LLAMA llama-server endpoint (default: http://192.168.0.50:12434/v1) - NOMYO_TEST_MODEL_CHAT chat model to use (auto-discovered if unset) - NOMYO_TEST_EMBED_MODEL embedding model (auto-discovered if unset) -""" -import asyncio -import os -import ssl -import sys -import tempfile -from pathlib import Path -from unittest.mock import MagicMock, patch - -import aiohttp -import httpx -import pytest - -_TEST_DIR = Path(__file__).parent -# Must be set before importing router so module-level Config.from_yaml + Config field -# defaults pick these up. db_path is intentionally absent from config_test.yaml so the -# env-var default wins — keeps tests portable across CI runners (Linux/macOS/Windows). -os.environ.setdefault("NOMYO_ROUTER_CONFIG_PATH", str(_TEST_DIR / "config_test.yaml")) -os.environ.setdefault( - "NOMYO_ROUTER_DB_PATH", - str(Path(tempfile.gettempdir()) / "nomyo_router_test_tokens.db"), -) - -sys.path.insert(0, str(_TEST_DIR.parent)) - -import router # noqa: E402 - -TEST_OLLAMA = os.getenv("NOMYO_TEST_OLLAMA", "http://192.168.0.51:12434") -TEST_LLAMA = os.getenv("NOMYO_TEST_LLAMA", "http://192.168.0.51:12434/v1") - - -def pytest_configure(config): - config.addinivalue_line( - "markers", - "integration: tests that require a real backend at 192.168.0.50:12434", - ) - - -# ── Config mocks ───────────────────────────────────────────────────────────── - -@pytest.fixture -def mock_config(): - """Minimal config pointing at TEST_OLLAMA / TEST_LLAMA.""" - cfg = MagicMock() - cfg.endpoints = [TEST_OLLAMA] - cfg.llama_server_endpoints = [TEST_LLAMA] - cfg.api_keys = {TEST_OLLAMA: "ollama", TEST_LLAMA: "llama"} - cfg.max_concurrent_connections = 2 - cfg.router_api_key = None - cfg.cache_enabled = False - return cfg - - -@pytest.fixture -def mock_config_no_llama(): - """Config with Ollama only, no llama-server.""" - cfg = MagicMock() - cfg.endpoints = [TEST_OLLAMA] - cfg.llama_server_endpoints = [] - cfg.api_keys = {TEST_OLLAMA: "ollama"} - cfg.max_concurrent_connections = 2 - cfg.router_api_key = None - cfg.cache_enabled = False - return cfg - - -@pytest.fixture -def mock_config_with_key(): - """Config with router_api_key set (enables auth middleware).""" - cfg = MagicMock() - cfg.endpoints = [TEST_OLLAMA] - cfg.llama_server_endpoints = [] - cfg.api_keys = {} - cfg.max_concurrent_connections = 2 - cfg.router_api_key = "test-secret-key" - cfg.cache_enabled = False - return cfg - - -# ── aiohttp session (used by fetch tests + choose_endpoint tests) ───────────── - -@pytest.fixture -async def aio_session(): - """Real aiohttp session stored in app_state; intercepted by aioresponses.""" - ssl_ctx = ssl.create_default_context() - conn = aiohttp.TCPConnector(ssl=ssl_ctx) - session = aiohttp.ClientSession(connector=conn) - router.app_state["session"] = session - - # Clear caches to prevent test bleed - router._models_cache.clear() - router._loaded_models_cache.clear() - router._available_error_cache.clear() - router._loaded_error_cache.clear() - router._inflight_available_models.clear() - router._inflight_loaded_models.clear() - router._bg_refresh_available.clear() - router._bg_refresh_loaded.clear() - - yield session - - await session.close() - router.app_state["session"] = None - - -# ── Validation-only HTTP client (no real backend needed) ────────────────────── - -@pytest.fixture -async def client(mock_config, tmp_path): - """httpx client for validation/auth tests — no real backend calls made.""" - from db import TokenDatabase - - ssl_ctx = ssl.create_default_context() - conn = aiohttp.TCPConnector(ssl=ssl_ctx) - session = aiohttp.ClientSession(connector=conn) - - db_inst = TokenDatabase(str(tmp_path / "test.db")) - await db_inst.init_db() - - old_session = router.app_state.get("session") - old_db = router.db - - router.app_state["session"] = session - router.db = db_inst - - with patch.object(router, "config", mock_config): - transport = httpx.ASGITransport(app=router.app) - async with httpx.AsyncClient( - transport=transport, base_url="http://test", timeout=10.0 - ) as c: - yield c - - await session.close() - router.app_state["session"] = old_session - router.db = old_db - - -@pytest.fixture -async def client_auth(mock_config_with_key, tmp_path): - """httpx client with router_api_key configured (for auth middleware tests).""" - from db import TokenDatabase - - ssl_ctx = ssl.create_default_context() - conn = aiohttp.TCPConnector(ssl=ssl_ctx) - session = aiohttp.ClientSession(connector=conn) - - db_inst = TokenDatabase(str(tmp_path / "test_auth.db")) - await db_inst.init_db() - - old_session = router.app_state.get("session") - old_db = router.db - - router.app_state["session"] = session - router.db = db_inst - - with patch.object(router, "config", mock_config_with_key): - transport = httpx.ASGITransport(app=router.app) - async with httpx.AsyncClient( - transport=transport, base_url="http://test", timeout=10.0 - ) as c: - yield c - - await session.close() - router.app_state["session"] = old_session - router.db = old_db - - -# ── Integration client (full startup with real backend) ────────────────────── - -@pytest.fixture(scope="module") -async def integration_client(): - """Full app startup pointing at the real test server.""" - await router.startup_event() - transport = httpx.ASGITransport(app=router.app) - async with httpx.AsyncClient( - transport=transport, - base_url="http://test", - timeout=httpx.Timeout(60.0), - ) as c: - yield c - await router.shutdown_event() - - -# ── Model discovery fixtures ────────────────────────────────────────────────── - -@pytest.fixture(scope="module") -async def chat_model(integration_client): - """Return a chat/generation model name available on the test server.""" - env_model = os.getenv("NOMYO_TEST_MODEL_CHAT") - if env_model: - return env_model - resp = await integration_client.get("/api/tags") - if resp.status_code != 200: - pytest.skip("Cannot reach test server") - models = resp.json().get("models", []) - # Prefer small models for faster tests - for m in models: - name = m.get("name", "") - if any(x in name.lower() for x in ["0.5b", "1b", "3b", "1.5b", "2b"]): - return name - if models: - return models[0]["name"] - pytest.skip("No chat models available on test server") - - -@pytest.fixture(scope="module") -async def embed_model(integration_client): - """Return an embedding model name available on the test server.""" - env_model = os.getenv("NOMYO_TEST_EMBED_MODEL") - if env_model: - return env_model - resp = await integration_client.get("/api/tags") - if resp.status_code != 200: - pytest.skip("Cannot reach test server") - models = resp.json().get("models", []) - for m in models: - name = m.get("name", "") - if any(x in name.lower() for x in ["embed", "nomic", "minilm", "bge", "e5"]): - return name - pytest.skip("No embedding model available on test server") diff --git a/test/pytest.ini b/test/pytest.ini deleted file mode 100644 index 1d05e6d..0000000 --- a/test/pytest.ini +++ /dev/null @@ -1,7 +0,0 @@ -[pytest] -asyncio_mode = auto -markers = - integration: tests that require a real backend at 192.168.0.51:12434 -testpaths = . -filterwarnings = - ignore::pytest.PytestUnhandledThreadExceptionWarning diff --git a/test/requirements_test.txt b/test/requirements_test.txt deleted file mode 100644 index 8c7c53f..0000000 --- a/test/requirements_test.txt +++ /dev/null @@ -1,4 +0,0 @@ -pytest>=8.0 -pytest-asyncio>=0.24 -pytest-cov>=5.0 -aioresponses>=0.7 diff --git a/test/test.md b/test/test.md deleted file mode 100644 index 533a3ee..0000000 --- a/test/test.md +++ /dev/null @@ -1,60 +0,0 @@ -# Testing nomyo-router - -## Setup - -Install test dependencies (from the project root): - -```bash -pip install -r test/requirements_test.txt -``` - -## Running tests - -All commands run from the `test/` directory: - -```bash -cd test -``` - -**All non-integration tests** (no backend required): -```bash -pytest -m "not integration" -v -``` - -**Integration tests only** (requires backend at `192.168.0.51:12434`): -```bash -pytest -m integration -v -``` - -**Everything:** -```bash -pytest -v -``` - -## Test structure - -| File | What it covers | Backend needed | -|---|---|---| -| `test_unit_helpers.py` | Pure helper functions (`_mask_secrets`, `_is_fresh`, `ep2base`, etc.) | No | -| `test_unit_transforms.py` | Message transform functions (tool calls, image stripping, etc.) | No | -| `test_unit_context.py` | Context window trimming logic | No | -| `test_fetch.py` | `fetch.available_models` / `fetch.loaded_models` with mocked HTTP | No | -| `test_choose_endpoint.py` | `choose_endpoint` routing logic with mocked fetch layer | No | -| `test_api_validation.py` | HTTP 400/401/403 validation and auth middleware (in-process app) | No | -| `test_api_integration.py` | Full request/response against a real Ollama/llama-server backend | **Yes** | - -## Integration test backend - -Integration tests start the router in-process via `startup_event()` and route traffic -through `httpx.ASGITransport` — no separately running router instance is needed. - -They do require a reachable Ollama or llama-server backend. Override the defaults via -environment variables: - -```bash -export NOMYO_TEST_OLLAMA=http://192.168.0.51:12434 -export NOMYO_TEST_EMBED_MODEL=nomic-embed-text # optional, auto-discovered otherwise -export NOMYO_TEST_MODEL_CHAT=llama3.2 # optional, auto-discovered otherwise -``` - -If the backend is unreachable, integration tests are automatically skipped. diff --git a/test/test_api_integration.py b/test/test_api_integration.py deleted file mode 100644 index 6c40fdc..0000000 --- a/test/test_api_integration.py +++ /dev/null @@ -1,304 +0,0 @@ -""" -Integration tests against the real backend at 192.168.0.50:12434. - -Run with: - pytest test/test_api_integration.py -v -m integration - -All tests in this file are marked @pytest.mark.integration. -They require the test server to be reachable and to have at least one -chat model and one embedding model available. - -Env vars to pin specific models: - NOMYO_TEST_MODEL_CHAT e.g. qwen2.5:1.5b - NOMYO_TEST_EMBED_MODEL e.g. nomic-embed-text:latest -""" -import json - -import pytest - - -pytestmark = pytest.mark.integration - - -# ── Health / discovery routes ───────────────────────────────────────────────── - -class TestDiscoveryRoutes: - async def test_version(self, integration_client): - resp = await integration_client.get("/api/version") - assert resp.status_code == 200 - data = resp.json() - assert "version" in data - assert isinstance(data["version"], str) - - async def test_tags_returns_models(self, integration_client): - resp = await integration_client.get("/api/tags") - assert resp.status_code == 200 - data = resp.json() - assert "models" in data - assert isinstance(data["models"], list) - assert len(data["models"]) > 0 - - async def test_ps_returns_list(self, integration_client): - resp = await integration_client.get("/api/ps") - assert resp.status_code == 200 - data = resp.json() - assert "models" in data - assert isinstance(data["models"], list) - - async def test_v1_models_returns_data(self, integration_client): - resp = await integration_client.get("/v1/models") - assert resp.status_code == 200 - data = resp.json() - assert "data" in data - assert isinstance(data["data"], list) - - async def test_usage_returns_counts(self, integration_client): - resp = await integration_client.get("/api/usage") - assert resp.status_code == 200 - data = resp.json() - assert "usage_counts" in data - assert "token_usage_counts" in data - - async def test_config_returns_endpoints(self, integration_client): - resp = await integration_client.get("/api/config") - assert resp.status_code == 200 - data = resp.json() - assert "endpoints" in data - - async def test_hostname(self, integration_client): - resp = await integration_client.get("/api/hostname") - assert resp.status_code == 200 - assert "hostname" in resp.json() - - async def test_health(self, integration_client): - resp = await integration_client.get("/health") - assert resp.status_code in (200, 503) - data = resp.json() - assert data["status"] in ("ok", "error") - assert "endpoints" in data - - async def test_cache_stats(self, integration_client): - resp = await integration_client.get("/api/cache/stats") - assert resp.status_code == 200 - data = resp.json() - assert "enabled" in data - - -# ── /api/chat ───────────────────────────────────────────────────────────────── - -class TestApiChat: - async def test_non_streaming(self, integration_client, chat_model): - resp = await integration_client.post( - "/api/chat", - json={ - "model": chat_model, - "stream": False, - "messages": [{"role": "user", "content": "Reply with exactly: OK"}], - "options": {"num_predict": 10}, - }, - ) - assert resp.status_code == 200 - data = resp.json() - assert "message" in data - assert "content" in data["message"] - - async def test_streaming_ndjson(self, integration_client, chat_model): - resp = await integration_client.post( - "/api/chat", - json={ - "model": chat_model, - "stream": True, - "messages": [{"role": "user", "content": "Say hi"}], - "options": {"num_predict": 5}, - }, - ) - assert resp.status_code == 200 - lines = [l for l in resp.text.strip().split("\n") if l.strip()] - assert len(lines) >= 1 - for line in lines: - obj = json.loads(line) - assert "model" in obj - - async def test_non_streaming_has_token_counts(self, integration_client, chat_model): - resp = await integration_client.post( - "/api/chat", - json={ - "model": chat_model, - "stream": False, - "messages": [{"role": "user", "content": "Count to 3"}], - "options": {"num_predict": 20}, - }, - ) - assert resp.status_code == 200 - data = resp.json() - assert data.get("done") is True - # Token counts should be present in the final chunk - assert data.get("prompt_eval_count", 0) >= 0 - - async def test_system_message_honoured(self, integration_client, chat_model): - resp = await integration_client.post( - "/api/chat", - json={ - "model": chat_model, - "stream": False, - "messages": [ - {"role": "system", "content": "You are a helpful assistant. Always reply with exactly: PONG"}, - {"role": "user", "content": "PING"}, - ], - "options": {"num_predict": 10}, - }, - ) - assert resp.status_code == 200 - content = resp.json()["message"]["content"] - assert isinstance(content, str) - assert len(content) > 0 - - -# ── /api/generate ───────────────────────────────────────────────────────────── - -class TestApiGenerate: - async def test_non_streaming(self, integration_client, chat_model): - resp = await integration_client.post( - "/api/generate", - json={ - "model": chat_model, - "prompt": "Complete: The sky is", - "stream": False, - "options": {"num_predict": 5}, - }, - ) - assert resp.status_code == 200 - data = resp.json() - assert "response" in data - - async def test_streaming(self, integration_client, chat_model): - resp = await integration_client.post( - "/api/generate", - json={ - "model": chat_model, - "prompt": "One plus one equals", - "stream": True, - "options": {"num_predict": 5}, - }, - ) - assert resp.status_code == 200 - lines = [l for l in resp.text.strip().split("\n") if l.strip()] - assert len(lines) >= 1 - - -# ── /api/embed ──────────────────────────────────────────────────────────────── - -class TestApiEmbed: - async def test_embed_single_string(self, integration_client, embed_model): - resp = await integration_client.post( - "/api/embed", - json={"model": embed_model, "input": "The quick brown fox"}, - ) - assert resp.status_code == 200 - data = resp.json() - assert "embeddings" in data - assert isinstance(data["embeddings"], list) - assert len(data["embeddings"]) == 1 - assert len(data["embeddings"][0]) > 0 - - async def test_embed_multiple_inputs(self, integration_client, embed_model): - resp = await integration_client.post( - "/api/embed", - json={"model": embed_model, "input": ["sentence one", "sentence two"]}, - ) - assert resp.status_code == 200 - data = resp.json() - assert "embeddings" in data - assert len(data["embeddings"]) == 2 - - -# ── /v1/chat/completions ────────────────────────────────────────────────────── - -class TestOpenAIChatCompletions: - async def test_non_streaming(self, integration_client, chat_model): - model = chat_model.replace(":latest", "") if ":latest" in chat_model else chat_model - resp = await integration_client.post( - "/v1/chat/completions", - json={ - "model": model, - "messages": [{"role": "user", "content": "Reply OK"}], - "max_tokens": 10, - "stream": False, - }, - ) - assert resp.status_code == 200 - data = resp.json() - assert "choices" in data - assert len(data["choices"]) > 0 - assert "message" in data["choices"][0] - - async def test_streaming_sse(self, integration_client, chat_model): - model = chat_model.replace(":latest", "") if ":latest" in chat_model else chat_model - resp = await integration_client.post( - "/v1/chat/completions", - json={ - "model": model, - "messages": [{"role": "user", "content": "Hi"}], - "max_tokens": 5, - "stream": True, - }, - ) - assert resp.status_code == 200 - # Response should be SSE format - assert "data:" in resp.text or "[DONE]" in resp.text - - async def test_non_streaming_has_usage(self, integration_client, chat_model): - model = chat_model.replace(":latest", "") if ":latest" in chat_model else chat_model - resp = await integration_client.post( - "/v1/chat/completions", - json={ - "model": model, - "messages": [{"role": "user", "content": "Say yes"}], - "max_tokens": 5, - "stream": False, - }, - ) - assert resp.status_code == 200 - data = resp.json() - if "usage" in data and data["usage"]: - assert data["usage"].get("prompt_tokens", 0) >= 0 - - -# ── /v1/embeddings ──────────────────────────────────────────────────────────── - -class TestOpenAIEmbeddings: - async def test_single_input(self, integration_client, embed_model): - model = embed_model.replace(":latest", "") if ":latest" in embed_model else embed_model - resp = await integration_client.post( - "/v1/embeddings", - json={"model": model, "input": "Test sentence"}, - ) - assert resp.status_code == 200 - data = resp.json() - assert "data" in data - assert len(data["data"]) > 0 - embedding = data["data"][0].get("embedding") - assert isinstance(embedding, list) - assert len(embedding) > 0 - - -# ── Token counts (database-backed) ─────────────────────────────────────────── - -class TestTokenCounts: - async def test_token_counts_endpoint(self, integration_client): - resp = await integration_client.get("/api/token_counts") - assert resp.status_code == 200 - data = resp.json() - assert "total_tokens" in data - assert "breakdown" in data - - -# ── ps_details (extended ps) ───────────────────────────────────────────────── - -class TestPsDetails: - async def test_ps_details_returns_models(self, integration_client): - resp = await integration_client.get("/api/ps_details") - assert resp.status_code == 200 - data = resp.json() - assert "models" in data - assert isinstance(data["models"], list) diff --git a/test/test_api_validation.py b/test/test_api_validation.py deleted file mode 100644 index 5d2b52d..0000000 --- a/test/test_api_validation.py +++ /dev/null @@ -1,230 +0,0 @@ -""" -HTTP-level validation and auth middleware tests. - -These tests use an in-process httpx client and never reach a real backend: -all requests are rejected at the validation or auth layer before any -endpoint-selection or upstream HTTP calls occur. -""" -import pytest - - -class TestChatValidation: - async def test_missing_model_returns_400(self, client): - resp = await client.post( - "/api/chat", - json={"messages": [{"role": "user", "content": "hello"}]}, - ) - assert resp.status_code == 400 - assert "model" in resp.json()["detail"].lower() - - async def test_missing_messages_returns_400(self, client): - resp = await client.post("/api/chat", json={"model": "llama3.2"}) - assert resp.status_code == 400 - - async def test_invalid_json_returns_400(self, client): - resp = await client.post( - "/api/chat", - content=b"not-json", - headers={"Content-Type": "application/json"}, - ) - assert resp.status_code == 400 - - async def test_messages_not_list_returns_400(self, client): - resp = await client.post( - "/api/chat", - json={"model": "m", "messages": "not-a-list"}, - ) - assert resp.status_code == 400 - - async def test_options_not_dict_returns_400(self, client): - resp = await client.post( - "/api/chat", - json={"model": "m", "messages": [{"role": "user", "content": "hi"}], "options": "bad"}, - ) - assert resp.status_code == 400 - - -class TestGenerateValidation: - async def test_missing_model_returns_400(self, client): - resp = await client.post("/api/generate", json={"prompt": "hello"}) - assert resp.status_code == 400 - assert "model" in resp.json()["detail"].lower() - - async def test_missing_prompt_returns_400(self, client): - resp = await client.post("/api/generate", json={"model": "m"}) - assert resp.status_code == 400 - assert "prompt" in resp.json()["detail"].lower() - - async def test_invalid_json_returns_400(self, client): - resp = await client.post( - "/api/generate", - content=b"{bad-json", - headers={"Content-Type": "application/json"}, - ) - assert resp.status_code == 400 - - -class TestEmbedValidation: - async def test_missing_model_returns_400(self, client): - resp = await client.post("/api/embed", json={"input": "hello"}) - assert resp.status_code == 400 - - async def test_missing_input_returns_400(self, client): - resp = await client.post("/api/embed", json={"model": "nomic-embed-text"}) - assert resp.status_code == 400 - - -class TestEmbeddingsValidation: - async def test_missing_model_returns_400(self, client): - resp = await client.post("/api/embeddings", json={"prompt": "hello"}) - assert resp.status_code == 400 - - async def test_missing_prompt_returns_400(self, client): - resp = await client.post("/api/embeddings", json={"model": "nomic-embed-text"}) - assert resp.status_code == 400 - - -class TestOpenAIChatValidation: - async def test_missing_model_returns_400(self, client): - resp = await client.post( - "/v1/chat/completions", - json={"messages": [{"role": "user", "content": "hello"}]}, - ) - assert resp.status_code == 400 - - async def test_missing_messages_returns_400(self, client): - resp = await client.post( - "/v1/chat/completions", - json={"model": "gpt-4o"}, - ) - assert resp.status_code == 400 - - async def test_invalid_json_returns_400(self, client): - resp = await client.post( - "/v1/chat/completions", - content=b"}{", - headers={"Content-Type": "application/json"}, - ) - assert resp.status_code == 400 - - async def test_svg_image_rejected(self, client): - resp = await client.post( - "/v1/chat/completions", - json={ - "model": "vision-model", - "messages": [{ - "role": "user", - "content": [ - {"type": "text", "text": "describe"}, - {"type": "image_url", "image_url": {"url": "data:image/svg+xml;base64,abc"}}, - ], - }], - }, - ) - assert resp.status_code == 400 - assert "svg" in resp.json()["detail"].lower() - - -class TestOpenAICompletionsValidation: - async def test_missing_model_returns_400(self, client): - resp = await client.post("/v1/completions", json={"prompt": "hello"}) - assert resp.status_code == 400 - - async def test_missing_prompt_returns_400(self, client): - resp = await client.post("/v1/completions", json={"model": "m"}) - assert resp.status_code == 400 - - -class TestRerankValidation: - async def test_missing_model_returns_400(self, client): - resp = await client.post( - "/v1/rerank", - json={"query": "search query", "documents": ["doc1"]}, - ) - assert resp.status_code == 400 - - async def test_missing_query_returns_400(self, client): - resp = await client.post( - "/v1/rerank", - json={"model": "reranker", "documents": ["doc1"]}, - ) - assert resp.status_code == 400 - - async def test_empty_documents_returns_400(self, client): - resp = await client.post( - "/v1/rerank", - json={"model": "reranker", "query": "search", "documents": []}, - ) - assert resp.status_code == 400 - - -class TestShowValidation: - async def test_missing_model_returns_400(self, client): - resp = await client.post("/api/show", json={}) - assert resp.status_code == 400 - - -class TestCopyValidation: - async def test_missing_source_returns_400(self, client): - resp = await client.post("/api/copy", json={"destination": "dst"}) - assert resp.status_code == 400 - - async def test_missing_destination_returns_400(self, client): - resp = await client.post("/api/copy", json={"source": "src"}) - assert resp.status_code == 400 - - -class TestDeleteValidation: - async def test_missing_model_returns_400(self, client): - import json as _json - resp = await client.request( - "DELETE", - "/api/delete", - content=_json.dumps({}).encode(), - headers={"Content-Type": "application/json"}, - ) - assert resp.status_code == 400 - - -class TestAuthMiddleware: - async def test_no_key_returns_401(self, client_auth): - resp = await client_auth.post( - "/api/chat", - json={"model": "m", "messages": [{"role": "user", "content": "hi"}]}, - ) - assert resp.status_code == 401 - assert "Missing" in resp.json()["detail"] - - async def test_invalid_key_returns_403(self, client_auth): - resp = await client_auth.post( - "/api/chat", - headers={"Authorization": "Bearer wrong-key"}, - json={"model": "m", "messages": [{"role": "user", "content": "hi"}]}, - ) - assert resp.status_code == 403 - assert "Invalid" in resp.json()["detail"] - - async def test_valid_key_passes_middleware(self, client_auth): - # /api/usage reads in-memory counters only — no backend call needed - resp = await client_auth.get( - "/api/usage", - headers={"Authorization": "Bearer test-secret-key"}, - ) - assert resp.status_code == 200 - - async def test_key_via_query_param(self, client_auth): - resp = await client_auth.get("/api/usage?api_key=test-secret-key") - assert resp.status_code == 200 - - async def test_options_bypasses_auth(self, client_auth): - resp = await client_auth.options("/api/chat") - assert resp.status_code not in (401, 403) - - async def test_root_path_bypasses_auth(self, client_auth): - resp = await client_auth.get("/") - assert resp.status_code not in (401, 403) - - async def test_favicon_bypasses_auth(self, client_auth): - resp = await client_auth.get("/favicon.ico") - # Should not be blocked by auth (may 404 in test but not 401/403) - assert resp.status_code not in (401, 403) diff --git a/test/test_cache.py b/test/test_cache.py deleted file mode 100644 index f2ce1a9..0000000 --- a/test/test_cache.py +++ /dev/null @@ -1,333 +0,0 @@ -"""Unit tests for cache.LLMCache in exact-match mode (no sentence-transformers needed).""" -import tempfile -from pathlib import Path -from types import SimpleNamespace - -import orjson -import pytest - -import cache as cache_mod -from cache import ( - LLMCache, - _bm25_weighted_text, - get_llm_cache, - init_llm_cache, - openai_nonstream_to_sse, -) - -_CACHE_DB_PATH = str(Path(tempfile.gettempdir()) / "nomyo_test_cache.db") - - -def _exact_cfg(backend: str = "memory") -> SimpleNamespace: - """Config for exact-match mode — similarity=1.0 avoids embedding deps.""" - return SimpleNamespace( - cache_enabled=True, - cache_backend=backend, - cache_similarity=1.0, - cache_history_weight=0.3, - cache_ttl=300, - cache_db_path=_CACHE_DB_PATH, - cache_redis_url="redis://localhost:6379", - ) - - -# ────────────────────────────────────────────────────────────────────────────── -# Pure helpers -# ────────────────────────────────────────────────────────────────────────────── - -class TestBM25WeightedText: - def test_empty_history(self): - assert _bm25_weighted_text([]) == "" - - def test_history_without_content(self): - assert _bm25_weighted_text([{"role": "user"}, {"role": "assistant"}]) == "" - - def test_repeats_high_idf_terms(self): - history = [ - {"role": "user", "content": "Tell me about quantum entanglement"}, - {"role": "assistant", "content": "Quantum entanglement is a phenomenon"}, - {"role": "user", "content": "How does entanglement work?"}, - ] - out = _bm25_weighted_text(history) - # Rare/domain term ("entanglement") should appear; short stopwords (<=2 chars) dropped - assert "entanglement" in out - assert "is" not in out.split() - - -# ────────────────────────────────────────────────────────────────────────────── -# openai_nonstream_to_sse -# ────────────────────────────────────────────────────────────────────────────── - -class TestOpenAINonstreamToSSE: - def test_valid_chat_completion(self): - chat = { - "id": "x1", - "created": 123, - "model": "gpt-4o", - "choices": [{"message": {"role": "assistant", "content": "hello"}}], - "usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}, - } - out = openai_nonstream_to_sse(orjson.dumps(chat), "gpt-4o") - text = out.decode() - assert text.startswith("data: ") - assert text.endswith("data: [DONE]\n\n") - # First chunk contains the original content - first = text.split("\n\n")[0][len("data: "):] - parsed = orjson.loads(first) - assert parsed["choices"][0]["delta"]["content"] == "hello" - assert parsed["usage"]["total_tokens"] == 3 - - def test_corrupt_bytes_return_done_only(self): - out = openai_nonstream_to_sse(b"not-json", "m") - assert out == b"data: [DONE]\n\n" - - -# ────────────────────────────────────────────────────────────────────────────── -# LLMCache internal helpers -# ────────────────────────────────────────────────────────────────────────────── - -class TestLLMCacheParsing: - def test_namespace_is_stable_and_isolated(self): - c = LLMCache(_exact_cfg()) - a = c._namespace("chat", "m1", "system A") - b = c._namespace("chat", "m1", "system A") - assert a == b - assert c._namespace("chat", "m1", "system B") != a - assert c._namespace("generate", "m1", "system A") != a - assert len(a) == 16 - - def test_parse_messages_flat_strings(self): - c = LLMCache(_exact_cfg()) - sys, hist, last = c._parse_messages([ - {"role": "system", "content": "be helpful"}, - {"role": "user", "content": "hi"}, - {"role": "assistant", "content": "hello"}, - {"role": "user", "content": "what is 2+2?"}, - ]) - assert sys == "be helpful" - assert last == "what is 2+2?" - assert hist == [ - {"role": "user", "content": "hi"}, - {"role": "assistant", "content": "hello"}, - ] - - def test_parse_messages_multimodal_content(self): - c = LLMCache(_exact_cfg()) - sys, _hist, last = c._parse_messages([ - {"role": "system", "content": "sys"}, - {"role": "user", "content": [ - {"type": "text", "text": "describe"}, - {"type": "image_url", "image_url": {"url": "data:..."}}, - ]}, - ]) - assert sys == "sys" - assert last == "describe" - - def test_parse_messages_no_user_message(self): - c = LLMCache(_exact_cfg()) - sys, hist, last = c._parse_messages([ - {"role": "system", "content": "sys only"}, - ]) - assert sys == "sys only" - assert last == "" - assert hist == [] - - -class TestPersonalTokenExtraction: - def test_email_extracted(self): - c = LLMCache(_exact_cfg()) - toks = c._extract_personal_tokens("Reach me at alice@example.com please") - assert "alice@example.com" in toks - - def test_numeric_id_after_keyword(self): - c = LLMCache(_exact_cfg()) - toks = c._extract_personal_tokens("User id: 123456") - assert "123456" in toks - - def test_identity_tag_names_extracted(self): - c = LLMCache(_exact_cfg()) - toks = c._extract_personal_tokens( - "[Tags: identity] User's name is Andreas Schwibbe" - ) - # Both name tokens should be extracted lowercased; stopwords dropped - assert "andreas" in toks - assert "schwibbe" in toks - assert "name" not in toks # in _IDENTITY_STOPWORDS - assert "user" not in toks - - def test_empty_system_returns_empty_set(self): - c = LLMCache(_exact_cfg()) - assert c._extract_personal_tokens("") == frozenset() - - -class TestResponseIsPersonalized: - def _resp(self, content: str) -> bytes: - return orjson.dumps({"choices": [{"message": {"content": content}}]}) - - def test_email_in_response_is_personalized(self): - c = LLMCache(_exact_cfg()) - assert c._response_is_personalized(self._resp("contact bob@x.com"), "") - - def test_uuid_in_response_is_personalized(self): - c = LLMCache(_exact_cfg()) - uuid = "550e8400-e29b-41d4-a716-446655440000" - assert c._response_is_personalized(self._resp(f"id={uuid}"), "") - - def test_long_numeric_id_in_response_is_personalized(self): - c = LLMCache(_exact_cfg()) - assert c._response_is_personalized(self._resp("account 12345678"), "") - - def test_identity_token_from_system_echoed_in_response(self): - c = LLMCache(_exact_cfg()) - system = "[Tags: identity] Andreas works here" - assert c._response_is_personalized( - self._resp("Yes, Andreas is logged in"), system - ) - - def test_generic_response_not_personalized(self): - c = LLMCache(_exact_cfg()) - assert not c._response_is_personalized( - self._resp("The capital of France is Paris."), "be helpful" - ) - - def test_ollama_message_format_parsed(self): - c = LLMCache(_exact_cfg()) - body = orjson.dumps({"message": {"content": "alice@example.com"}}) - assert c._response_is_personalized(body, "") - - def test_unparseable_body_with_bytes_is_conservative(self): - c = LLMCache(_exact_cfg()) - # Can't parse → returns True (err on the side of privacy) - assert c._response_is_personalized(b"binary-junk", "") - - def test_empty_response_not_personalized(self): - c = LLMCache(_exact_cfg()) - assert not c._response_is_personalized(b"", "anything") - - -# ────────────────────────────────────────────────────────────────────────────── -# End-to-end exact-match cache with the memory backend -# ────────────────────────────────────────────────────────────────────────────── - -@pytest.fixture -async def memcache(): - """LLMCache wired up with the in-memory backend (no external deps).""" - c = LLMCache(_exact_cfg("memory")) - await c.init() - return c - - -class TestExactMatchCache: - async def test_miss_then_set_then_hit(self, memcache): - msgs = [ - {"role": "system", "content": "be helpful"}, - {"role": "user", "content": "what is 2+2?"}, - ] - resp = orjson.dumps({"choices": [{"message": {"content": "4"}}]}) - - assert await memcache.get_chat("chat", "m1", msgs) is None - await memcache.set_chat("chat", "m1", msgs, resp) - hit = await memcache.get_chat("chat", "m1", msgs) - assert hit == resp - - async def test_namespace_isolation_by_system(self, memcache): - resp = orjson.dumps({"choices": [{"message": {"content": "ok"}}]}) - msgs_a = [ - {"role": "system", "content": "system A"}, - {"role": "user", "content": "same question"}, - ] - msgs_b = [ - {"role": "system", "content": "system B"}, - {"role": "user", "content": "same question"}, - ] - await memcache.set_chat("chat", "m", msgs_a, resp) - # Same question + different system prompt = different namespace = miss - assert await memcache.get_chat("chat", "m", msgs_b) is None - - async def test_namespace_isolation_by_route(self, memcache): - resp = orjson.dumps({"choices": [{"message": {"content": "ok"}}]}) - msgs = [{"role": "user", "content": "ping"}] - await memcache.set_chat("chat", "m", msgs, resp) - assert await memcache.get_chat("openai_chat", "m", msgs) is None - - async def test_no_user_message_is_noop(self, memcache): - msgs = [{"role": "system", "content": "sys only"}] - resp = orjson.dumps({"choices": [{"message": {"content": "x"}}]}) - # Both get and set should silently no-op - assert await memcache.get_chat("chat", "m", msgs) is None - await memcache.set_chat("chat", "m", msgs, resp) - assert await memcache.get_chat("chat", "m", msgs) is None - - async def test_personalized_response_generic_system_not_stored(self, memcache): - msgs = [ - {"role": "system", "content": "be helpful"}, # generic - {"role": "user", "content": "give me an email"}, - ] - # Response contains an email → would leak across users sharing the - # generic namespace → must NOT be stored at all - resp = orjson.dumps({"choices": [{"message": {"content": "bob@x.com"}}]}) - await memcache.set_chat("chat", "m", msgs, resp) - assert await memcache.get_chat("chat", "m", msgs) is None - - async def test_personalized_response_user_specific_system_stored(self, memcache): - msgs = [ - {"role": "system", "content": "User id: 998877 prefers concise answers"}, - {"role": "user", "content": "what is my id?"}, - ] - resp = orjson.dumps({"choices": [{"message": {"content": "Your id is 998877"}}]}) - await memcache.set_chat("chat", "m", msgs, resp) - # User-specific namespace → exact-match within this user is OK - assert await memcache.get_chat("chat", "m", msgs) == resp - - async def test_generate_convenience_wrappers(self, memcache): - resp = orjson.dumps({"response": "blue"}) - await memcache.set_generate("m", "what color is the sky?", "", resp) - assert await memcache.get_generate("m", "what color is the sky?") == resp - - -class TestStatsAndClear: - async def test_stats_tracks_hits_and_misses(self, memcache): - msgs = [{"role": "user", "content": "hello"}] - await memcache.get_chat("chat", "m", msgs) # miss - resp = orjson.dumps({"choices": [{"message": {"content": "hi"}}]}) - await memcache.set_chat("chat", "m", msgs, resp) - await memcache.get_chat("chat", "m", msgs) # hit - s = memcache.stats() - assert s["hits"] == 1 - assert s["misses"] == 1 - assert s["hit_rate"] == 0.5 - assert s["semantic"] is False - assert s["backend"] == "memory" - - async def test_clear_resets_counters_and_storage(self, memcache): - msgs = [{"role": "user", "content": "hi"}] - resp = orjson.dumps({"choices": [{"message": {"content": "ok"}}]}) - await memcache.set_chat("chat", "m", msgs, resp) - await memcache.get_chat("chat", "m", msgs) - await memcache.clear() - s = memcache.stats() - assert s["hits"] == 0 - assert s["misses"] == 0 - assert await memcache.get_chat("chat", "m", msgs) is None - - -# ────────────────────────────────────────────────────────────────────────────── -# Module-level helpers -# ────────────────────────────────────────────────────────────────────────────── - -class TestInitLLMCache: - async def test_disabled_returns_none(self): - cfg = _exact_cfg() - cfg.cache_enabled = False - result = await init_llm_cache(cfg) - assert result is None - - async def test_enabled_returns_initialized_cache(self): - cfg = _exact_cfg() - try: - result = await init_llm_cache(cfg) - assert result is not None - assert get_llm_cache() is result - finally: - # Reset singleton between tests - cache_mod._cache = None diff --git a/test/test_choose_endpoint.py b/test/test_choose_endpoint.py deleted file mode 100644 index ece609a..0000000 --- a/test/test_choose_endpoint.py +++ /dev/null @@ -1,399 +0,0 @@ -"""Tests for choose_endpoint routing logic with mocked fetch calls.""" -import time -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -import router - -EP1 = "http://ep1:11434" -EP2 = "http://ep2:11434" -EP3 = "http://ep3:11434" -LLAMA_EP = "http://llama:8080/v1" - - -def _make_cfg(endpoints, llama_eps=None, max_conn=2, endpoint_config=None, priority_routing=False): - cfg = MagicMock() - cfg.endpoints = endpoints - cfg.llama_server_endpoints = llama_eps or [] - cfg.api_keys = {} - cfg.max_concurrent_connections = max_conn - cfg.endpoint_config = endpoint_config or {} - cfg.priority_routing = priority_routing - cfg.router_api_key = None - return cfg - - -@pytest.fixture(autouse=True) -def reset_usage(): - """Clear usage_counts and error caches between tests to prevent bleed.""" - router.usage_counts.clear() - router._loaded_error_cache.clear() - yield - router.usage_counts.clear() - router._loaded_error_cache.clear() - - -class TestChooseEndpointBasic: - async def test_selects_single_candidate(self): - cfg = _make_cfg([EP1]) - with ( - patch.object(router, "config", cfg), - patch.object(router.fetch, "available_models", AsyncMock(return_value={"llama3.2:latest"})), - patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"llama3.2:latest"})), - ): - ep, tracking = await router.choose_endpoint("llama3.2:latest") - assert ep == EP1 - assert tracking == "llama3.2:latest" - - async def test_raises_when_no_endpoint_has_model(self): - cfg = _make_cfg([EP1, EP2]) - with ( - patch.object(router, "config", cfg), - patch.object(router.fetch, "available_models", AsyncMock(return_value=set())), - patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())), - ): - with pytest.raises(RuntimeError, match="advertise the model"): - await router.choose_endpoint("unknown-model:latest") - - async def test_prefers_loaded_endpoint(self): - cfg = _make_cfg([EP1, EP2]) - async def available(ep, *_): - return {"llama3.2:latest"} - - async def loaded(ep): - return {"llama3.2:latest"} if ep == EP2 else set() - - with ( - patch.object(router, "config", cfg), - patch.object(router.fetch, "available_models", side_effect=available), - patch.object(router.fetch, "loaded_models", side_effect=loaded), - ): - ep, _ = await router.choose_endpoint("llama3.2:latest") - assert ep == EP2 - - async def test_falls_back_to_free_slot(self): - cfg = _make_cfg([EP1, EP2]) - async def available(ep, *_): - return {"llama3.2:latest"} - - with ( - patch.object(router, "config", cfg), - patch.object(router.fetch, "available_models", side_effect=available), - patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())), - ): - ep, _ = await router.choose_endpoint("llama3.2:latest") - assert ep in (EP1, EP2) - - async def test_saturated_picks_least_busy(self): - cfg = _make_cfg([EP1, EP2]) - cfg.max_concurrent_connections = 1 - - async def available(ep, *_): - return {"llama3.2:latest"} - - # Saturate EP1 with 2 active connections, EP2 with 1 - router.usage_counts[EP1]["llama3.2:latest"] = 2 - router.usage_counts[EP2]["llama3.2:latest"] = 1 - - with ( - patch.object(router, "config", cfg), - patch.object(router.fetch, "available_models", side_effect=available), - patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())), - ): - ep, _ = await router.choose_endpoint("llama3.2:latest") - # Least-busy is EP2 - assert ep == EP2 - - async def test_excludes_endpoint_with_recent_loaded_error(self): - # Regression: issue #83 — when /api/ps fails for EP1 but EP1 - # still advertises the model via /api/tags, routing must not - # fall back to EP1 just because it has a free slot. - cfg = _make_cfg([EP1, EP2]) - - async def available(ep, *_): - return {"llama3.2:latest"} - - # EP1's /api/ps probe failed recently; EP2 is fine but the model - # is not loaded there. Without the health filter, EP1 would be - # picked by the free-slot fallback (step 4 in choose_endpoint). - router._loaded_error_cache[EP1] = time.time() - - with ( - patch.object(router, "config", cfg), - patch.object(router.fetch, "available_models", side_effect=available), - patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())), - ): - ep, _ = await router.choose_endpoint("llama3.2:latest") - assert ep == EP2 - - async def test_stale_loaded_error_does_not_exclude(self): - # Errors older than the 300s window must not keep an endpoint - # excluded forever. - cfg = _make_cfg([EP1]) - router._loaded_error_cache[EP1] = time.time() - 301 - - with ( - patch.object(router, "config", cfg), - patch.object(router.fetch, "available_models", AsyncMock(return_value={"m:latest"})), - patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"m:latest"})), - ): - ep, _ = await router.choose_endpoint("m:latest") - assert ep == EP1 - - async def test_all_unhealthy_still_routes(self): - # If every candidate has a fresh loaded-error we still try one - # (it may have recovered between the cache write and now) rather - # than refusing to route. - cfg = _make_cfg([EP1]) - router._loaded_error_cache[EP1] = time.time() - - with ( - patch.object(router, "config", cfg), - patch.object(router.fetch, "available_models", AsyncMock(return_value={"m:latest"})), - patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())), - ): - ep, _ = await router.choose_endpoint("m:latest") - assert ep == EP1 - - async def test_reserve_increments_usage(self): - cfg = _make_cfg([EP1]) - with ( - patch.object(router, "config", cfg), - patch.object(router.fetch, "available_models", AsyncMock(return_value={"model:latest"})), - patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"model:latest"})), - ): - ep, tracking = await router.choose_endpoint("model:latest", reserve=True) - assert router.usage_counts[ep][tracking] == 1 - - -class TestChooseEndpointModelNaming: - async def test_strips_latest_for_openai_endpoints(self): - cfg = _make_cfg(endpoints=[], llama_eps=[LLAMA_EP]) - cfg.endpoints = [] - - async def available(ep, *_): - # llama-server advertises without :latest - return {"gpt-4o"} - - with ( - patch.object(router, "config", cfg), - patch.object(router.fetch, "available_models", side_effect=available), - patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"gpt-4o"})), - ): - ep, _ = await router.choose_endpoint("gpt-4o:latest") - assert ep == LLAMA_EP - - async def test_adds_latest_for_ollama_when_bare_name(self): - cfg = _make_cfg([EP1]) - - async def available(ep, *_): - return {"llama3.2:latest"} - - with ( - patch.object(router, "config", cfg), - patch.object(router.fetch, "available_models", side_effect=available), - patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"llama3.2:latest"})), - ): - ep, _ = await router.choose_endpoint("llama3.2") - assert ep == EP1 - - -class TestChooseEndpointLoadBalancing: - async def test_random_selection_among_idle(self): - cfg = _make_cfg([EP1, EP2, EP3]) - selected = set() - - async def available(ep, *_): - return {"model:latest"} - - async def loaded(ep): - return {"model:latest"} - - for _ in range(20): - router.usage_counts.clear() - with ( - patch.object(router, "config", cfg), - patch.object(router.fetch, "available_models", side_effect=available), - patch.object(router.fetch, "loaded_models", side_effect=loaded), - ): - ep, _ = await router.choose_endpoint("model:latest", reserve=False) - selected.add(ep) - - # With 20 draws from 3 idle endpoints, all three should appear - assert len(selected) > 1 - - async def test_sort_by_load_ascending(self): - cfg = _make_cfg([EP1, EP2]) - router.usage_counts[EP1]["model:latest"] = 1 - router.usage_counts[EP2]["model:latest"] = 0 - - async def available(ep, *_): - return {"model:latest"} - - async def loaded(ep): - return {"model:latest"} - - with ( - patch.object(router, "config", cfg), - patch.object(router.fetch, "available_models", side_effect=available), - patch.object(router.fetch, "loaded_models", side_effect=loaded), - ): - ep, _ = await router.choose_endpoint("model:latest", reserve=False) - # EP2 has fewer active connections → should be selected - assert ep == EP2 - - -# --------------------------------------------------------------------------- -# get_max_connections unit tests -# --------------------------------------------------------------------------- - -class TestGetMaxConnections: - def test_returns_global_default_when_no_override(self): - cfg = _make_cfg([EP1, EP2], max_conn=3) - with patch.object(router, "config", cfg): - assert router.get_max_connections(EP1) == 3 - assert router.get_max_connections(EP2) == 3 - - def test_returns_per_endpoint_override(self): - cfg = _make_cfg( - [EP1, EP2], - max_conn=2, - endpoint_config={EP1: {"max_concurrent_connections": 5}}, - ) - with patch.object(router, "config", cfg): - assert router.get_max_connections(EP1) == 5 - assert router.get_max_connections(EP2) == 2 # falls back to global - - def test_unrecognised_endpoint_falls_back_to_global(self): - cfg = _make_cfg([EP1], max_conn=4, endpoint_config={EP2: {"max_concurrent_connections": 1}}) - with patch.object(router, "config", cfg): - assert router.get_max_connections(EP3) == 4 - - -# --------------------------------------------------------------------------- -# Priority / WRR routing tests -# --------------------------------------------------------------------------- - -MODEL = "model:latest" - - -def _all_loaded(ep): - """Side-effect: every endpoint advertises and has MODEL loaded.""" - return {MODEL} - - -class TestPriorityRouting: - """Tests for priority_routing=True (WRR + config-order tiebreaking).""" - - async def test_idle_picks_first_in_config_order(self): - """When all endpoints are idle, priority picks the first listed endpoint.""" - cfg = _make_cfg([EP1, EP2, EP3], priority_routing=True) - with ( - patch.object(router, "config", cfg), - patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)), - patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)), - ): - ep, _ = await router.choose_endpoint(MODEL, reserve=False) - assert ep == EP1 - - async def test_lower_utilization_preferred_over_priority(self): - """An endpoint with lower ratio is preferred even if it has lower priority.""" - cfg = _make_cfg([EP1, EP2], priority_routing=True) - # EP1 (priority 0) is busier: 1/2 = 0.5; EP2 (priority 1) is idle: 0/2 = 0.0 - router.usage_counts[EP1][MODEL] = 1 - - with ( - patch.object(router, "config", cfg), - patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)), - patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)), - ): - ep, _ = await router.choose_endpoint(MODEL, reserve=False) - assert ep == EP2 - - async def test_wrr_distribution_matches_expected_sequence(self): - """ - Full WRR sequence with heterogeneous capacities, mirroring the issue example: - EP1 max=2, EP2 max=2, EP3 max=1 - - Expected routing order for 5 sequential requests: - EP1, EP2, EP3, EP1, EP2 - """ - cfg = _make_cfg( - [EP1, EP2, EP3], - max_conn=2, - endpoint_config={EP3: {"max_concurrent_connections": 1}}, - priority_routing=True, - ) - - expected = [EP1, EP2, EP3, EP1, EP2] - actual = [] - - with ( - patch.object(router, "config", cfg), - patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)), - patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)), - ): - for _ in expected: - ep, _ = await router.choose_endpoint(MODEL, reserve=True) - actual.append(ep) - - assert actual == expected - - async def test_saturated_picks_lowest_ratio_then_priority(self): - """When all endpoints are saturated, pick lowest utilization ratio; break ties by priority.""" - cfg = _make_cfg( - [EP1, EP2, EP3], - max_conn=1, - endpoint_config={EP3: {"max_concurrent_connections": 2}}, - priority_routing=True, - ) - # EP1 usage=1/1=1.0, EP2 usage=1/1=1.0, EP3 usage=1/2=0.5 → EP3 wins - router.usage_counts[EP1][MODEL] = 1 - router.usage_counts[EP2][MODEL] = 1 - router.usage_counts[EP3][MODEL] = 1 - - with ( - patch.object(router, "config", cfg), - patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)), - patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)), - ): - ep, _ = await router.choose_endpoint(MODEL, reserve=False) - assert ep == EP3 - - async def test_saturated_ties_broken_by_priority(self): - """When all are saturated with equal ratio, config order wins.""" - cfg = _make_cfg([EP1, EP2, EP3], max_conn=1, priority_routing=True) - router.usage_counts[EP1][MODEL] = 1 - router.usage_counts[EP2][MODEL] = 1 - router.usage_counts[EP3][MODEL] = 1 - - with ( - patch.object(router, "config", cfg), - patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)), - patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)), - ): - ep, _ = await router.choose_endpoint(MODEL, reserve=False) - assert ep == EP1 - - -class TestPriorityRoutingDisabled: - """Verify that priority_routing=False keeps the original random behaviour.""" - - async def test_idle_endpoints_are_randomised(self): - """Without priority routing, all-idle selection must eventually pick each endpoint.""" - cfg = _make_cfg([EP1, EP2, EP3], priority_routing=False) - selected = set() - - with ( - patch.object(router, "config", cfg), - patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)), - patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)), - ): - for _ in range(30): - router.usage_counts.clear() - ep, _ = await router.choose_endpoint(MODEL, reserve=False) - selected.add(ep) - - # With 30 draws from 3 equally-idle endpoints, all three must appear - assert selected == {EP1, EP2, EP3} diff --git a/test/test_db.py b/test/test_db.py deleted file mode 100644 index 833b375..0000000 --- a/test/test_db.py +++ /dev/null @@ -1,197 +0,0 @@ -"""Direct unit tests for db.TokenDatabase — no router/app dependency.""" -from datetime import datetime, timezone - -import pytest - -from db import TokenDatabase - - -@pytest.fixture -async def db(tmp_path): - inst = TokenDatabase(str(tmp_path / "tokens.db")) - await inst.init_db() - yield inst - await inst.close() - - -class TestInit: - async def test_init_creates_tables(self, db): - # Re-init must be idempotent - await db.init_db() - # Insert + read confirms tables exist - await db.update_token_counts("http://ep", "m", 1, 2) - rows = [r async for r in db.load_token_counts()] - assert len(rows) == 1 - - async def test_creates_parent_directory(self, tmp_path): - nested = tmp_path / "nested" / "subdir" / "x.db" - inst = TokenDatabase(str(nested)) - await inst.init_db() - try: - assert nested.parent.exists() - finally: - await inst.close() - - -class TestUpdateTokenCounts: - async def test_insert_then_update_aggregates(self, db): - await db.update_token_counts("http://ep", "m1", 10, 20) - await db.update_token_counts("http://ep", "m1", 5, 7) - rows = [r async for r in db.load_token_counts()] - assert len(rows) == 1 - r = rows[0] - assert r["endpoint"] == "http://ep" - assert r["model"] == "m1" - assert r["input_tokens"] == 15 - assert r["output_tokens"] == 27 - assert r["total_tokens"] == 42 - - async def test_independent_endpoint_model_pairs(self, db): - await db.update_token_counts("http://ep1", "m1", 1, 1) - await db.update_token_counts("http://ep1", "m2", 2, 2) - await db.update_token_counts("http://ep2", "m1", 3, 3) - rows = [r async for r in db.load_token_counts()] - assert len(rows) == 3 - totals = {(r["endpoint"], r["model"]): r["total_tokens"] for r in rows} - assert totals == { - ("http://ep1", "m1"): 2, - ("http://ep1", "m2"): 4, - ("http://ep2", "m1"): 6, - } - - -class TestBatchedCounts: - async def test_update_batched_counts(self, db): - counts = { - "http://a": {"m": (4, 6)}, - "http://b": {"m": (1, 1), "n": (10, 0)}, - } - await db.update_batched_counts(counts) - rows = [r async for r in db.load_token_counts()] - totals = {(r["endpoint"], r["model"]): r["total_tokens"] for r in rows} - assert totals == { - ("http://a", "m"): 10, - ("http://b", "m"): 2, - ("http://b", "n"): 10, - } - - async def test_empty_batch_is_noop(self, db): - await db.update_batched_counts({}) - rows = [r async for r in db.load_token_counts()] - assert rows == [] - - -class TestTimeSeries: - async def test_add_time_series_entry(self, db): - # The aggregate FK requires the (endpoint,model) row to exist first - await db.update_token_counts("http://ep", "m", 0, 0) - await db.add_time_series_entry("http://ep", "m", 3, 4) - await db.add_time_series_entry("http://ep", "m", 1, 1) - rows = [r async for r in db.get_latest_time_series(limit=10)] - assert len(rows) == 2 - # Newest-first ordering; both timestamps are within the same minute, - # so just check totals are present and well-formed - for r in rows: - assert r["endpoint"] == "http://ep" - assert r["model"] == "m" - assert r["total_tokens"] == r["input_tokens"] + r["output_tokens"] - - async def test_add_batched_time_series(self, db): - await db.update_token_counts("http://ep", "m", 0, 0) - now = int(datetime.now(tz=timezone.utc).timestamp()) - entries = [ - {"endpoint": "http://ep", "model": "m", "input_tokens": 1, - "output_tokens": 2, "total_tokens": 3, "timestamp": now - 60}, - {"endpoint": "http://ep", "model": "m", "input_tokens": 4, - "output_tokens": 5, "total_tokens": 9, "timestamp": now}, - ] - await db.add_batched_time_series(entries) - rows = [r async for r in db.get_latest_time_series(limit=10)] - assert len(rows) == 2 - assert rows[0]["timestamp"] >= rows[1]["timestamp"] - - async def test_get_time_series_for_model_filters(self, db): - await db.update_token_counts("http://ep", "m1", 0, 0) - await db.update_token_counts("http://ep", "m2", 0, 0) - now = int(datetime.now(tz=timezone.utc).timestamp()) - await db.add_batched_time_series([ - {"endpoint": "http://ep", "model": "m1", "input_tokens": 1, - "output_tokens": 1, "total_tokens": 2, "timestamp": now}, - {"endpoint": "http://ep", "model": "m2", "input_tokens": 9, - "output_tokens": 9, "total_tokens": 18, "timestamp": now}, - ]) - rows = [r async for r in db.get_time_series_for_model("m1")] - assert len(rows) == 1 - assert rows[0]["total_tokens"] == 2 - - async def test_endpoint_distribution_for_model(self, db): - await db.update_token_counts("http://a", "m", 0, 0) - await db.update_token_counts("http://b", "m", 0, 0) - now = int(datetime.now(tz=timezone.utc).timestamp()) - await db.add_batched_time_series([ - {"endpoint": "http://a", "model": "m", "input_tokens": 1, - "output_tokens": 1, "total_tokens": 2, "timestamp": now}, - {"endpoint": "http://a", "model": "m", "input_tokens": 1, - "output_tokens": 1, "total_tokens": 2, "timestamp": now}, - {"endpoint": "http://b", "model": "m", "input_tokens": 5, - "output_tokens": 5, "total_tokens": 10, "timestamp": now}, - ]) - dist = await db.get_endpoint_distribution_for_model("m") - assert dist == {"http://a": 4, "http://b": 10} - - -class TestGetTokenCountsForModel: - async def test_aggregates_across_endpoints(self, db): - await db.update_token_counts("http://a", "m", 1, 2) - await db.update_token_counts("http://b", "m", 3, 4) - result = await db.get_token_counts_for_model("m") - assert result is not None - assert result["endpoint"] == "aggregated" - assert result["model"] == "m" - assert result["input_tokens"] == 4 - assert result["output_tokens"] == 6 - assert result["total_tokens"] == 10 - - async def test_unknown_model_returns_zero_aggregate(self, db): - # SUM(...) WHERE no-match returns one row with NULLs — exposed as zeros - result = await db.get_token_counts_for_model("nope") - assert result is not None - assert result["input_tokens"] in (0, None) - - -class TestAggregateTimeSeriesOlderThan: - async def test_aggregates_old_entries_by_day(self, db): - await db.update_token_counts("http://ep", "m", 0, 0) - now = int(datetime.now(tz=timezone.utc).timestamp()) - old = now - (40 * 86400) # 40 days ago - await db.add_batched_time_series([ - {"endpoint": "http://ep", "model": "m", "input_tokens": 1, - "output_tokens": 1, "total_tokens": 2, "timestamp": old}, - {"endpoint": "http://ep", "model": "m", "input_tokens": 3, - "output_tokens": 3, "total_tokens": 6, "timestamp": old + 60}, - {"endpoint": "http://ep", "model": "m", "input_tokens": 99, - "output_tokens": 99, "total_tokens": 198, "timestamp": now}, - ]) - n = await db.aggregate_time_series_older_than(30, trim_old=False) - assert n == 1 # one (endpoint, model, day) group rolled up - - async def test_invalid_days_falls_back_to_30(self, db): - # Just ensure it doesn't blow up with a bogus value - n = await db.aggregate_time_series_older_than(0) - assert n == 0 - - async def test_trim_old_removes_aggregated_rows(self, db): - await db.update_token_counts("http://ep", "m", 0, 0) - now = int(datetime.now(tz=timezone.utc).timestamp()) - old = now - (40 * 86400) - await db.add_batched_time_series([ - {"endpoint": "http://ep", "model": "m", "input_tokens": 1, - "output_tokens": 1, "total_tokens": 2, "timestamp": old}, - {"endpoint": "http://ep", "model": "m", "input_tokens": 99, - "output_tokens": 99, "total_tokens": 198, "timestamp": now}, - ]) - await db.aggregate_time_series_older_than(30, trim_old=True) - remaining = [r async for r in db.get_latest_time_series(limit=10)] - # Only the recent (within-cutoff) row should remain - assert len(remaining) == 1 - assert remaining[0]["total_tokens"] == 198 diff --git a/test/test_fetch.py b/test/test_fetch.py deleted file mode 100644 index 6f2ed50..0000000 --- a/test/test_fetch.py +++ /dev/null @@ -1,210 +0,0 @@ -"""Tests for fetch.available_models and fetch.loaded_models using aioresponses mocking.""" -import time -from unittest.mock import patch, MagicMock - -import pytest -from aioresponses import aioresponses - -import router -from conftest import TEST_OLLAMA, TEST_LLAMA - -MOCK_OLLAMA_EP = "http://mock-ollama:11434" -MOCK_LLAMA_EP = "http://mock-llama:8080/v1" - - -def _make_cfg(ollama_eps=None, llama_eps=None, api_keys=None): - cfg = MagicMock() - cfg.endpoints = ollama_eps or [MOCK_OLLAMA_EP] - cfg.llama_server_endpoints = llama_eps or [MOCK_LLAMA_EP] - cfg.api_keys = api_keys or {} - cfg.max_concurrent_connections = 2 - cfg.router_api_key = None - return cfg - - -@pytest.fixture(autouse=True) -def clear_caches(aio_session): - """aio_session fixture already clears caches and sets up app_state.""" - yield - - -class TestFetchAvailableModels: - async def test_ollama_tags(self): - cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) - with patch.object(router, "config", cfg), aioresponses() as m: - m.get( - f"{MOCK_OLLAMA_EP}/api/tags", - payload={"models": [ - {"name": "llama3.2:latest"}, - {"name": "qwen2.5:7b"}, - ]}, - ) - models = await router.fetch.available_models(MOCK_OLLAMA_EP) - assert models == {"llama3.2:latest", "qwen2.5:7b"} - - async def test_openai_compatible_models_endpoint(self): - cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP]) - with patch.object(router, "config", cfg), aioresponses() as m: - m.get( - f"{MOCK_LLAMA_EP}/models", - payload={"data": [{"id": "unsloth/model:Q8_0"}]}, - ) - models = await router.fetch.available_models(MOCK_LLAMA_EP, api_key="tok") - assert "unsloth/model:Q8_0" in models - - async def test_caches_successful_result(self): - cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) - with patch.object(router, "config", cfg), aioresponses() as m: - m.get( - f"{MOCK_OLLAMA_EP}/api/tags", - payload={"models": [{"name": "llama3.2:latest"}]}, - ) - first = await router.fetch.available_models(MOCK_OLLAMA_EP) - second = await router.fetch.available_models(MOCK_OLLAMA_EP) - # second call must be served from cache without a second HTTP request - assert first == second == {"llama3.2:latest"} - - async def test_returns_empty_on_http_500(self): - cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) - with patch.object(router, "config", cfg), aioresponses() as m: - m.get(f"{MOCK_OLLAMA_EP}/api/tags", status=500, payload={"error": "oops"}) - models = await router.fetch.available_models(MOCK_OLLAMA_EP) - assert models == set() - - async def test_returns_empty_on_connection_error(self): - cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) - import aiohttp - with patch.object(router, "config", cfg), aioresponses() as m: - m.get( - f"{MOCK_OLLAMA_EP}/api/tags", - exception=aiohttp.ClientConnectorError( - connection_key=MagicMock(host="mock-ollama", port=11434), - os_error=OSError(111, "refused"), - ), - ) - models = await router.fetch.available_models(MOCK_OLLAMA_EP) - assert models == set() - - async def test_stale_cache_returned_while_refresh_runs(self): - cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) - with patch.object(router, "config", cfg), aioresponses() as m: - m.get( - f"{MOCK_OLLAMA_EP}/api/tags", - payload={"models": [{"name": "llama3.2:latest"}]}, - ) - await router.fetch.available_models(MOCK_OLLAMA_EP) - - # Manually age cache into stale-but-valid window (300-600s) - async with router._models_cache_lock: - models, _ = router._models_cache[MOCK_OLLAMA_EP] - router._models_cache[MOCK_OLLAMA_EP] = (models, time.time() - 400) - - with patch.object(router, "config", cfg), aioresponses() as m: - m.get( - f"{MOCK_OLLAMA_EP}/api/tags", - payload={"models": [{"name": "llama3.2:latest"}]}, - ) - # Should return stale data immediately - stale = await router.fetch.available_models(MOCK_OLLAMA_EP) - assert "llama3.2:latest" in stale - - async def test_error_cache_short_circuits(self): - cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) - # Seed error cache with a very recent error - async with router._available_error_cache_lock: - router._available_error_cache[MOCK_OLLAMA_EP] = time.time() - - with patch.object(router, "config", cfg), aioresponses(): - # No HTTP mock registered — if a call happens it will raise - models = await router.fetch.available_models(MOCK_OLLAMA_EP) - assert models == set() - - -class TestFetchLoadedModels: - async def test_ollama_ps(self): - cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) - with patch.object(router, "config", cfg), aioresponses() as m: - m.get( - f"{MOCK_OLLAMA_EP}/api/ps", - payload={"models": [{"name": "llama3.2:latest"}]}, - ) - models = await router.fetch.loaded_models(MOCK_OLLAMA_EP) - assert models == {"llama3.2:latest"} - - async def test_llama_server_filters_loaded(self): - cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP]) - with patch.object(router, "config", cfg), aioresponses() as m: - m.get( - f"{MOCK_LLAMA_EP}/models", - payload={"data": [ - {"id": "model-a", "status": {"value": "loaded"}}, - {"id": "model-b", "status": {"value": "unloaded"}}, - ]}, - ) - models = await router.fetch.loaded_models(MOCK_LLAMA_EP) - assert models == {"model-a"} - - async def test_llama_server_no_status_field_always_loaded(self): - cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP]) - with patch.object(router, "config", cfg), aioresponses() as m: - m.get( - f"{MOCK_LLAMA_EP}/models", - payload={"data": [{"id": "always-on-model"}]}, - ) - models = await router.fetch.loaded_models(MOCK_LLAMA_EP) - assert "always-on-model" in models - - async def test_returns_empty_on_error(self): - cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) - with patch.object(router, "config", cfg), aioresponses() as m: - m.get(f"{MOCK_OLLAMA_EP}/api/ps", status=503, payload={}) - models = await router.fetch.loaded_models(MOCK_OLLAMA_EP) - assert models == set() - - async def test_ext_openai_always_empty(self): - ext_ep = "https://api.openai.com/v1" - cfg = _make_cfg(ollama_eps=[ext_ep], llama_eps=[]) - with patch.object(router, "config", cfg): - models = await router.fetch.loaded_models(ext_ep) - assert models == set() - - async def test_caches_result(self): - cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) - with patch.object(router, "config", cfg), aioresponses() as m: - m.get( - f"{MOCK_OLLAMA_EP}/api/ps", - payload={"models": [{"name": "qwen:7b"}]}, - ) - first = await router.fetch.loaded_models(MOCK_OLLAMA_EP) - second = await router.fetch.loaded_models(MOCK_OLLAMA_EP) - assert first == second - - async def test_records_error_in_loaded_error_cache_on_failure(self): - # Regression: issue #83 — /api/ps failures must be recorded so - # `choose_endpoint` can exclude unhealthy backends from routing. - cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) - with patch.object(router, "config", cfg), aioresponses() as m: - m.get(f"{MOCK_OLLAMA_EP}/api/ps", status=502, payload={}) - await router.fetch.loaded_models(MOCK_OLLAMA_EP) - assert MOCK_OLLAMA_EP in router._loaded_error_cache - - async def test_records_error_for_llama_server_on_failure(self): - cfg = _make_cfg(ollama_eps=[], llama_eps=[MOCK_LLAMA_EP]) - with patch.object(router, "config", cfg), aioresponses() as m: - m.get(f"{MOCK_LLAMA_EP}/models", status=502, payload={}) - await router.fetch.loaded_models(MOCK_LLAMA_EP) - assert MOCK_LLAMA_EP in router._loaded_error_cache - - async def test_clears_error_cache_on_subsequent_success(self): - cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) - # Pre-seed an old error so loaded_models() falls through to the - # network probe instead of short-circuiting on the error cache. - async with router._loaded_error_cache_lock: - router._loaded_error_cache[MOCK_OLLAMA_EP] = time.time() - 301 - with patch.object(router, "config", cfg), aioresponses() as m: - m.get( - f"{MOCK_OLLAMA_EP}/api/ps", - payload={"models": [{"name": "qwen:7b"}]}, - ) - await router.fetch.loaded_models(MOCK_OLLAMA_EP) - assert MOCK_OLLAMA_EP not in router._loaded_error_cache diff --git a/test/test_openai_proxies.py b/test/test_openai_proxies.py deleted file mode 100644 index 8a56c91..0000000 --- a/test/test_openai_proxies.py +++ /dev/null @@ -1,181 +0,0 @@ -"""Cache-hit short-circuit tests for the OpenAI-compatible proxy routes. - -These tests verify that when the LLM cache reports a hit, the route returns -the cached payload *without* selecting an endpoint or contacting any backend. -""" -from unittest.mock import AsyncMock, patch - -import orjson -import pytest -from fastapi import HTTPException - -import router - - -_BYPASS = HTTPException(status_code=599, detail="bypassed") - - -class _FakeCache: - """Minimal stand-in for cache.LLMCache.get_chat.""" - def __init__(self, response_bytes: bytes | None): - self._resp = response_bytes - self.calls: list[tuple] = [] - - async def get_chat(self, route, model, messages): - self.calls.append((route, model, messages)) - return self._resp - - -@pytest.fixture -def cache_hit_payload(): - return orjson.dumps({ - "id": "cmpl-xyz", - "created": 1, - "model": "test-model", - "choices": [{"message": {"role": "assistant", "content": "from-cache"}}], - "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, - }) - - -# ────────────────────────────────────────────────────────────────────────────── -# /v1/chat/completions -# ────────────────────────────────────────────────────────────────────────────── - -class TestOpenAIChatCompletionsCacheHit: - async def test_nonstream_cache_hit_returns_cached_json(self, client, cache_hit_payload): - fake = _FakeCache(cache_hit_payload) - # Patch the route's references to both helpers — they're imported by name - # into router's namespace at module load time. - with ( - patch.object(router, "get_llm_cache", return_value=fake), - patch.object(router, "choose_endpoint", - AsyncMock(side_effect=AssertionError("backend must not be reached"))), - ): - resp = await client.post( - "/v1/chat/completions", - json={ - "model": "test-model", - "messages": [{"role": "user", "content": "ping"}], - "stream": False, - "nomyo": {"cache": True}, - }, - ) - assert resp.status_code == 200 - # Body is streamed; collect it - body = resp.content - parsed = orjson.loads(body) - assert parsed["choices"][0]["message"]["content"] == "from-cache" - assert fake.calls and fake.calls[0][0] == "openai_chat" - - async def test_stream_cache_hit_returns_sse(self, client, cache_hit_payload): - fake = _FakeCache(cache_hit_payload) - with ( - patch.object(router, "get_llm_cache", return_value=fake), - patch.object(router, "choose_endpoint", - AsyncMock(side_effect=AssertionError("backend must not be reached"))), - ): - resp = await client.post( - "/v1/chat/completions", - json={ - "model": "test-model", - "messages": [{"role": "user", "content": "ping"}], - "stream": True, - "nomyo": {"cache": True}, - }, - ) - assert resp.status_code == 200 - assert resp.headers["content-type"].startswith("text/event-stream") - text = resp.content.decode() - # First SSE frame contains the cached content as a delta - first_frame = text.split("\n\n")[0] - assert first_frame.startswith("data: ") - chunk = orjson.loads(first_frame[len("data: "):]) - assert chunk["choices"][0]["delta"]["content"] == "from-cache" - # Stream is terminated with [DONE] - assert "data: [DONE]" in text - - async def test_cache_disabled_in_payload_bypasses_cache_check(self, client): - """When nomyo.cache=False, get_chat is never called even if a cache exists.""" - fake = _FakeCache(b"") # has a response, but should never be consulted - with ( - patch.object(router, "get_llm_cache", return_value=fake), - patch.object(router, "choose_endpoint", - AsyncMock(side_effect=_BYPASS)), - ): - resp = await client.post( - "/v1/chat/completions", - json={ - "model": "m", - "messages": [{"role": "user", "content": "hi"}], - "nomyo": {"cache": False}, - }, - ) - # Got past the cache short-circuit → endpoint selection invoked - assert resp.status_code == 599 - assert fake.calls == [] - - async def test_no_cache_configured_bypasses_cache_check(self, client): - """get_llm_cache() returning None should not break the route.""" - with ( - patch.object(router, "get_llm_cache", return_value=None), - patch.object(router, "choose_endpoint", - AsyncMock(side_effect=_BYPASS)), - ): - resp = await client.post( - "/v1/chat/completions", - json={ - "model": "m", - "messages": [{"role": "user", "content": "hi"}], - "nomyo": {"cache": True}, - }, - ) - assert resp.status_code == 599 - - -# ────────────────────────────────────────────────────────────────────────────── -# /v1/completions -# ────────────────────────────────────────────────────────────────────────────── - -class TestOpenAICompletionsCacheHit: - async def test_nonstream_cache_hit(self, client, cache_hit_payload): - fake = _FakeCache(cache_hit_payload) - with ( - patch.object(router, "get_llm_cache", return_value=fake), - patch.object(router, "choose_endpoint", - AsyncMock(side_effect=AssertionError("backend must not be reached"))), - ): - resp = await client.post( - "/v1/completions", - json={ - "model": "test-model", - "prompt": "Tell me a joke", - "stream": False, - "nomyo": {"cache": True}, - }, - ) - assert resp.status_code == 200 - # Prompt-style cache lookup is namespaced under "openai_completions" - assert fake.calls[0][0] == "openai_completions" - # Cache lookup receives the prompt as a single user message - cached_msgs = fake.calls[0][2] - assert cached_msgs == [{"role": "user", "content": "Tell me a joke"}] - - async def test_stream_cache_hit(self, client, cache_hit_payload): - fake = _FakeCache(cache_hit_payload) - with ( - patch.object(router, "get_llm_cache", return_value=fake), - patch.object(router, "choose_endpoint", - AsyncMock(side_effect=AssertionError("backend must not be reached"))), - ): - resp = await client.post( - "/v1/completions", - json={ - "model": "test-model", - "prompt": "What is 2+2?", - "stream": True, - "nomyo": {"cache": True}, - }, - ) - assert resp.status_code == 200 - assert resp.headers["content-type"].startswith("text/event-stream") - assert "data: [DONE]" in resp.content.decode() diff --git a/test/test_unit_context.py b/test/test_unit_context.py deleted file mode 100644 index de2b98a..0000000 --- a/test/test_unit_context.py +++ /dev/null @@ -1,116 +0,0 @@ -"""Unit tests for context-window trimming logic.""" -import pytest -import router - - -def _msgs(roles_contents): - return [{"role": r, "content": c} for r, c in roles_contents] - - -class TestCountMessageTokens: - def test_returns_int(self): - msgs = _msgs([("user", "hello")]) - assert isinstance(router._count_message_tokens(msgs), int) - - def test_empty_list(self): - assert router._count_message_tokens([]) >= 0 - - def test_longer_content_more_tokens(self): - short = _msgs([("user", "hi")]) - long_ = _msgs([("user", "a " * 500)]) - assert router._count_message_tokens(long_) > router._count_message_tokens(short) - - def test_list_content(self): - msgs = [{"role": "user", "content": [ - {"type": "text", "text": "what do you see?"}, - ]}] - tokens = router._count_message_tokens(msgs) - assert tokens > 0 - - def test_multiple_messages(self): - msgs = _msgs([("system", "you are helpful"), ("user", "hello"), ("assistant", "hi!")]) - assert router._count_message_tokens(msgs) > 10 - - -class TestTrimMessagesForContext: - def test_short_history_unchanged(self): - msgs = _msgs([("user", "hello"), ("assistant", "hi"), ("user", "bye")]) - result = router._trim_messages_for_context(msgs, n_ctx=4096) - assert result == msgs - - def test_system_messages_always_kept(self): - msgs = ( - _msgs([("system", "you are helpful")]) - + _msgs([("user", f"msg {i}") for i in range(50)]) - + _msgs([("user", "final question")]) - ) - result = router._trim_messages_for_context(msgs, n_ctx=512) - system_msgs = [m for m in result if m["role"] == "system"] - assert len(system_msgs) == 1 - assert system_msgs[0]["content"] == "you are helpful" - - def test_last_user_message_always_kept(self): - msgs = _msgs([("user", f"old msg {i}") for i in range(100)] + [("user", "very important last question")]) - result = router._trim_messages_for_context(msgs, n_ctx=256) - assert result[-1]["content"] == "very important last question" - - def test_oldest_dropped_first(self): - msgs = _msgs([ - ("user", "oldest msg"), - ("assistant", "oldest reply"), - ("user", "newer msg"), - ("assistant", "newer reply"), - ("user", "newest"), - ]) - # Use very small target to force trimming - result = router._trim_messages_for_context(msgs, n_ctx=256, target_tokens=10) - contents = [m["content"] for m in result] - # "oldest msg" should be dropped before "newest" - if "oldest msg" in contents: - assert "newest" in contents - else: - assert "newest" in contents - - def test_result_starts_with_user(self): - msgs = _msgs([ - ("assistant", "leftover assistant"), - ("user", "question"), - ]) - result = router._trim_messages_for_context(msgs, n_ctx=256, target_tokens=20) - if result: - assert result[0]["role"] == "user" - - def test_target_tokens_overrides_safety_margin(self): - msgs = _msgs([("user", "a " * 200)]) - result_small = router._trim_messages_for_context(msgs, n_ctx=8192, target_tokens=10) - result_large = router._trim_messages_for_context(msgs, n_ctx=8192, target_tokens=5000) - # Both should return at least the last message - assert len(result_small) >= 1 - assert len(result_large) >= 1 - - -class TestCalibratedTrimTarget: - def test_returns_positive_int(self): - msgs = [{"role": "user", "content": "hello " * 100}] - result = router._calibrated_trim_target(msgs, n_ctx=4096, actual_tokens=3000) - assert isinstance(result, int) - assert result >= 1 - - def test_over_limit_reduces_target(self): - msgs = [{"role": "user", "content": "a " * 500}] - # actual_tokens > n_ctx means we need to shed more - target = router._calibrated_trim_target(msgs, n_ctx=2048, actual_tokens=2500) - assert target < router._count_message_tokens(msgs) - - def test_well_within_limit_returns_current(self): - msgs = [{"role": "user", "content": "hi"}] - # actual_tokens << n_ctx means nothing to shed - target = router._calibrated_trim_target(msgs, n_ctx=16384, actual_tokens=50) - # Should return cur_tiktoken since to_shed == 0 - assert target == max(1, router._count_message_tokens(msgs)) - - def test_minimum_is_one(self): - # Even if we need to shed everything, result is at least 1 - msgs = [{"role": "user", "content": "hello"}] - target = router._calibrated_trim_target(msgs, n_ctx=100, actual_tokens=99999) - assert target >= 1 diff --git a/test/test_unit_helpers.py b/test/test_unit_helpers.py deleted file mode 100644 index d38eb37..0000000 --- a/test/test_unit_helpers.py +++ /dev/null @@ -1,279 +0,0 @@ -"""Unit tests for pure helper functions in router.py (no network, no app).""" -import time -import asyncio -from unittest.mock import MagicMock, patch - -import aiohttp -import pytest - -import router - - -class TestMaskSecrets: - def test_masks_openai_key(self): - text = "Authorization: Bearer sk-abcd1234XYZabcd1234XYZabcd1234XYZ" - result = router._mask_secrets(text) - assert "sk-***redacted***" in result - assert "sk-abcd1234" not in result - - def test_masks_api_key_assignment(self): - result = router._mask_secrets("api_key: supersecretvalue123") - assert "supersecretvalue123" not in result - assert "***redacted***" in result - - def test_masks_api_key_with_colon(self): - result = router._mask_secrets("api-key: mykey") - assert "mykey" not in result - - def test_empty_string_returns_empty(self): - assert router._mask_secrets("") == "" - - def test_none_returns_none(self): - assert router._mask_secrets(None) is None - - def test_no_secrets_unchanged(self): - text = "this is a normal log line" - assert router._mask_secrets(text) == text - - -class TestIsFresh: - def test_fresh_within_ttl(self): - cached_at = time.time() - 10 - assert router._is_fresh(cached_at, 300) is True - - def test_expired_beyond_ttl(self): - cached_at = time.time() - 400 - assert router._is_fresh(cached_at, 300) is False - - def test_exactly_at_boundary(self): - cached_at = time.time() - 300 - # May be True or False depending on timing, just verify it runs - result = router._is_fresh(cached_at, 300) - assert isinstance(result, bool) - - def test_just_cached(self): - assert router._is_fresh(time.time(), 1) is True - - -class TestNormalizeLlamaModelName: - def test_strips_hf_prefix(self): - assert router._normalize_llama_model_name("unsloth/gpt-oss-20b-GGUF") == "gpt-oss-20b-GGUF" - - def test_strips_quant_suffix(self): - assert router._normalize_llama_model_name("model:Q8_0") == "model" - - def test_strips_both(self): - result = router._normalize_llama_model_name("unsloth/gpt-oss-20b-GGUF:F16") - assert result == "gpt-oss-20b-GGUF" - - def test_no_prefix_no_suffix(self): - assert router._normalize_llama_model_name("plain-model") == "plain-model" - - def test_multiple_slashes(self): - result = router._normalize_llama_model_name("org/user/model-name:Q4_K_M") - assert result == "model-name" - - -class TestExtractLlamaQuant: - def test_extracts_quant(self): - assert router._extract_llama_quant("unsloth/model:Q8_0") == "Q8_0" - - def test_no_quant_returns_empty(self): - assert router._extract_llama_quant("plain-model") == "" - - def test_f16(self): - assert router._extract_llama_quant("model:F16") == "F16" - - def test_q4_k_m(self): - assert router._extract_llama_quant("model:Q4_K_M") == "Q4_K_M" - - -class TestIsUnixSocketEndpoint: - def test_sock_endpoint_detected(self): - assert router._is_unix_socket_endpoint("http://192.168.0.52.sock/v1") is True - - def test_regular_http_not_sock(self): - assert router._is_unix_socket_endpoint("http://192.168.0.52:8080/v1") is False - - def test_ollama_not_sock(self): - assert router._is_unix_socket_endpoint("http://localhost:11434") is False - - def test_dot_sock_in_host_detected(self): - assert router._is_unix_socket_endpoint("http://llama.sock/v1") is True - - -class TestGetSocketPath: - def test_returns_run_user_path(self): - import os - path = router._get_socket_path("http://192.168.0.52.sock/v1") - uid = os.getuid() - assert path == f"/run/user/{uid}/192.168.0.52.sock" - - -class TestIsBase64: - def test_valid_base64(self): - import base64 - data = base64.b64encode(b"hello world").decode() - assert router.is_base64(data) is True - - def test_invalid_base64(self): - assert router.is_base64("not-base64!@#$") is False - - def test_empty_string(self): - # Empty string is valid base64 (decodes to empty bytes) - assert router.is_base64("") is True - - def test_non_string(self): - # Non-strings fall through without returning True (returns None) - assert not router.is_base64(12345) - - -class TestIsLlamaModelLoaded: - def test_status_dict_loaded(self): - assert router._is_llama_model_loaded({"id": "m", "status": {"value": "loaded"}}) is True - - def test_status_dict_unloaded(self): - assert router._is_llama_model_loaded({"id": "m", "status": {"value": "unloaded"}}) is False - - def test_status_string_loaded(self): - assert router._is_llama_model_loaded({"id": "m", "status": "loaded"}) is True - - def test_status_string_unloaded(self): - assert router._is_llama_model_loaded({"id": "m", "status": "unloaded"}) is False - - def test_no_status_field_always_loaded(self): - # No status field → always available (single-model server) - assert router._is_llama_model_loaded({"id": "m"}) is True - - def test_status_none_always_loaded(self): - assert router._is_llama_model_loaded({"id": "m", "status": None}) is True - - -class TestEp2Base: - def test_adds_v1_to_ollama(self): - assert router.ep2base("http://localhost:11434") == "http://localhost:11434/v1" - - def test_keeps_v1_if_present(self): - assert router.ep2base("http://host/v1") == "http://host/v1" - - def test_llama_server_endpoint_unchanged(self): - ep = "http://192.168.0.50:8889/v1" - assert router.ep2base(ep) == ep - - -class TestDedupeOnKeys: - def test_removes_duplicate_by_single_key(self): - items = [{"name": "a", "x": 1}, {"name": "b", "x": 2}, {"name": "a", "x": 3}] - result = router.dedupe_on_keys(items, ["name"]) - assert len(result) == 2 - assert result[0]["name"] == "a" - assert result[1]["name"] == "b" - - def test_removes_duplicate_by_two_keys(self): - items = [ - {"digest": "abc", "name": "m1"}, - {"digest": "abc", "name": "m1"}, - {"digest": "def", "name": "m2"}, - ] - result = router.dedupe_on_keys(items, ["digest", "name"]) - assert len(result) == 2 - - def test_empty_list(self): - assert router.dedupe_on_keys([], ["name"]) == [] - - def test_no_duplicates(self): - items = [{"name": "a"}, {"name": "b"}, {"name": "c"}] - assert len(router.dedupe_on_keys(items, ["name"])) == 3 - - -class TestFormatConnectionIssue: - def test_connector_error_message(self): - err = aiohttp.ClientConnectorError( - connection_key=MagicMock(host="localhost", port=11434), - os_error=OSError(111, "Connection refused"), - ) - msg = router._format_connection_issue("http://localhost:11434", err) - assert "localhost" in msg - assert "Connection refused" in msg or "111" in msg - - def test_timeout_error_message(self): - msg = router._format_connection_issue("http://host:1234", asyncio.TimeoutError()) - assert "Timed out" in msg - assert "host:1234" in msg - - def test_generic_error(self): - msg = router._format_connection_issue("http://host:1234", ValueError("boom")) - assert "host:1234" in msg - assert "boom" in msg - - -class TestIsExtOpenaiEndpoint: - def test_openai_com_is_ext(self): - cfg = MagicMock() - cfg.endpoints = [] - cfg.llama_server_endpoints = [] - with patch.object(router, "config", cfg): - assert router.is_ext_openai_endpoint("https://api.openai.com/v1") is True - - def test_ollama_default_port_not_ext(self): - cfg = MagicMock() - cfg.endpoints = ["http://host:11434"] - cfg.llama_server_endpoints = [] - with patch.object(router, "config", cfg): - assert router.is_ext_openai_endpoint("http://host:11434") is False - - def test_llama_server_not_ext(self): - cfg = MagicMock() - cfg.endpoints = [] - cfg.llama_server_endpoints = ["http://host:8080/v1"] - with patch.object(router, "config", cfg): - assert router.is_ext_openai_endpoint("http://host:8080/v1") is False - - def test_no_v1_not_ext(self): - cfg = MagicMock() - cfg.endpoints = ["http://host:11434"] - cfg.llama_server_endpoints = [] - with patch.object(router, "config", cfg): - assert router.is_ext_openai_endpoint("http://host:11434") is False - - -class TestIsOpenaiCompatible: - def test_v1_endpoint_compatible(self): - cfg = MagicMock() - cfg.llama_server_endpoints = [] - with patch.object(router, "config", cfg): - assert router.is_openai_compatible("http://host/v1") is True - - def test_ollama_not_compatible(self): - cfg = MagicMock() - cfg.llama_server_endpoints = [] - with patch.object(router, "config", cfg): - assert router.is_openai_compatible("http://localhost:11434") is False - - def test_llama_server_in_list_compatible(self): - cfg = MagicMock() - cfg.llama_server_endpoints = ["http://host:8080"] - with patch.object(router, "config", cfg): - assert router.is_openai_compatible("http://host:8080") is True - - -class TestGetTrackingModel: - def test_ollama_adds_latest(self): - cfg = MagicMock() - cfg.llama_server_endpoints = [] - with patch.object(router, "config", cfg): - assert router.get_tracking_model("http://ollama:11434", "llama3.2") == "llama3.2:latest" - - def test_ollama_keeps_existing_tag(self): - cfg = MagicMock() - cfg.llama_server_endpoints = [] - with patch.object(router, "config", cfg): - assert router.get_tracking_model("http://ollama:11434", "llama3.2:7b") == "llama3.2:7b" - - def test_llama_server_normalizes(self): - ep = "http://host:8080/v1" - cfg = MagicMock() - cfg.llama_server_endpoints = [ep] - with patch.object(router, "config", cfg): - result = router.get_tracking_model(ep, "unsloth/model:Q8_0") - assert result == "model" diff --git a/test/test_unit_rechunk.py b/test/test_unit_rechunk.py deleted file mode 100644 index e0d01c9..0000000 --- a/test/test_unit_rechunk.py +++ /dev/null @@ -1,173 +0,0 @@ -"""Unit tests for router.rechunk — OpenAI ↔ Ollama chunk shape conversion.""" -import time -from types import SimpleNamespace - -import ollama - -import router - - -def _ns(**kw): - return SimpleNamespace(**kw) - - -def _stream_chunk(content="hi", role="assistant", finish_reason=None, - usage=None, model="m"): - """Build a SimpleNamespace mimicking a streaming OpenAI chunk.""" - delta = _ns(content=content, role=role, reasoning=None, reasoning_content=None, - tool_calls=None) - choice = _ns(delta=delta, finish_reason=finish_reason, logprobs=None) - return _ns(model=model, choices=[choice], usage=usage) - - -def _nonstream_chunk(content="hi", role="assistant", finish_reason="stop", - usage=None, model="m", tool_calls=None): - """Build a SimpleNamespace mimicking a non-streaming OpenAI ChatCompletion.""" - message = _ns(content=content, role=role, reasoning=None, reasoning_content=None, - tool_calls=tool_calls) - choice = _ns(message=message, finish_reason=finish_reason, logprobs=None) - return _ns(model=model, choices=[choice], usage=usage) - - -# ────────────────────────────────────────────────────────────────────────────── -# openai_chat_completion2ollama -# ────────────────────────────────────────────────────────────────────────────── - -class TestChatCompletionToOllama: - def test_streaming_content_chunk(self): - chunk = _stream_chunk(content="hello", finish_reason=None, usage=None) - out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter()) - assert isinstance(out, ollama.ChatResponse) - assert out.message.role == "assistant" - assert out.message.content == "hello" - assert out.done is False # usage is None → not done yet - assert out.model == "m" - - def test_streaming_empty_content_defaults(self): - # Some chunks have content=None — should coerce to empty string - chunk = _stream_chunk(content=None, role=None) - out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter()) - assert out.message.role == "assistant" # role defaulted - assert out.message.content == "" - - def test_final_usage_only_chunk_marks_done(self): - usage = _ns(prompt_tokens=10, completion_tokens=5, total_tokens=15) - chunk = _ns(model="m", choices=[], usage=usage) - out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter()) - assert out.done is True - assert out.done_reason == "stop" - assert out.prompt_eval_count == 10 - assert out.eval_count == 5 - assert out.message.content == "" - - def test_nonstreaming_with_content(self): - usage = _ns(prompt_tokens=2, completion_tokens=3, total_tokens=5) - chunk = _nonstream_chunk(content="response text", finish_reason="stop", usage=usage) - out = router.rechunk.openai_chat_completion2ollama(chunk, False, time.perf_counter()) - assert out.done is True - assert out.message.content == "response text" - assert out.prompt_eval_count == 2 - assert out.eval_count == 3 - - def test_nonstreaming_tool_calls_converted(self): - """Tool calls with JSON string arguments are parsed into dicts.""" - tc = _ns(function=_ns(name="get_weather", arguments='{"city": "Paris"}')) - usage = _ns(prompt_tokens=1, completion_tokens=1, total_tokens=2) - chunk = _nonstream_chunk( - content="", finish_reason="tool_calls", usage=usage, tool_calls=[tc] - ) - out = router.rechunk.openai_chat_completion2ollama(chunk, False, time.perf_counter()) - assert out.message.tool_calls is not None - assert len(out.message.tool_calls) == 1 - first = out.message.tool_calls[0] - assert first.function.name == "get_weather" - assert first.function.arguments == {"city": "Paris"} - - def test_nonstreaming_tool_calls_with_invalid_json_fall_back_to_empty(self): - tc = _ns(function=_ns(name="f", arguments="not-json")) - usage = _ns(prompt_tokens=1, completion_tokens=1, total_tokens=2) - chunk = _nonstream_chunk(content="", usage=usage, tool_calls=[tc]) - out = router.rechunk.openai_chat_completion2ollama(chunk, False, time.perf_counter()) - assert out.message.tool_calls[0].function.arguments == {} - - def test_streaming_tool_calls_in_delta_are_skipped(self): - """Streaming mode must not assemble tool calls (caller handles it).""" - chunk = _stream_chunk(content="x", finish_reason=None) - # Even if a chunk somehow carried tool_calls in the delta, streaming - # mode should ignore them. - out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter()) - assert out.message.tool_calls is None - - -# ────────────────────────────────────────────────────────────────────────────── -# openai_completion2ollama -# ────────────────────────────────────────────────────────────────────────────── - -class TestCompletionToOllama: - def test_streaming_text_chunk(self): - choice = _ns(text="word", finish_reason=None, reasoning=None) - chunk = _ns(model="m", choices=[choice], usage=None) - out = router.rechunk.openai_completion2ollama(chunk, True, time.perf_counter()) - assert isinstance(out, ollama.GenerateResponse) - assert out.response == "word" - assert out.done is False - - def test_final_chunk_with_usage(self): - usage = _ns(prompt_tokens=4, completion_tokens=6, total_tokens=10) - choice = _ns(text="end", finish_reason="stop", reasoning=None) - chunk = _ns(model="m", choices=[choice], usage=usage) - out = router.rechunk.openai_completion2ollama(chunk, True, time.perf_counter()) - assert out.done is True - assert out.prompt_eval_count == 4 - assert out.eval_count == 6 - - -# ────────────────────────────────────────────────────────────────────────────── -# embeddings / embed -# ────────────────────────────────────────────────────────────────────────────── - -class TestEmbeddingConversions: - def test_openai_embeddings2ollama(self): - chunk = _ns(data=[_ns(embedding=[0.1, 0.2, 0.3])]) - out = router.rechunk.openai_embeddings2ollama(chunk) - assert isinstance(out, ollama.EmbeddingsResponse) - assert list(out.embedding) == [0.1, 0.2, 0.3] - - def test_openai_embed2ollama(self): - chunk = _ns(data=[_ns(embedding=[0.5, 0.6])]) - out = router.rechunk.openai_embed2ollama(chunk, "my-embed-model") - assert isinstance(out, ollama.EmbedResponse) - assert out.model == "my-embed-model" - assert list(out.embeddings[0]) == [0.5, 0.6] - - -# ────────────────────────────────────────────────────────────────────────────── -# extract_usage_from_llama_timings -# ────────────────────────────────────────────────────────────────────────────── - -class TestExtractUsageFromLlamaTimings: - def test_none_when_no_timings_attr(self): - obj = _ns() - assert router.rechunk.extract_usage_from_llama_timings(obj) is None - - def test_prompt_plus_cache_sums(self): - obj = _ns(timings={"prompt_n": 1, "cache_n": 236, "predicted_n": 35}) - prompt, completion = router.rechunk.extract_usage_from_llama_timings(obj) - assert prompt == 237 - assert completion == 35 - - def test_missing_keys_default_to_zero(self): - obj = _ns(timings={"predicted_n": 12}) - prompt, completion = router.rechunk.extract_usage_from_llama_timings(obj) - assert prompt == 0 - assert completion == 12 - - def test_null_values_treated_as_zero(self): - obj = _ns(timings={"prompt_n": None, "cache_n": None, "predicted_n": None}) - prompt, completion = router.rechunk.extract_usage_from_llama_timings(obj) - assert prompt == 0 - assert completion == 0 - - def test_non_dict_timings_returns_none(self): - obj = _ns(timings="not-a-dict") - assert router.rechunk.extract_usage_from_llama_timings(obj) is None diff --git a/test/test_unit_transforms.py b/test/test_unit_transforms.py deleted file mode 100644 index 51160a0..0000000 --- a/test/test_unit_transforms.py +++ /dev/null @@ -1,200 +0,0 @@ -"""Unit tests for message transformation functions.""" -from unittest.mock import MagicMock - -import pytest - -import router - - -class TestStripAssistantPrefill: - def test_removes_trailing_assistant(self): - msgs = [ - {"role": "user", "content": "hello"}, - {"role": "assistant", "content": "prefill"}, - ] - result = router._strip_assistant_prefill(msgs) - assert len(result) == 1 - assert result[0]["role"] == "user" - - def test_keeps_non_trailing_assistant(self): - msgs = [ - {"role": "user", "content": "hello"}, - {"role": "assistant", "content": "response"}, - {"role": "user", "content": "follow-up"}, - ] - result = router._strip_assistant_prefill(msgs) - assert len(result) == 3 - - def test_empty_list_unchanged(self): - assert router._strip_assistant_prefill([]) == [] - - def test_single_user_message_unchanged(self): - msgs = [{"role": "user", "content": "hi"}] - assert router._strip_assistant_prefill(msgs) == msgs - - -class TestTransformToolCallsToOpenAI: - def test_adds_type_function(self): - msgs = [{"role": "assistant", "tool_calls": [ - {"function": {"name": "get_weather", "arguments": {"city": "Berlin"}}} - ]}] - result = router.transform_tool_calls_to_openai(msgs) - tc = result[0]["tool_calls"][0] - assert tc["type"] == "function" - - def test_adds_id_when_missing(self): - msgs = [{"role": "assistant", "tool_calls": [ - {"function": {"name": "fn", "arguments": {}}} - ]}] - result = router.transform_tool_calls_to_openai(msgs) - assert "id" in result[0]["tool_calls"][0] - - def test_converts_dict_arguments_to_string(self): - msgs = [{"role": "assistant", "tool_calls": [ - {"function": {"name": "fn", "arguments": {"key": "val"}}} - ]}] - result = router.transform_tool_calls_to_openai(msgs) - args = result[0]["tool_calls"][0]["function"]["arguments"] - assert isinstance(args, str) - import orjson - parsed = orjson.loads(args) - assert parsed == {"key": "val"} - - def test_keeps_string_arguments_unchanged(self): - msgs = [{"role": "assistant", "tool_calls": [ - {"function": {"name": "fn", "arguments": '{"key": "val"}'}} - ]}] - result = router.transform_tool_calls_to_openai(msgs) - args = result[0]["tool_calls"][0]["function"]["arguments"] - assert args == '{"key": "val"}' - - def test_links_tool_call_id_to_tool_response(self): - msgs = [ - {"role": "assistant", "tool_calls": [ - {"function": {"name": "get_weather", "arguments": {}}} - ]}, - {"role": "tool", "name": "get_weather", "content": "sunny"}, - ] - result = router.transform_tool_calls_to_openai(msgs) - tc_id = result[0]["tool_calls"][0]["id"] - assert result[1].get("tool_call_id") == tc_id - - def test_non_tool_messages_unchanged(self): - msgs = [{"role": "user", "content": "hello"}] - result = router.transform_tool_calls_to_openai(msgs) - assert result == msgs - - -class TestStripImagesFromMessages: - def test_removes_image_url_parts(self): - msgs = [{"role": "user", "content": [ - {"type": "text", "text": "what is this?"}, - {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, - ]}] - result = router._strip_images_from_messages(msgs) - content = result[0]["content"] - assert content == "what is this?" - - def test_keeps_text_only_messages(self): - msgs = [{"role": "user", "content": "plain text"}] - result = router._strip_images_from_messages(msgs) - assert result[0]["content"] == "plain text" - - def test_multiple_text_parts_kept_as_list(self): - msgs = [{"role": "user", "content": [ - {"type": "text", "text": "part one"}, - {"type": "text", "text": "part two"}, - {"type": "image_url", "image_url": {"url": "data:..."}}, - ]}] - result = router._strip_images_from_messages(msgs) - content = result[0]["content"] - assert isinstance(content, list) - assert len(content) == 2 - - def test_all_images_removed_empty_list(self): - msgs = [{"role": "user", "content": [ - {"type": "image_url", "image_url": {"url": "data:..."}}, - ]}] - result = router._strip_images_from_messages(msgs) - # Image-only content becomes empty list - content = result[0]["content"] - assert content == [] - - -class TestAccumulateOpenAITcDelta: - def _make_chunk(self, index, name=None, args_fragment="", tc_id=None): - delta = MagicMock() - tc = MagicMock() - tc.index = index - tc.id = tc_id - tc.function = MagicMock() - tc.function.name = name - tc.function.arguments = args_fragment - delta.tool_calls = [tc] - chunk = MagicMock() - chunk.choices = [MagicMock(delta=delta)] - return chunk - - def test_first_delta_creates_entry(self): - acc = {} - chunk = self._make_chunk(0, name="my_fn", args_fragment='{"k"') - router._accumulate_openai_tc_delta(chunk, acc) - assert 0 in acc - assert acc[0]["name"] == "my_fn" - assert acc[0]["arguments"] == '{"k"' - - def test_subsequent_deltas_concatenate_args(self): - acc = {} - router._accumulate_openai_tc_delta(self._make_chunk(0, name="fn", args_fragment='{"k"'), acc) - router._accumulate_openai_tc_delta(self._make_chunk(0, args_fragment=': "v"}'), acc) - assert acc[0]["arguments"] == '{"k": "v"}' - - def test_multiple_tool_calls_tracked_separately(self): - acc = {} - c1 = self._make_chunk(0, name="fn1", args_fragment="{}") - c2 = self._make_chunk(1, name="fn2", args_fragment="{}") - chunk = MagicMock() - tc1 = MagicMock() - tc1.index = 0 - tc1.id = "id1" - tc1.function = MagicMock(name="fn1", arguments="{}") - tc2 = MagicMock() - tc2.index = 1 - tc2.id = "id2" - tc2.function = MagicMock(name="fn2", arguments="{}") - chunk.choices = [MagicMock(delta=MagicMock(tool_calls=[tc1, tc2]))] - router._accumulate_openai_tc_delta(chunk, acc) - assert 0 in acc and 1 in acc - - def test_no_choices_is_noop(self): - acc = {} - chunk = MagicMock(choices=[]) - router._accumulate_openai_tc_delta(chunk, acc) - assert acc == {} - - -class TestBuildOllamaToolCalls: - def test_builds_from_accumulator(self): - acc = {0: {"id": "call_abc", "name": "get_weather", "arguments": '{"city": "Berlin"}'}} - result = router._build_ollama_tool_calls(acc) - assert result is not None - assert len(result) == 1 - assert result[0].function.name == "get_weather" - assert result[0].function.arguments == {"city": "Berlin"} - - def test_invalid_json_args_becomes_empty_dict(self): - acc = {0: {"id": "c1", "name": "fn", "arguments": "not-json"}} - result = router._build_ollama_tool_calls(acc) - assert result[0].function.arguments == {} - - def test_empty_accumulator_returns_none(self): - assert router._build_ollama_tool_calls({}) is None - - def test_preserves_order_by_index(self): - acc = { - 1: {"id": "c2", "name": "fn2", "arguments": "{}"}, - 0: {"id": "c1", "name": "fn1", "arguments": "{}"}, - } - result = router._build_ollama_tool_calls(acc) - assert result[0].function.name == "fn1" - assert result[1].function.name == "fn2"