feat: transparent openai responses api integration

This commit is contained in:
Alpha Nerd 2026-06-10 18:48:26 +02:00
parent e7407b86b3
commit b28f175b61
Signed by: alpha-nerd
SSH key fingerprint: SHA256:QkkAgVoYi9TQ0UKPkiKSfnerZy2h4qhi3SVPXJmBN+M
7 changed files with 1674 additions and 86 deletions

View file

@ -46,6 +46,110 @@ from routing import choose_endpoint, decrement_usage
router = APIRouter()
async def create_chat_with_retries(oclient, send_params, endpoint, model, tracking_model):
"""Call ``chat.completions.create`` with the router's resilience retries.
Encapsulates the recovery ladder shared by the chat-completions handler and
the translated ``/v1/responses`` path:
* ``does not support tools`` retry without ``tools``
* llama-server context exhaustion sliding-window message trim, with a
second retry that also strips ``tools``/``tool_choice``
* backend connection failure mark (endpoint, model) unhealthy so the next
request reroutes, then re-raise
* ``image input is not supported`` strip images and retry
On unrecoverable failure the endpoint usage counter is decremented and the
exception is re-raised. Returns the established async generator / response.
"""
config = get_config()
try:
async_gen = await oclient.chat.completions.create(**send_params)
except Exception as e:
_e_str = str(e)
_is_ctx_err = "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str
print(f"[ochat] caught={type(e).__name__} ctx={_is_ctx_err} msg={_e_str[:120]}", flush=True)
if "does not support tools" in _e_str:
# Model doesn't support tools — retry without them
print(f"[ochat] retry: no tools", flush=True)
try:
params_without_tools = {k: v for k, v in send_params.items() if k != "tools"}
async_gen = await oclient.chat.completions.create(**params_without_tools)
except Exception:
await decrement_usage(endpoint, tracking_model)
raise
elif _is_ctx_err:
# Backend context limit hit — apply sliding-window trim (context-shift at message level)
err_body = getattr(e, "body", {}) or {}
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
n_ctx_limit = err_detail.get("n_ctx", 0)
actual_tokens = err_detail.get("n_prompt_tokens", 0)
# Fallback: parse from string if body parsing yielded nothing (SDK may not parse llama-server errors)
if not n_ctx_limit:
import re as _re
_m = _re.search(r"'n_ctx':\s*(\d+)", _e_str)
if _m:
n_ctx_limit = int(_m.group(1))
_m = _re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str)
if _m:
actual_tokens = int(_m.group(1))
print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True)
if not n_ctx_limit:
await decrement_usage(endpoint, tracking_model)
raise
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
msgs_to_trim = send_params.get("messages", [])
try:
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
except Exception as _helper_exc:
print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True)
await decrement_usage(endpoint, tracking_model)
raise
dropped = len(msgs_to_trim) - len(trimmed_messages)
print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True)
try:
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages})
print(f"[ctx-trim] retry-1 ok", flush=True)
except Exception as e2:
_e2_str = str(e2)
if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str:
# Still too large — tool definitions likely consuming too many tokens, strip them too
print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True)
params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")}
try:
async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages})
print(f"[ctx-trim] retry-2 ok", flush=True)
except Exception:
await decrement_usage(endpoint, tracking_model)
raise
else:
await decrement_usage(endpoint, tracking_model)
raise
elif _is_backend_connection_error(e):
# Upstream connection failed (e.g. llama-server in router mode
# whose delegated worker died). Mark (endpoint, model) so the
# next request reroutes; the client will retry this one.
print(f"[ochat] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
await _mark_backend_unhealthy(endpoint, model, _e_str)
await decrement_usage(endpoint, tracking_model)
raise
elif "image input is not supported" in _e_str:
# Model doesn't support images — strip and retry
print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages")
try:
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))})
except Exception:
await decrement_usage(endpoint, tracking_model)
raise
else:
await decrement_usage(endpoint, tracking_model)
raise
return async_gen
@router.post("/v1/embeddings")
async def openai_embedding_proxy(request: Request):
"""
@ -260,90 +364,7 @@ async def openai_chat_completions_proxy(request: Request):
_dropped = len(_pre_msgs) - len(_pre_trimmed)
print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True)
send_params = {**send_params, "messages": _pre_trimmed}
try:
async_gen = await oclient.chat.completions.create(**send_params)
except Exception as e:
_e_str = str(e)
_is_ctx_err = "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str
print(f"[ochat] caught={type(e).__name__} ctx={_is_ctx_err} msg={_e_str[:120]}", flush=True)
if "does not support tools" in _e_str:
# Model doesn't support tools — retry without them
print(f"[ochat] retry: no tools", flush=True)
try:
params_without_tools = {k: v for k, v in send_params.items() if k != "tools"}
async_gen = await oclient.chat.completions.create(**params_without_tools)
except Exception:
await decrement_usage(endpoint, tracking_model)
raise
elif _is_ctx_err:
# Backend context limit hit — apply sliding-window trim (context-shift at message level)
err_body = getattr(e, "body", {}) or {}
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
n_ctx_limit = err_detail.get("n_ctx", 0)
actual_tokens = err_detail.get("n_prompt_tokens", 0)
# Fallback: parse from string if body parsing yielded nothing (SDK may not parse llama-server errors)
if not n_ctx_limit:
import re as _re
_m = _re.search(r"'n_ctx':\s*(\d+)", _e_str)
if _m:
n_ctx_limit = int(_m.group(1))
_m = _re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str)
if _m:
actual_tokens = int(_m.group(1))
print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True)
if not n_ctx_limit:
await decrement_usage(endpoint, tracking_model)
raise
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
msgs_to_trim = send_params.get("messages", [])
try:
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
except Exception as _helper_exc:
print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True)
await decrement_usage(endpoint, tracking_model)
raise
dropped = len(msgs_to_trim) - len(trimmed_messages)
print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True)
try:
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages})
print(f"[ctx-trim] retry-1 ok", flush=True)
except Exception as e2:
_e2_str = str(e2)
if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str:
# Still too large — tool definitions likely consuming too many tokens, strip them too
print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True)
params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")}
try:
async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages})
print(f"[ctx-trim] retry-2 ok", flush=True)
except Exception:
await decrement_usage(endpoint, tracking_model)
raise
else:
await decrement_usage(endpoint, tracking_model)
raise
elif _is_backend_connection_error(e):
# Upstream connection failed (e.g. llama-server in router mode
# whose delegated worker died). Mark (endpoint, model) so the
# next request reroutes; the client will retry this one.
print(f"[ochat] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
await _mark_backend_unhealthy(endpoint, model, _e_str)
await decrement_usage(endpoint, tracking_model)
raise
elif "image input is not supported" in _e_str:
# Model doesn't support images — strip and retry
print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages")
try:
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))})
except Exception:
await decrement_usage(endpoint, tracking_model)
raise
else:
await decrement_usage(endpoint, tracking_model)
raise
async_gen = await create_chat_with_retries(oclient, send_params, endpoint, model, tracking_model)
# 4. Async generator — only streams the already-established async_gen
async def stream_ochat_response():

398
api/responses.py Normal file
View file

@ -0,0 +1,398 @@
"""OpenAI **Responses API** routes (``/v1/responses`` and its retrieve / delete /
cancel companions).
The router speaks Chat Completions to its backends, so this layer:
* **native** (external OpenAI): forwards via ``oclient.responses.create`` and
streams the SDK's typed events straight back, rewriting the response ``id`` to
a router-owned ``resp_`` id so chaining stays router-managed.
* **translated** (Ollama / llama-server): converts the request to chat, reuses
the resilient ``create_chat_with_retries`` ladder, and re-emits the result as
Responses typed SSE events (``requests/responses.py``).
State (``store`` / ``previous_response_id``) and background-task status live in the
router's SQLite DB (``db.py``); the router mints and owns every response id.
"""
import asyncio
import secrets
import time
import orjson
from fastapi import APIRouter, HTTPException, Request
from starlette.responses import JSONResponse, StreamingResponse
from cache import get_llm_cache
from config import get_config
from db import get_db
from fingerprint import _conversation_fingerprint
from state import token_queue, default_headers
from backends.normalize import is_ext_openai_endpoint
from backends.sessions import _make_openai_client
from routing import choose_endpoint, decrement_usage
from api.openai import create_chat_with_retries
from requests.responses import (
ChatToResponsesStream,
build_response_object,
chat_message_to_output_items,
messages_to_responses_input,
responses_input_to_messages,
responses_object_to_sse,
tools_responses_to_chat,
usage_chat_to_responses,
)
router = APIRouter()
# In-memory handles for background tasks so /cancel can reach a running task in
# this worker. Cross-worker cancel falls back to marking the DB row cancelled.
_background_tasks: dict[str, asyncio.Task] = {}
# ---------------------------------------------------------------------------
# small helpers
# ---------------------------------------------------------------------------
def _usage_tokens(usage):
"""Return ``(prompt, completion)`` tokens from a chat- or responses-shaped usage."""
if not usage:
return 0, 0
if "input_tokens" in usage:
return usage.get("input_tokens", 0) or 0, usage.get("output_tokens", 0) or 0
return usage.get("prompt_tokens", 0) or 0, usage.get("completion_tokens", 0) or 0
def _text_format_to_response_format(text):
"""Map Responses ``text.format`` → Chat Completions ``response_format`` (best effort)."""
if not isinstance(text, dict):
return None
fmt = text.get("format")
if not isinstance(fmt, dict):
return None
ftype = fmt.get("type")
if ftype == "json_object":
return {"type": "json_object"}
if ftype == "json_schema":
return {"type": "json_schema", "json_schema": {
k: fmt[k] for k in ("name", "schema", "strict", "description") if k in fmt
}}
return None
def _native_usage_from_response(data):
return data.get("usage")
async def _resolve_history_messages(previous_response_id):
"""Rebuild prior-turn chat messages from the stored response chain."""
if not previous_response_id:
return []
db = get_db()
chain = await db.get_response_chain(previous_response_id)
messages = []
for turn in chain:
# Each turn stored the chat messages that produced it + its output items.
for m in turn.get("input_messages") or []:
messages.append(m)
for item in turn.get("output_items") or []:
if item.get("type") == "message":
text = "".join(
p.get("text", "") for p in item.get("content") or []
if p.get("type") == "output_text"
)
if text:
messages.append({"role": "assistant", "content": text})
elif item.get("type") == "function_call":
messages.append({
"role": "assistant", "content": None,
"tool_calls": [{"id": item.get("call_id"), "type": "function",
"function": {"name": item.get("name"),
"arguments": item.get("arguments", "")}}],
})
return messages
class _NativeStream:
"""Re-emit an SDK Responses event stream, rewriting the response id and
capturing the final output/usage for storage."""
def __init__(self, response_id):
self.response_id = response_id
self.output_items = []
self.usage = None
async def events(self, sdk_gen):
async for event in sdk_gen:
data = event.model_dump() if hasattr(event, "model_dump") else event
etype = data.get("type", "")
resp = data.get("response")
if isinstance(resp, dict) and resp.get("id"):
resp["id"] = self.response_id
if etype in ("response.completed", "response.incomplete", "response.failed") \
and isinstance(resp, dict):
self.output_items = resp.get("output", []) or []
self.usage = resp.get("usage")
yield f"event: {etype}\ndata: {orjson.dumps(data).decode('utf-8')}\n\n".encode("utf-8")
# ---------------------------------------------------------------------------
# backend execution (non-streaming, used by background + non-stream sync)
# ---------------------------------------------------------------------------
async def _run_to_completion(*, native, oclient, endpoint, model, tracking_model,
send_params, native_params):
"""Drive the backend to completion (no client streaming).
Returns ``(output_items, usage)`` where usage is responses-shaped. Caller is
responsible for ``decrement_usage`` (translated failures self-decrement inside
``create_chat_with_retries``)."""
if native:
resp_obj = await oclient.responses.create(stream=False, **native_params)
data = resp_obj.model_dump()
return data.get("output", []) or [], data.get("usage")
async_gen = await create_chat_with_retries(oclient, {**send_params, "stream": False},
endpoint, model, tracking_model)
message = async_gen.choices[0].message.model_dump() if async_gen.choices else {}
output_items = chat_message_to_output_items(message)
usage = usage_chat_to_responses(
async_gen.usage.model_dump() if async_gen.usage is not None else None
)
return output_items, usage
# ---------------------------------------------------------------------------
# POST /v1/responses
# ---------------------------------------------------------------------------
@router.post("/v1/responses")
async def openai_responses_proxy(request: Request):
config = get_config()
try:
payload = orjson.loads((await request.body()).decode("utf-8"))
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
model = payload.get("model")
input_data = payload.get("input")
instructions = payload.get("instructions")
stream = bool(payload.get("stream"))
store = payload.get("store", True)
background = bool(payload.get("background"))
previous_response_id = payload.get("previous_response_id")
tools = payload.get("tools")
metadata = payload.get("metadata") or {}
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
if not model:
raise HTTPException(status_code=400, detail="Missing required field 'model'")
if input_data is None:
raise HTTPException(status_code=400, detail="Missing required field 'input'")
if background and not store:
raise HTTPException(status_code=400, detail="background mode requires store=true")
if ":latest" in model:
model = model.split(":latest")[0]
# Resolve conversation: prior turns (from store) + this turn's input.
history = await _resolve_history_messages(previous_response_id)
messages = history + responses_input_to_messages(input_data, instructions)
response_id = f"resp_{secrets.token_hex(24)}"
created_at = int(time.time())
# Cache lookup (foreground only) — before endpoint selection.
_cache = get_llm_cache()
if _cache is not None and _cache_enabled and not background:
cached = await _cache.get_chat("openai_responses", model, messages)
if cached is not None:
resp_obj = orjson.loads(cached)
resp_obj["id"] = response_id
if stream:
async def _served_cached():
yield responses_object_to_sse(resp_obj)
return StreamingResponse(_served_cached(), media_type="text/event-stream")
return JSONResponse(content=resp_obj)
# Endpoint selection (reserves a slot — must be released exactly once).
_affinity_key = _conversation_fingerprint(model, messages, None)
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
oclient = _make_openai_client(endpoint, default_headers=default_headers,
api_key=config.api_keys.get(endpoint, "no-key"))
native = is_ext_openai_endpoint(endpoint)
# Build backend params for both shapes.
send_params = {"messages": messages, "model": model}
_opt = {
"temperature": payload.get("temperature"),
"top_p": payload.get("top_p"),
"max_tokens": payload.get("max_output_tokens"),
"tools": tools_responses_to_chat(tools),
"tool_choice": payload.get("tool_choice"),
"response_format": _text_format_to_response_format(payload.get("text")),
}
send_params.update({k: v for k, v in _opt.items() if v is not None})
native_instructions, native_input = messages_to_responses_input(messages)
native_params = {"model": model, "input": native_input, "store": False}
_nopt = {
"instructions": native_instructions,
"temperature": payload.get("temperature"),
"top_p": payload.get("top_p"),
"max_output_tokens": payload.get("max_output_tokens"),
"tools": tools,
"tool_choice": payload.get("tool_choice"),
"text": payload.get("text"),
"reasoning": payload.get("reasoning"),
}
native_params.update({k: v for k, v in _nopt.items() if v is not None})
async def _persist(status, output_items=None, usage=None, error=None, insert=False):
if not store:
return
db = get_db()
if insert:
await db.store_response(
response_id, previous_response_id=previous_response_id, model=model,
status=status, created_at=created_at, input_messages=messages,
output_items=output_items, usage=usage, instructions=instructions, error=error)
else:
await db.update_response_status(response_id, status, output_items=output_items,
usage=usage, error=error)
async def _track(usage):
prompt_tok, comp_tok = _usage_tokens(usage)
if prompt_tok or comp_tok:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
async def _cache_store(output_items, usage):
if _cache is None or not _cache_enabled or not output_items:
return
obj = build_response_object(response_id=response_id, model=model,
output_items=output_items, usage=usage,
created_at=created_at,
previous_response_id=previous_response_id,
instructions=instructions, metadata=metadata)
try:
await _cache.set_chat("openai_responses", model, messages, orjson.dumps(obj))
except Exception as _ce:
print(f"[cache] set_chat (openai_responses) failed: {_ce}")
# ---- background: run detached, return queued immediately --------------
if background:
await _persist("queued", insert=True)
async def _bg_run():
try:
await get_db().update_response_status(response_id, "in_progress")
output_items, usage = await _run_to_completion(
native=native, oclient=oclient, endpoint=endpoint, model=model,
tracking_model=tracking_model, send_params=send_params,
native_params=native_params)
await _track(usage)
await _persist("completed", output_items=output_items, usage=usage)
await _cache_store(output_items, usage)
except asyncio.CancelledError:
await get_db().update_response_status(response_id, "cancelled")
raise
except Exception as e:
await get_db().update_response_status(
response_id, "failed",
error={"message": str(e)[:500], "type": type(e).__name__})
finally:
await decrement_usage(endpoint, tracking_model)
_background_tasks.pop(response_id, None)
task = asyncio.create_task(_bg_run())
_background_tasks[response_id] = task
queued = build_response_object(response_id=response_id, model=model, output_items=[],
status="queued", created_at=created_at,
previous_response_id=previous_response_id,
instructions=instructions, metadata=metadata)
return JSONResponse(content=queued, status_code=200)
# ---- streaming sync ----------------------------------------------------
if stream:
if native:
source = await oclient.responses.create(stream=True, **native_params)
translator = _NativeStream(response_id)
else:
source = await create_chat_with_retries(
oclient, {**send_params, "stream": True,
"stream_options": {"include_usage": True}},
endpoint, model, tracking_model)
translator = ChatToResponsesStream(
response_id, model, created_at=created_at,
previous_response_id=previous_response_id, instructions=instructions,
metadata=metadata)
async def _stream():
await _persist("in_progress", insert=True)
try:
async for sse in translator.events(source):
yield sse
await _track(translator.usage)
await _persist("completed", output_items=translator.output_items,
usage=translator.usage)
await _cache_store(translator.output_items, translator.usage)
finally:
await decrement_usage(endpoint, tracking_model)
return StreamingResponse(_stream(), media_type="text/event-stream")
# ---- non-streaming sync ------------------------------------------------
try:
output_items, usage = await _run_to_completion(
native=native, oclient=oclient, endpoint=endpoint, model=model,
tracking_model=tracking_model, send_params=send_params,
native_params=native_params)
await _track(usage)
await _persist("completed", output_items=output_items, usage=usage, insert=True)
await _cache_store(output_items, usage)
finally:
await decrement_usage(endpoint, tracking_model)
resp_obj = build_response_object(
response_id=response_id, model=model, output_items=output_items, usage=usage,
created_at=created_at, previous_response_id=previous_response_id,
instructions=instructions, metadata=metadata)
return JSONResponse(content=resp_obj)
# ---------------------------------------------------------------------------
# GET / DELETE / cancel
# ---------------------------------------------------------------------------
def _stored_to_response_object(row):
return build_response_object(
response_id=row["response_id"], model=row.get("model"),
output_items=row.get("output_items") or [], usage=row.get("usage"),
status=row.get("status") or "completed", created_at=row.get("created_at"),
previous_response_id=row.get("previous_response_id"),
instructions=row.get("instructions"), error=row.get("error"))
@router.get("/v1/responses/{response_id}")
async def get_response(response_id: str):
row = await get_db().get_response(response_id)
if row is None:
raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found")
return JSONResponse(content=_stored_to_response_object(row))
@router.delete("/v1/responses/{response_id}")
async def delete_response(response_id: str):
deleted = await get_db().delete_response(response_id)
if not deleted:
raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found")
return JSONResponse(content={"id": response_id, "object": "response.deleted", "deleted": True})
@router.post("/v1/responses/{response_id}/cancel")
async def cancel_response(response_id: str):
row = await get_db().get_response(response_id)
if row is None:
raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found")
# Cancel the running task if it lives in this worker; otherwise just mark the
# DB row so a polling client sees a terminal state (cross-worker limitation).
task = _background_tasks.get(response_id)
if task is not None and not task.done():
task.cancel()
elif row.get("status") in ("queued", "in_progress"):
await get_db().update_response_status(response_id, "cancelled")
row = await get_db().get_response(response_id)
return JSONResponse(content=_stored_to_response_object(row))