refac: modularize apis VII
This commit is contained in:
parent
e74f5d1ba6
commit
4b5a70e787
7 changed files with 2244 additions and 2108 deletions
0
api/__init__.py
Normal file
0
api/__init__.py
Normal file
278
api/management.py
Normal file
278
api/management.py
Normal file
|
|
@ -0,0 +1,278 @@
|
|||
"""Management / observability routes.
|
||||
|
||||
Read-only endpoints used by the dashboard and external monitoring:
|
||||
* usage counters and token-counts breakdown,
|
||||
* conversation-affinity introspection,
|
||||
* endpoint health summary,
|
||||
* LLM-response cache stats and invalidation,
|
||||
* SSE live-stream of usage updates,
|
||||
* hostname and ``/health`` probe.
|
||||
"""
|
||||
import asyncio
|
||||
import socket
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import orjson
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from cache import get_llm_cache
|
||||
from config import get_config
|
||||
from db import get_db
|
||||
from state import (
|
||||
usage_counts,
|
||||
token_usage_counts,
|
||||
_affinity_map,
|
||||
_affinity_lock,
|
||||
)
|
||||
from sse import subscribe, unsubscribe
|
||||
from backends.normalize import _normalize_llama_model_name
|
||||
from backends.probe import _endpoint_health
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/api/token_counts")
|
||||
async def token_counts_proxy():
|
||||
breakdown = []
|
||||
total = 0
|
||||
async for entry in get_db().load_token_counts():
|
||||
total += entry['total_tokens']
|
||||
breakdown.append({
|
||||
"endpoint": entry["endpoint"],
|
||||
"model": entry["model"],
|
||||
"input_tokens": entry["input_tokens"],
|
||||
"output_tokens": entry["output_tokens"],
|
||||
"total_tokens": entry["total_tokens"],
|
||||
})
|
||||
return {"total_tokens": total, "breakdown": breakdown}
|
||||
|
||||
|
||||
@router.post("/api/aggregate_time_series_days")
|
||||
async def aggregate_time_series_days_proxy(request: Request):
|
||||
"""
|
||||
Aggregate time_series entries older than days into daily aggregates by endpoint/model/date.
|
||||
"""
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
if not body_bytes:
|
||||
days = 30
|
||||
trim_old = False
|
||||
else:
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
days = int(payload.get("days", 30))
|
||||
trim_old = bool(payload.get("trim_old", False))
|
||||
except Exception:
|
||||
days = 30
|
||||
trim_old = False
|
||||
aggregated = await get_db().aggregate_time_series_older_than(days, trim_old=trim_old)
|
||||
return {"status": "ok", "days": days, "trim_old": trim_old, "aggregated_groups": aggregated}
|
||||
|
||||
|
||||
@router.post("/api/stats")
|
||||
async def stats_proxy(request: Request, model: Optional[str] = None):
|
||||
"""
|
||||
Return token usage statistics for a specific model.
|
||||
"""
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
|
||||
if not model:
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
model = payload.get("model")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
db = get_db()
|
||||
token_data = await db.get_token_counts_for_model(model)
|
||||
|
||||
if not token_data:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="No token data found for this model"
|
||||
)
|
||||
|
||||
time_series = [
|
||||
entry async for entry in db.get_time_series_for_model(model)
|
||||
]
|
||||
endpoint_distribution = await db.get_endpoint_distribution_for_model(model)
|
||||
|
||||
return {
|
||||
'model': model,
|
||||
'input_tokens': token_data['input_tokens'],
|
||||
'output_tokens': token_data['output_tokens'],
|
||||
'total_tokens': token_data['total_tokens'],
|
||||
'time_series': time_series,
|
||||
'endpoint_distribution': endpoint_distribution,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/api/affinity_stats")
|
||||
async def affinity_stats(request: Request):
|
||||
"""
|
||||
Aggregate live conversation-affinity pins, one entry per pinned conversation.
|
||||
Each entry exposes only the endpoint, model, and remaining TTL in seconds —
|
||||
no fingerprints or content. When conversation_affinity is disabled the
|
||||
`entries` list is always empty.
|
||||
"""
|
||||
config = get_config()
|
||||
if not config.conversation_affinity:
|
||||
return {"enabled": False, "ttl": config.conversation_affinity_ttl, "entries": []}
|
||||
|
||||
now = time.monotonic()
|
||||
entries: list[dict] = []
|
||||
llama_eps = set(config.llama_server_endpoints)
|
||||
async with _affinity_lock:
|
||||
for fp, (ep, mdl, expires_at) in list(_affinity_map.items()):
|
||||
remaining = expires_at - now
|
||||
if remaining <= 0:
|
||||
_affinity_map.pop(fp, None)
|
||||
continue
|
||||
# Mirror the normalisation used by /api/ps_details so the dashboard
|
||||
# can join affinity entries to PS rows by (endpoint, model).
|
||||
display_model = _normalize_llama_model_name(mdl) if ep in llama_eps else mdl
|
||||
entries.append({
|
||||
"endpoint": ep,
|
||||
"model": display_model,
|
||||
"remaining": round(remaining, 2),
|
||||
})
|
||||
return {
|
||||
"enabled": True,
|
||||
"ttl": config.conversation_affinity_ttl,
|
||||
"entries": entries,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/api/usage")
|
||||
async def usage_proxy(request: Request):
|
||||
"""
|
||||
Return a snapshot of the usage counter for each endpoint.
|
||||
Useful for debugging / monitoring.
|
||||
"""
|
||||
return {"usage_counts": usage_counts,
|
||||
"token_usage_counts": token_usage_counts}
|
||||
|
||||
|
||||
@router.get("/api/config")
|
||||
async def config_proxy(request: Request):
|
||||
"""
|
||||
Return a simple JSON object that contains the configured
|
||||
Ollama endpoints and llama_server_endpoints. The front‑end uses this
|
||||
to display which endpoints are being proxied and their health.
|
||||
Status is "error" when either liveness (/api/version) or routing
|
||||
health (/api/ps) fails — see issue #83.
|
||||
"""
|
||||
config = get_config()
|
||||
|
||||
async def check(url: str) -> dict:
|
||||
return {"url": url, **(await _endpoint_health(url, timeout=5))}
|
||||
|
||||
ollama_results = await asyncio.gather(*[check(ep) for ep in config.endpoints])
|
||||
llama_results = []
|
||||
if config.llama_server_endpoints:
|
||||
llama_results = await asyncio.gather(
|
||||
*[check(ep) for ep in config.llama_server_endpoints]
|
||||
)
|
||||
|
||||
return {
|
||||
"endpoints": ollama_results,
|
||||
"llama_server_endpoints": llama_results,
|
||||
"require_router_api_key": bool(config.router_api_key),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/api/cache/stats")
|
||||
async def cache_stats():
|
||||
"""Return hit/miss counters and configuration for the LLM response cache."""
|
||||
c = get_llm_cache()
|
||||
if c is None:
|
||||
return {"enabled": False}
|
||||
return {"enabled": True, **c.stats()}
|
||||
|
||||
|
||||
@router.post("/api/cache/invalidate")
|
||||
async def cache_invalidate():
|
||||
"""Clear all entries from the LLM response cache and reset counters."""
|
||||
c = get_llm_cache()
|
||||
if c is None:
|
||||
return {"enabled": False, "cleared": False}
|
||||
await c.clear()
|
||||
return {"enabled": True, "cleared": True}
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_proxy(request: Request):
|
||||
"""
|
||||
Health‑check endpoint for monitoring the proxy.
|
||||
|
||||
* Queries each configured endpoint for both liveness and routing health:
|
||||
Ollama endpoints are probed at `/api/version` AND `/api/ps`,
|
||||
OpenAI-compatible endpoints at `/models`.
|
||||
* Returns a JSON object containing:
|
||||
- `status`: "ok" if every endpoint replied to every probe, otherwise "error".
|
||||
- `endpoints`: a mapping of endpoint URL → `{status, version|detail}`.
|
||||
* The HTTP status code is 200 when everything is healthy, 503 otherwise.
|
||||
"""
|
||||
config = get_config()
|
||||
# Run all health checks in parallel.
|
||||
# Ollama endpoints expose /api/version (liveness) and /api/ps (routing
|
||||
# health — required by `choose_endpoint`). OpenAI-compatible endpoints
|
||||
# (vLLM, llama-server, external) expose /models, which serves both
|
||||
# purposes. Probing /api/version alone would miss the case where the
|
||||
# Ollama process is up but /api/ps is failing — see issue #83.
|
||||
all_endpoints = list(config.endpoints)
|
||||
llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints]
|
||||
all_endpoints += llama_eps_extra
|
||||
|
||||
probe_results = await asyncio.gather(
|
||||
*(_endpoint_health(ep) for ep in all_endpoints),
|
||||
)
|
||||
|
||||
health_summary = dict(zip(all_endpoints, probe_results))
|
||||
overall_ok = all(entry.get("status") == "ok" for entry in probe_results)
|
||||
|
||||
response_payload = {
|
||||
"status": "ok" if overall_ok else "error",
|
||||
"endpoints": health_summary,
|
||||
}
|
||||
|
||||
http_status = 200 if overall_ok else 503
|
||||
return JSONResponse(content=response_payload, status_code=http_status)
|
||||
|
||||
|
||||
@router.get("/api/hostname")
|
||||
async def get_hostname():
|
||||
"""Return the hostname of the machine running the router."""
|
||||
return JSONResponse(content={"hostname": socket.gethostname()})
|
||||
|
||||
|
||||
@router.get("/api/usage-stream")
|
||||
async def usage_stream(request: Request):
|
||||
"""
|
||||
Server‑Sent‑Events that emits a JSON payload every time the
|
||||
global `usage_counts` dictionary changes.
|
||||
"""
|
||||
async def event_generator():
|
||||
# The queue that receives *every* new snapshot
|
||||
queue = await subscribe()
|
||||
try:
|
||||
while True:
|
||||
# If the client disconnects, cancel the loop
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
data = await queue.get()
|
||||
if data is None:
|
||||
break
|
||||
# Send the data as a single SSE message
|
||||
yield f"data: {data}\n\n"
|
||||
finally:
|
||||
# Clean‑up: unsubscribe from the broadcast channel
|
||||
await unsubscribe(queue)
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
1106
api/ollama.py
Normal file
1106
api/ollama.py
Normal file
File diff suppressed because it is too large
Load diff
804
api/openai.py
Normal file
804
api/openai.py
Normal file
|
|
@ -0,0 +1,804 @@
|
|||
"""OpenAI-compatible routes (``/v1/embeddings``, ``/v1/chat/completions``,
|
||||
``/v1/completions``, ``/v1/models``, ``/v1/rerank`` and ``/rerank``).
|
||||
|
||||
The chat-completions and completions handlers carry the full reactive-trim
|
||||
logic for ``exceed_context_size_error`` plus connection-failure rerouting
|
||||
(``_mark_backend_unhealthy``). The streaming branches assemble cached
|
||||
responses on the fly so caching works for both streaming and non-streaming
|
||||
clients.
|
||||
"""
|
||||
import asyncio
|
||||
import base64
|
||||
import math
|
||||
|
||||
import aiohttp
|
||||
import orjson
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from cache import get_llm_cache, openai_nonstream_to_sse
|
||||
from config import get_config
|
||||
from context_window import (
|
||||
_count_message_tokens,
|
||||
_trim_messages_for_context,
|
||||
_calibrated_trim_target,
|
||||
_endpoint_nctx,
|
||||
_CTX_TRIM_SMALL_LIMIT,
|
||||
)
|
||||
from fingerprint import _conversation_fingerprint
|
||||
from security import _mask_secrets
|
||||
from state import token_queue, app_state, default_headers
|
||||
from backends.health import _is_backend_connection_error, _mark_backend_unhealthy
|
||||
from backends.normalize import (
|
||||
dedupe_on_keys,
|
||||
ep2base,
|
||||
is_ext_openai_endpoint,
|
||||
is_openai_compatible,
|
||||
_normalize_llama_model_name,
|
||||
)
|
||||
from backends.probe import fetch
|
||||
from backends.sessions import _make_openai_client, get_session
|
||||
from requests.messages import _strip_assistant_prefill, _strip_images_from_messages
|
||||
from requests.rechunk import rechunk
|
||||
from routing import choose_endpoint, decrement_usage
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
async def openai_embedding_proxy(request: Request):
|
||||
"""
|
||||
Proxy an OpenAI API compatible embedding request to Ollama and reply with embeddings.
|
||||
|
||||
"""
|
||||
config = get_config()
|
||||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
doc = payload.get("input")
|
||||
|
||||
# Normalize multimodal input: extract only text parts for embedding models
|
||||
if isinstance(doc, list):
|
||||
normalized = []
|
||||
for item in doc:
|
||||
if isinstance(item, dict):
|
||||
# Multimodal content part - extract text only, skip images
|
||||
if item.get("type") == "text":
|
||||
normalized.append(item.get("text", ""))
|
||||
# Skip image_url and other non-text types
|
||||
else:
|
||||
normalized.append(item)
|
||||
doc = normalized if len(normalized) != 1 else normalized[0]
|
||||
elif isinstance(doc, dict) and doc.get("type") == "text":
|
||||
doc = doc.get("text", "")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
if not doc:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'input'"
|
||||
)
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint, tracking_model = await choose_endpoint(model)
|
||||
if is_openai_compatible(endpoint):
|
||||
api_key = config.api_keys.get(endpoint, "no-key")
|
||||
else:
|
||||
api_key = "ollama"
|
||||
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=api_key)
|
||||
|
||||
try:
|
||||
async_gen = await oclient.embeddings.create(input=doc, model=model)
|
||||
result = async_gen.model_dump()
|
||||
for item in result.get("data", []):
|
||||
emb = item.get("embedding")
|
||||
if emb:
|
||||
item["embedding"] = [0.0 if isinstance(v, float) and not math.isfinite(v) else v for v in emb]
|
||||
return JSONResponse(content=result)
|
||||
finally:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
async def openai_chat_completions_proxy(request: Request):
|
||||
"""
|
||||
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
|
||||
|
||||
"""
|
||||
config = get_config()
|
||||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
messages = payload.get("messages")
|
||||
frequency_penalty = payload.get("frequency_penalty")
|
||||
presence_penalty = payload.get("presence_penalty")
|
||||
response_format = payload.get("response_format")
|
||||
seed = payload.get("seed")
|
||||
stop = payload.get("stop")
|
||||
stream = payload.get("stream")
|
||||
stream_options = payload.get("stream_options")
|
||||
temperature = payload.get("temperature")
|
||||
top_p = payload.get("top_p")
|
||||
max_tokens = payload.get("max_tokens")
|
||||
max_completion_tokens = payload.get("max_completion_tokens")
|
||||
tools = payload.get("tools")
|
||||
logprobs = payload.get("logprobs")
|
||||
top_logprobs = payload.get("top_logprobs")
|
||||
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
if not isinstance(messages, list):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'messages' (must be a list)"
|
||||
)
|
||||
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
model = model[0]
|
||||
|
||||
messages = _strip_assistant_prefill(messages)
|
||||
params = {
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
optional_params = {
|
||||
"tools": tools,
|
||||
"response_format": response_format,
|
||||
"stream_options": stream_options or {"include_usage": True },
|
||||
"max_completion_tokens": max_completion_tokens,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"seed": seed,
|
||||
"presence_penalty": presence_penalty,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
"logprobs": logprobs,
|
||||
"top_logprobs": top_logprobs,
|
||||
}
|
||||
|
||||
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# Reject unsupported image formats (SVG) before doing any work
|
||||
for _msg in messages:
|
||||
for _item in (_msg.get("content") or []) if isinstance(_msg.get("content"), list) else []:
|
||||
if _item.get("type") == "image_url":
|
||||
_url = (_item.get("image_url") or {}).get("url", "")
|
||||
if _url.startswith("data:image/svg") or _url.lower().endswith(".svg"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="SVG images are not supported. Please convert the image to PNG or JPEG before sending.",
|
||||
)
|
||||
|
||||
# Cache lookup — before endpoint selection
|
||||
_cache = get_llm_cache()
|
||||
if _cache is not None and _cache_enabled:
|
||||
_cached = await _cache.get_chat("openai_chat", model, messages)
|
||||
if _cached is not None:
|
||||
if stream:
|
||||
_sse = openai_nonstream_to_sse(_cached, model)
|
||||
async def _serve_cached_ochat_stream():
|
||||
yield _sse
|
||||
return StreamingResponse(_serve_cached_ochat_stream(), media_type="text/event-stream")
|
||||
else:
|
||||
async def _serve_cached_ochat_json():
|
||||
yield _cached
|
||||
return StreamingResponse(_serve_cached_ochat_json(), media_type="application/json")
|
||||
|
||||
# 2. Endpoint logic
|
||||
_affinity_key = _conversation_fingerprint(model, messages, None)
|
||||
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
|
||||
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
# 3. Helpers and API call — done in handler scope so try/except works reliably
|
||||
async def _normalize_images_in_messages(msgs: list) -> list:
|
||||
"""Fetch remote image URLs and convert them to base64 data URLs so
|
||||
Ollama/llama-server can handle them without making outbound HTTP requests."""
|
||||
resolved = []
|
||||
for msg in msgs:
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
resolved.append(msg)
|
||||
continue
|
||||
new_content = []
|
||||
for item in content:
|
||||
if item.get("type") == "image_url":
|
||||
url = (item.get("image_url") or {}).get("url", "")
|
||||
if url and not url.startswith("data:"):
|
||||
try:
|
||||
http: aiohttp.ClientSession = app_state["session"]
|
||||
async with http.get(url) as resp:
|
||||
ctype = resp.headers.get("Content-Type", "image/jpeg").split(";")[0].strip()
|
||||
img_bytes = await resp.read()
|
||||
b64 = base64.b64encode(img_bytes).decode("utf-8")
|
||||
new_content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{ctype};base64,{b64}"}
|
||||
})
|
||||
except Exception as _ie:
|
||||
print(f"[image] Failed to fetch image URL: {_ie}")
|
||||
new_content.append(item)
|
||||
else:
|
||||
new_content.append(item)
|
||||
else:
|
||||
new_content.append(item)
|
||||
resolved.append({**msg, "content": new_content})
|
||||
return resolved
|
||||
|
||||
# Make the API call in handler scope — try/except inside async generators is unreliable
|
||||
# with Starlette's streaming machinery, so we resolve errors here before the generator starts.
|
||||
send_params = params
|
||||
if not is_ext_openai_endpoint(endpoint):
|
||||
resolved_msgs = await _normalize_images_in_messages(params.get("messages", []))
|
||||
send_params = {**params, "messages": resolved_msgs}
|
||||
# Proactive trim: only for small-ctx models we've already seen run out of space
|
||||
_lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model
|
||||
_known_nctx = _endpoint_nctx.get((endpoint, _lookup_model))
|
||||
if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
||||
_pre_target = int(((_known_nctx - _known_nctx // 4)) / 1.2)
|
||||
_pre_est = _count_message_tokens(send_params.get("messages", []))
|
||||
if _pre_est > _pre_target:
|
||||
_pre_msgs = send_params.get("messages", [])
|
||||
_pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target)
|
||||
_dropped = len(_pre_msgs) - len(_pre_trimmed)
|
||||
print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True)
|
||||
send_params = {**send_params, "messages": _pre_trimmed}
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**send_params)
|
||||
except Exception as e:
|
||||
_e_str = str(e)
|
||||
_is_ctx_err = "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str
|
||||
print(f"[ochat] caught={type(e).__name__} ctx={_is_ctx_err} msg={_e_str[:120]}", flush=True)
|
||||
if "does not support tools" in _e_str:
|
||||
# Model doesn't support tools — retry without them
|
||||
print(f"[ochat] retry: no tools", flush=True)
|
||||
try:
|
||||
params_without_tools = {k: v for k, v in send_params.items() if k != "tools"}
|
||||
async_gen = await oclient.chat.completions.create(**params_without_tools)
|
||||
except Exception:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
elif _is_ctx_err:
|
||||
# Backend context limit hit — apply sliding-window trim (context-shift at message level)
|
||||
err_body = getattr(e, "body", {}) or {}
|
||||
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
|
||||
n_ctx_limit = err_detail.get("n_ctx", 0)
|
||||
actual_tokens = err_detail.get("n_prompt_tokens", 0)
|
||||
# Fallback: parse from string if body parsing yielded nothing (SDK may not parse llama-server errors)
|
||||
if not n_ctx_limit:
|
||||
import re as _re
|
||||
_m = _re.search(r"'n_ctx':\s*(\d+)", _e_str)
|
||||
if _m:
|
||||
n_ctx_limit = int(_m.group(1))
|
||||
_m = _re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str)
|
||||
if _m:
|
||||
actual_tokens = int(_m.group(1))
|
||||
print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True)
|
||||
if not n_ctx_limit:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
|
||||
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
|
||||
|
||||
msgs_to_trim = send_params.get("messages", [])
|
||||
try:
|
||||
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
|
||||
trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
|
||||
except Exception as _helper_exc:
|
||||
print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True)
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
dropped = len(msgs_to_trim) - len(trimmed_messages)
|
||||
print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True)
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages})
|
||||
print(f"[ctx-trim] retry-1 ok", flush=True)
|
||||
except Exception as e2:
|
||||
_e2_str = str(e2)
|
||||
if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str:
|
||||
# Still too large — tool definitions likely consuming too many tokens, strip them too
|
||||
print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True)
|
||||
params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")}
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages})
|
||||
print(f"[ctx-trim] retry-2 ok", flush=True)
|
||||
except Exception:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
else:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
elif _is_backend_connection_error(e):
|
||||
# Upstream connection failed (e.g. llama-server in router mode
|
||||
# whose delegated worker died). Mark (endpoint, model) so the
|
||||
# next request reroutes; the client will retry this one.
|
||||
print(f"[ochat] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
|
||||
await _mark_backend_unhealthy(endpoint, model, _e_str)
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
elif "image input is not supported" in _e_str:
|
||||
# Model doesn't support images — strip and retry
|
||||
print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages")
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))})
|
||||
except Exception:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
else:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
|
||||
# 4. Async generator — only streams the already-established async_gen
|
||||
async def stream_ochat_response():
|
||||
try:
|
||||
if stream == True:
|
||||
content_parts: list[str] = []
|
||||
usage_snapshot: dict = {}
|
||||
async for chunk in async_gen:
|
||||
data = (
|
||||
chunk.model_dump_json()
|
||||
if hasattr(chunk, "model_dump_json")
|
||||
else orjson.dumps(chunk)
|
||||
)
|
||||
if chunk.choices:
|
||||
delta = chunk.choices[0].delta
|
||||
has_content = delta.content is not None
|
||||
has_reasoning = (
|
||||
getattr(delta, "reasoning_content", None) is not None
|
||||
or getattr(delta, "reasoning", None) is not None
|
||||
)
|
||||
has_tool_calls = getattr(delta, "tool_calls", None) is not None
|
||||
if has_content or has_reasoning or has_tool_calls:
|
||||
yield f"data: {data}\n\n".encode("utf-8")
|
||||
if has_content and delta.content:
|
||||
content_parts.append(delta.content)
|
||||
elif chunk.usage is not None:
|
||||
# Forward the usage-only final chunk (e.g. from llama-server)
|
||||
yield f"data: {data}\n\n".encode("utf-8")
|
||||
prompt_tok = 0
|
||||
comp_tok = 0
|
||||
if chunk.usage is not None:
|
||||
prompt_tok = chunk.usage.prompt_tokens or 0
|
||||
comp_tok = chunk.usage.completion_tokens or 0
|
||||
usage_snapshot = {"prompt_tokens": prompt_tok, "completion_tokens": comp_tok, "total_tokens": prompt_tok + comp_tok}
|
||||
else:
|
||||
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
|
||||
if llama_usage:
|
||||
prompt_tok, comp_tok = llama_usage
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
# Detect context exhaustion mid-generation for small-ctx models.
|
||||
# Guard: skip if max_tokens was set in the request — finish_reason=length
|
||||
# could just mean the caller's token budget was exhausted, not the context window.
|
||||
_req_max_tok = send_params.get("max_tokens") or send_params.get("max_completion_tokens")
|
||||
if chunk.choices and chunk.choices[0].finish_reason == "length" and not _req_max_tok:
|
||||
_inferred_nctx = (prompt_tok + comp_tok) or 0
|
||||
if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
||||
_endpoint_nctx[(endpoint, model)] = _inferred_nctx
|
||||
print(f"[ctx-cache] finish_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True)
|
||||
# Cache assembled streaming response — before [DONE] so it always runs
|
||||
if _cache is not None and _cache_enabled and content_parts:
|
||||
assembled = orjson.dumps({
|
||||
"model": model,
|
||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": "".join(content_parts)}, "finish_reason": "stop"}],
|
||||
**({"usage": usage_snapshot} if usage_snapshot else {}),
|
||||
}) + b"\n"
|
||||
try:
|
||||
await _cache.set_chat("openai_chat", model, messages, assembled)
|
||||
except Exception as _ce:
|
||||
print(f"[cache] set_chat (openai_chat streaming) failed: {_ce}")
|
||||
yield b"data: [DONE]\n\n"
|
||||
else:
|
||||
prompt_tok = 0
|
||||
comp_tok = 0
|
||||
if async_gen.usage is not None:
|
||||
prompt_tok = async_gen.usage.prompt_tokens or 0
|
||||
comp_tok = async_gen.usage.completion_tokens or 0
|
||||
else:
|
||||
llama_usage = rechunk.extract_usage_from_llama_timings(async_gen)
|
||||
if llama_usage:
|
||||
prompt_tok, comp_tok = llama_usage
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
json_line = (
|
||||
async_gen.model_dump_json()
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
else orjson.dumps(async_gen)
|
||||
)
|
||||
cache_bytes = json_line.encode("utf-8") + b"\n"
|
||||
yield cache_bytes
|
||||
# Cache non-streaming response
|
||||
if _cache is not None and _cache_enabled:
|
||||
try:
|
||||
await _cache.set_chat("openai_chat", model, messages, cache_bytes)
|
||||
except Exception as _ce:
|
||||
print(f"[cache] set_chat (openai_chat non-streaming) failed: {_ce}")
|
||||
|
||||
finally:
|
||||
# Ensure counter is decremented even if an exception occurs
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
# 4. Return a StreamingResponse backed by the generator
|
||||
return StreamingResponse(
|
||||
stream_ochat_response(),
|
||||
media_type="text/event-stream" if stream else "application/json",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/completions")
|
||||
async def openai_completions_proxy(request: Request):
|
||||
"""
|
||||
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
|
||||
|
||||
"""
|
||||
config = get_config()
|
||||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
prompt = payload.get("prompt")
|
||||
frequency_penalty = payload.get("frequency_penalty")
|
||||
presence_penalty = payload.get("presence_penalty")
|
||||
seed = payload.get("seed")
|
||||
stop = payload.get("stop")
|
||||
stream = payload.get("stream")
|
||||
stream_options = payload.get("stream_options")
|
||||
temperature = payload.get("temperature")
|
||||
top_p = payload.get("top_p")
|
||||
max_tokens = payload.get("max_tokens")
|
||||
max_completion_tokens = payload.get("max_completion_tokens")
|
||||
suffix = payload.get("suffix")
|
||||
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
if not prompt:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'prompt'"
|
||||
)
|
||||
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
model = model[0]
|
||||
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
optional_params = {
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"presence_penalty": presence_penalty,
|
||||
"seed": seed,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
"stream_options": stream_options or {"include_usage": True },
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"max_tokens": max_tokens,
|
||||
"max_completion_tokens": max_completion_tokens,
|
||||
"suffix": suffix
|
||||
}
|
||||
|
||||
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# Cache lookup — completions prompt mapped to a single-turn messages list
|
||||
_cache = get_llm_cache()
|
||||
_compl_messages = [{"role": "user", "content": prompt}]
|
||||
if _cache is not None and _cache_enabled:
|
||||
_cached = await _cache.get_chat("openai_completions", model, _compl_messages)
|
||||
if _cached is not None:
|
||||
if stream:
|
||||
_sse = openai_nonstream_to_sse(_cached, model)
|
||||
async def _serve_cached_ocompl_stream():
|
||||
yield _sse
|
||||
return StreamingResponse(_serve_cached_ocompl_stream(), media_type="text/event-stream")
|
||||
else:
|
||||
async def _serve_cached_ocompl_json():
|
||||
yield _cached
|
||||
return StreamingResponse(_serve_cached_ocompl_json(), media_type="application/json")
|
||||
|
||||
# 2. Endpoint logic
|
||||
_affinity_key = _conversation_fingerprint(model, None, prompt)
|
||||
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
|
||||
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
|
||||
# 3. Async generator that streams completions data and decrements the counter
|
||||
# Make the API call in handler scope (try/except inside async generators is unreliable)
|
||||
try:
|
||||
async_gen = await oclient.completions.create(**params)
|
||||
except Exception as e:
|
||||
if _is_backend_connection_error(e):
|
||||
print(f"[ocompl] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
|
||||
await _mark_backend_unhealthy(endpoint, model, str(e))
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
|
||||
async def stream_ocompletions_response(model=model):
|
||||
try:
|
||||
if stream == True:
|
||||
text_parts: list[str] = []
|
||||
usage_snapshot: dict = {}
|
||||
async for chunk in async_gen:
|
||||
data = (
|
||||
chunk.model_dump_json()
|
||||
if hasattr(chunk, "model_dump_json")
|
||||
else orjson.dumps(chunk)
|
||||
)
|
||||
if chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
has_text = getattr(choice, "text", None) is not None
|
||||
has_reasoning = (
|
||||
getattr(choice, "reasoning_content", None) is not None
|
||||
or getattr(choice, "reasoning", None) is not None
|
||||
)
|
||||
if has_text or has_reasoning or choice.finish_reason is not None:
|
||||
yield f"data: {data}\n\n".encode("utf-8")
|
||||
if has_text and choice.text:
|
||||
text_parts.append(choice.text)
|
||||
elif chunk.usage is not None:
|
||||
# Forward the usage-only final chunk (e.g. from llama-server)
|
||||
yield f"data: {data}\n\n".encode("utf-8")
|
||||
prompt_tok = 0
|
||||
comp_tok = 0
|
||||
if chunk.usage is not None:
|
||||
prompt_tok = chunk.usage.prompt_tokens or 0
|
||||
comp_tok = chunk.usage.completion_tokens or 0
|
||||
usage_snapshot = {"prompt_tokens": prompt_tok, "completion_tokens": comp_tok, "total_tokens": prompt_tok + comp_tok}
|
||||
else:
|
||||
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
|
||||
if llama_usage:
|
||||
prompt_tok, comp_tok = llama_usage
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
# Cache assembled streaming response — before [DONE] so it always runs
|
||||
if _cache is not None and _cache_enabled and text_parts:
|
||||
assembled = orjson.dumps({
|
||||
"model": model,
|
||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": "".join(text_parts)}, "finish_reason": "stop"}],
|
||||
**({"usage": usage_snapshot} if usage_snapshot else {}),
|
||||
}) + b"\n"
|
||||
try:
|
||||
await _cache.set_chat("openai_completions", model, _compl_messages, assembled)
|
||||
except Exception as _ce:
|
||||
print(f"[cache] set_chat (openai_completions streaming) failed: {_ce}")
|
||||
# Final DONE event
|
||||
yield b"data: [DONE]\n\n"
|
||||
else:
|
||||
prompt_tok = 0
|
||||
comp_tok = 0
|
||||
if async_gen.usage is not None:
|
||||
prompt_tok = async_gen.usage.prompt_tokens or 0
|
||||
comp_tok = async_gen.usage.completion_tokens or 0
|
||||
else:
|
||||
llama_usage = rechunk.extract_usage_from_llama_timings(async_gen)
|
||||
if llama_usage:
|
||||
prompt_tok, comp_tok = llama_usage
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
json_line = (
|
||||
async_gen.model_dump_json()
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
else orjson.dumps(async_gen)
|
||||
)
|
||||
cache_bytes = json_line.encode("utf-8") + b"\n"
|
||||
yield cache_bytes
|
||||
# Cache non-streaming response
|
||||
if _cache is not None and _cache_enabled:
|
||||
try:
|
||||
await _cache.set_chat("openai_completions", model, _compl_messages, cache_bytes)
|
||||
except Exception as _ce:
|
||||
print(f"[cache] set_chat (openai_completions non-streaming) failed: {_ce}")
|
||||
|
||||
finally:
|
||||
# Ensure counter is decremented even if an exception occurs
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
# 4. Return a StreamingResponse backed by the generator
|
||||
return StreamingResponse(
|
||||
stream_ocompletions_response(),
|
||||
media_type="text/event-stream" if stream else "application/json",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/v1/models")
|
||||
async def openai_models_proxy(request: Request):
|
||||
"""
|
||||
Proxy an OpenAI API models request to Ollama and llama-server endpoints and reply with a unique list of models.
|
||||
|
||||
For Ollama endpoints: queries /api/tags (all models)
|
||||
For llama-server endpoints: queries /v1/models and filters for status.value == "loaded"
|
||||
"""
|
||||
config = get_config()
|
||||
# 1. Query Ollama endpoints for all models via /api/tags
|
||||
ollama_tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep]
|
||||
# 2. Query external OpenAI endpoints (Groq, OpenAI, etc.) via /models
|
||||
ext_openai_tasks = [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in config.endpoints if is_ext_openai_endpoint(ep)]
|
||||
# 3. Query llama-server endpoints for loaded models via /v1/models
|
||||
# Also query endpoints from llama_server_endpoints that may not be in config.endpoints
|
||||
all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints)
|
||||
llama_tasks = [
|
||||
fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8)
|
||||
for ep in all_llama_endpoints
|
||||
]
|
||||
|
||||
ollama_models = await asyncio.gather(*ollama_tasks) if ollama_tasks else []
|
||||
ext_openai_models = await asyncio.gather(*ext_openai_tasks) if ext_openai_tasks else []
|
||||
llama_models = await asyncio.gather(*llama_tasks) if llama_tasks else []
|
||||
|
||||
models = {'data': []}
|
||||
|
||||
# Add Ollama models (if any)
|
||||
if ollama_models:
|
||||
for modellist in ollama_models:
|
||||
for model in modellist:
|
||||
if not "id" in model.keys(): # Relable Ollama models with OpenAI Model.id from Model.name
|
||||
model['id'] = model.get('name', model.get('id', ''))
|
||||
else:
|
||||
model['name'] = model['id']
|
||||
models['data'].append(model)
|
||||
|
||||
# Add external OpenAI models (if any)
|
||||
if ext_openai_models:
|
||||
for modellist in ext_openai_models:
|
||||
for model in modellist:
|
||||
if not "id" in model.keys():
|
||||
model['id'] = model.get('name', model.get('id', ''))
|
||||
else:
|
||||
model['name'] = model['id']
|
||||
models['data'].append(model)
|
||||
|
||||
# Add llama-server models (all available, not just loaded)
|
||||
if llama_models:
|
||||
for modellist in llama_models:
|
||||
for model in modellist:
|
||||
if not "id" in model.keys():
|
||||
model['id'] = model.get('name', model.get('id', ''))
|
||||
else:
|
||||
model['name'] = model['id']
|
||||
models['data'].append(model)
|
||||
|
||||
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
|
||||
return JSONResponse(
|
||||
content={"data": dedupe_on_keys(models['data'], ['name'])},
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/rerank")
|
||||
@router.post("/rerank")
|
||||
async def rerank_proxy(request: Request):
|
||||
"""
|
||||
Proxy a rerank request to a llama-server or external OpenAI-compatible endpoint.
|
||||
|
||||
Compatible with the Jina/Cohere rerank API convention used by llama-server,
|
||||
vLLM, and services such as Cohere and Jina AI.
|
||||
|
||||
Ollama does not natively support reranking; requests routed to a plain Ollama
|
||||
endpoint will receive a 501 Not Implemented response.
|
||||
|
||||
Request body:
|
||||
model (str, required) – reranker model name
|
||||
query (str, required) – search query
|
||||
documents (list[str], required) – candidate documents to rank
|
||||
top_n (int, optional) – limit returned results (default: all)
|
||||
return_documents (bool, optional) – include document text in results
|
||||
max_tokens_per_doc (int, optional) – truncation limit per document
|
||||
|
||||
Response (Jina/Cohere-compatible):
|
||||
{
|
||||
"id": "...",
|
||||
"model": "...",
|
||||
"usage": {"prompt_tokens": N, "total_tokens": N},
|
||||
"results": [{"index": 0, "relevance_score": 0.95}, ...]
|
||||
}
|
||||
"""
|
||||
config = get_config()
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
query = payload.get("query")
|
||||
documents = payload.get("documents")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(status_code=400, detail="Missing required field 'model'")
|
||||
if not query:
|
||||
raise HTTPException(status_code=400, detail="Missing required field 'query'")
|
||||
if not isinstance(documents, list) or not documents:
|
||||
raise HTTPException(status_code=400, detail="Missing or empty required field 'documents' (must be a non-empty list)")
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# Determine which endpoint serves this model
|
||||
try:
|
||||
endpoint, tracking_model = await choose_endpoint(model)
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
# Ollama endpoints have no native rerank support
|
||||
if not is_openai_compatible(endpoint):
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail=(
|
||||
f"Endpoint '{endpoint}' is a plain Ollama instance which does not support "
|
||||
"reranking. Use a llama-server or OpenAI-compatible endpoint with a "
|
||||
"dedicated reranker model."
|
||||
),
|
||||
)
|
||||
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")[0]
|
||||
|
||||
# Build upstream rerank request body – forward only recognised fields
|
||||
upstream_payload: dict = {"model": model, "query": query, "documents": documents}
|
||||
for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"):
|
||||
if optional_key in payload:
|
||||
upstream_payload[optional_key] = payload[optional_key]
|
||||
|
||||
# Determine upstream URL:
|
||||
# llama-server exposes /v1/rerank (base already contains /v1 for llama_server_endpoints)
|
||||
# External OpenAI endpoints expose /rerank under their /v1 base
|
||||
if endpoint in config.llama_server_endpoints:
|
||||
# llama-server: endpoint may or may not already contain /v1
|
||||
if "/v1" in endpoint:
|
||||
rerank_url = f"{endpoint}/rerank"
|
||||
else:
|
||||
rerank_url = f"{endpoint}/v1/rerank"
|
||||
else:
|
||||
# External OpenAI-compatible: ep2base gives us the /v1 base
|
||||
rerank_url = f"{ep2base(endpoint)}/rerank"
|
||||
|
||||
api_key = config.api_keys.get(endpoint, "no-key")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
|
||||
client: aiohttp.ClientSession = get_session(endpoint)
|
||||
try:
|
||||
async with client.post(rerank_url, json=upstream_payload, headers=headers) as resp:
|
||||
response_bytes = await resp.read()
|
||||
if resp.status >= 400:
|
||||
raise HTTPException(
|
||||
status_code=resp.status,
|
||||
detail=_mask_secrets(response_bytes.decode("utf-8", errors="replace")),
|
||||
)
|
||||
data = orjson.loads(response_bytes)
|
||||
|
||||
# Record token usage if the upstream returned a usage object
|
||||
usage = data.get("usage") or {}
|
||||
prompt_tok = usage.get("prompt_tokens") or 0
|
||||
total_tok = usage.get("total_tokens") or 0
|
||||
# For reranking there are no completion tokens; we record prompt tokens only
|
||||
if prompt_tok or total_tok:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, 0))
|
||||
|
||||
return JSONResponse(content=data)
|
||||
finally:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
30
api/static.py
Normal file
30
api/static.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
"""Static-asset and dashboard routes."""
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from starlette.responses import HTMLResponse, RedirectResponse
|
||||
|
||||
# Directory containing static files (resolved relative to project root).
|
||||
STATIC_DIR = Path(__file__).resolve().parent.parent / "static"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/favicon.ico")
|
||||
async def redirect_favicon():
|
||||
return RedirectResponse(url="/static/favicon.ico")
|
||||
|
||||
|
||||
@router.get("/", response_class=HTMLResponse)
|
||||
async def index(request: Request):
|
||||
"""
|
||||
Render the dynamic NOMYO Router dashboard listing the configured endpoints
|
||||
and the models details, availability & task status.
|
||||
"""
|
||||
index_path = STATIC_DIR / "index.html"
|
||||
try:
|
||||
return HTMLResponse(content=index_path.read_text(encoding="utf-8"), status_code=200)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Page not found")
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
|
@ -10,6 +10,7 @@ import pytest
|
|||
from fastapi import HTTPException
|
||||
|
||||
import router
|
||||
from api import openai as api_openai
|
||||
|
||||
|
||||
_BYPASS = HTTPException(status_code=599, detail="bypassed")
|
||||
|
|
@ -47,8 +48,8 @@ class TestOpenAIChatCompletionsCacheHit:
|
|||
# Patch the route's references to both helpers — they're imported by name
|
||||
# into router's namespace at module load time.
|
||||
with (
|
||||
patch.object(router, "get_llm_cache", return_value=fake),
|
||||
patch.object(router, "choose_endpoint",
|
||||
patch.object(api_openai, "get_llm_cache", return_value=fake),
|
||||
patch.object(api_openai, "choose_endpoint",
|
||||
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
|
||||
):
|
||||
resp = await client.post(
|
||||
|
|
@ -70,8 +71,8 @@ class TestOpenAIChatCompletionsCacheHit:
|
|||
async def test_stream_cache_hit_returns_sse(self, client, cache_hit_payload):
|
||||
fake = _FakeCache(cache_hit_payload)
|
||||
with (
|
||||
patch.object(router, "get_llm_cache", return_value=fake),
|
||||
patch.object(router, "choose_endpoint",
|
||||
patch.object(api_openai, "get_llm_cache", return_value=fake),
|
||||
patch.object(api_openai, "choose_endpoint",
|
||||
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
|
||||
):
|
||||
resp = await client.post(
|
||||
|
|
@ -98,8 +99,8 @@ class TestOpenAIChatCompletionsCacheHit:
|
|||
"""When nomyo.cache=False, get_chat is never called even if a cache exists."""
|
||||
fake = _FakeCache(b"") # has a response, but should never be consulted
|
||||
with (
|
||||
patch.object(router, "get_llm_cache", return_value=fake),
|
||||
patch.object(router, "choose_endpoint",
|
||||
patch.object(api_openai, "get_llm_cache", return_value=fake),
|
||||
patch.object(api_openai, "choose_endpoint",
|
||||
AsyncMock(side_effect=_BYPASS)),
|
||||
):
|
||||
resp = await client.post(
|
||||
|
|
@ -117,8 +118,8 @@ class TestOpenAIChatCompletionsCacheHit:
|
|||
async def test_no_cache_configured_bypasses_cache_check(self, client):
|
||||
"""get_llm_cache() returning None should not break the route."""
|
||||
with (
|
||||
patch.object(router, "get_llm_cache", return_value=None),
|
||||
patch.object(router, "choose_endpoint",
|
||||
patch.object(api_openai, "get_llm_cache", return_value=None),
|
||||
patch.object(api_openai, "choose_endpoint",
|
||||
AsyncMock(side_effect=_BYPASS)),
|
||||
):
|
||||
resp = await client.post(
|
||||
|
|
@ -140,8 +141,8 @@ class TestOpenAICompletionsCacheHit:
|
|||
async def test_nonstream_cache_hit(self, client, cache_hit_payload):
|
||||
fake = _FakeCache(cache_hit_payload)
|
||||
with (
|
||||
patch.object(router, "get_llm_cache", return_value=fake),
|
||||
patch.object(router, "choose_endpoint",
|
||||
patch.object(api_openai, "get_llm_cache", return_value=fake),
|
||||
patch.object(api_openai, "choose_endpoint",
|
||||
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
|
||||
):
|
||||
resp = await client.post(
|
||||
|
|
@ -163,8 +164,8 @@ class TestOpenAICompletionsCacheHit:
|
|||
async def test_stream_cache_hit(self, client, cache_hit_payload):
|
||||
fake = _FakeCache(cache_hit_payload)
|
||||
with (
|
||||
patch.object(router, "get_llm_cache", return_value=fake),
|
||||
patch.object(router, "choose_endpoint",
|
||||
patch.object(api_openai, "get_llm_cache", return_value=fake),
|
||||
patch.object(api_openai, "choose_endpoint",
|
||||
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
|
||||
):
|
||||
resp = await client.post(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue