nomyo-router/api/responses.py

398 lines
18 KiB
Python

"""OpenAI **Responses API** routes (``/v1/responses`` and its retrieve / delete /
cancel companions).
The router speaks Chat Completions to its backends, so this layer:
* **native** (external OpenAI): forwards via ``oclient.responses.create`` and
streams the SDK's typed events straight back, rewriting the response ``id`` to
a router-owned ``resp_`` id so chaining stays router-managed.
* **translated** (Ollama / llama-server): converts the request to chat, reuses
the resilient ``create_chat_with_retries`` ladder, and re-emits the result as
Responses typed SSE events (``requests/responses.py``).
State (``store`` / ``previous_response_id``) and background-task status live in the
router's SQLite DB (``db.py``); the router mints and owns every response id.
"""
import asyncio
import secrets
import time
import orjson
from fastapi import APIRouter, HTTPException, Request
from starlette.responses import JSONResponse, StreamingResponse
from cache import get_llm_cache
from config import get_config
from db import get_db
from fingerprint import _conversation_fingerprint
from state import token_queue, default_headers
from backends.normalize import is_ext_openai_endpoint
from backends.sessions import _make_openai_client
from routing import choose_endpoint, decrement_usage
from api.openai import create_chat_with_retries
from requests.responses import (
ChatToResponsesStream,
build_response_object,
chat_message_to_output_items,
messages_to_responses_input,
responses_input_to_messages,
responses_object_to_sse,
tools_responses_to_chat,
usage_chat_to_responses,
)
router = APIRouter()
# In-memory handles for background tasks so /cancel can reach a running task in
# this worker. Cross-worker cancel falls back to marking the DB row cancelled.
_background_tasks: dict[str, asyncio.Task] = {}
# ---------------------------------------------------------------------------
# small helpers
# ---------------------------------------------------------------------------
def _usage_tokens(usage):
"""Return ``(prompt, completion)`` tokens from a chat- or responses-shaped usage."""
if not usage:
return 0, 0
if "input_tokens" in usage:
return usage.get("input_tokens", 0) or 0, usage.get("output_tokens", 0) or 0
return usage.get("prompt_tokens", 0) or 0, usage.get("completion_tokens", 0) or 0
def _text_format_to_response_format(text):
"""Map Responses ``text.format`` → Chat Completions ``response_format`` (best effort)."""
if not isinstance(text, dict):
return None
fmt = text.get("format")
if not isinstance(fmt, dict):
return None
ftype = fmt.get("type")
if ftype == "json_object":
return {"type": "json_object"}
if ftype == "json_schema":
return {"type": "json_schema", "json_schema": {
k: fmt[k] for k in ("name", "schema", "strict", "description") if k in fmt
}}
return None
def _native_usage_from_response(data):
return data.get("usage")
async def _resolve_history_messages(previous_response_id):
"""Rebuild prior-turn chat messages from the stored response chain."""
if not previous_response_id:
return []
db = get_db()
chain = await db.get_response_chain(previous_response_id)
messages = []
for turn in chain:
# Each turn stored the chat messages that produced it + its output items.
for m in turn.get("input_messages") or []:
messages.append(m)
for item in turn.get("output_items") or []:
if item.get("type") == "message":
text = "".join(
p.get("text", "") for p in item.get("content") or []
if p.get("type") == "output_text"
)
if text:
messages.append({"role": "assistant", "content": text})
elif item.get("type") == "function_call":
messages.append({
"role": "assistant", "content": None,
"tool_calls": [{"id": item.get("call_id"), "type": "function",
"function": {"name": item.get("name"),
"arguments": item.get("arguments", "")}}],
})
return messages
class _NativeStream:
"""Re-emit an SDK Responses event stream, rewriting the response id and
capturing the final output/usage for storage."""
def __init__(self, response_id):
self.response_id = response_id
self.output_items = []
self.usage = None
async def events(self, sdk_gen):
async for event in sdk_gen:
data = event.model_dump() if hasattr(event, "model_dump") else event
etype = data.get("type", "")
resp = data.get("response")
if isinstance(resp, dict) and resp.get("id"):
resp["id"] = self.response_id
if etype in ("response.completed", "response.incomplete", "response.failed") \
and isinstance(resp, dict):
self.output_items = resp.get("output", []) or []
self.usage = resp.get("usage")
yield f"event: {etype}\ndata: {orjson.dumps(data).decode('utf-8')}\n\n".encode("utf-8")
# ---------------------------------------------------------------------------
# backend execution (non-streaming, used by background + non-stream sync)
# ---------------------------------------------------------------------------
async def _run_to_completion(*, native, oclient, endpoint, model, tracking_model,
send_params, native_params):
"""Drive the backend to completion (no client streaming).
Returns ``(output_items, usage)`` where usage is responses-shaped. Caller is
responsible for ``decrement_usage`` (translated failures self-decrement inside
``create_chat_with_retries``)."""
if native:
resp_obj = await oclient.responses.create(stream=False, **native_params)
data = resp_obj.model_dump()
return data.get("output", []) or [], data.get("usage")
async_gen = await create_chat_with_retries(oclient, {**send_params, "stream": False},
endpoint, model, tracking_model)
message = async_gen.choices[0].message.model_dump() if async_gen.choices else {}
output_items = chat_message_to_output_items(message)
usage = usage_chat_to_responses(
async_gen.usage.model_dump() if async_gen.usage is not None else None
)
return output_items, usage
# ---------------------------------------------------------------------------
# POST /v1/responses
# ---------------------------------------------------------------------------
@router.post("/v1/responses")
async def openai_responses_proxy(request: Request):
config = get_config()
try:
payload = orjson.loads((await request.body()).decode("utf-8"))
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
model = payload.get("model")
input_data = payload.get("input")
instructions = payload.get("instructions")
stream = bool(payload.get("stream"))
store = payload.get("store", True)
background = bool(payload.get("background"))
previous_response_id = payload.get("previous_response_id")
tools = payload.get("tools")
metadata = payload.get("metadata") or {}
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
if not model:
raise HTTPException(status_code=400, detail="Missing required field 'model'")
if input_data is None:
raise HTTPException(status_code=400, detail="Missing required field 'input'")
if background and not store:
raise HTTPException(status_code=400, detail="background mode requires store=true")
if ":latest" in model:
model = model.split(":latest")[0]
# Resolve conversation: prior turns (from store) + this turn's input.
history = await _resolve_history_messages(previous_response_id)
messages = history + responses_input_to_messages(input_data, instructions)
response_id = f"resp_{secrets.token_hex(24)}"
created_at = int(time.time())
# Cache lookup (foreground only) — before endpoint selection.
_cache = get_llm_cache()
if _cache is not None and _cache_enabled and not background:
cached = await _cache.get_chat("openai_responses", model, messages)
if cached is not None:
resp_obj = orjson.loads(cached)
resp_obj["id"] = response_id
if stream:
async def _served_cached():
yield responses_object_to_sse(resp_obj)
return StreamingResponse(_served_cached(), media_type="text/event-stream")
return JSONResponse(content=resp_obj)
# Endpoint selection (reserves a slot — must be released exactly once).
_affinity_key = _conversation_fingerprint(model, messages, None)
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
oclient = _make_openai_client(endpoint, default_headers=default_headers,
api_key=config.api_keys.get(endpoint, "no-key"))
native = is_ext_openai_endpoint(endpoint)
# Build backend params for both shapes.
send_params = {"messages": messages, "model": model}
_opt = {
"temperature": payload.get("temperature"),
"top_p": payload.get("top_p"),
"max_tokens": payload.get("max_output_tokens"),
"tools": tools_responses_to_chat(tools),
"tool_choice": payload.get("tool_choice"),
"response_format": _text_format_to_response_format(payload.get("text")),
}
send_params.update({k: v for k, v in _opt.items() if v is not None})
native_instructions, native_input = messages_to_responses_input(messages)
native_params = {"model": model, "input": native_input, "store": False}
_nopt = {
"instructions": native_instructions,
"temperature": payload.get("temperature"),
"top_p": payload.get("top_p"),
"max_output_tokens": payload.get("max_output_tokens"),
"tools": tools,
"tool_choice": payload.get("tool_choice"),
"text": payload.get("text"),
"reasoning": payload.get("reasoning"),
}
native_params.update({k: v for k, v in _nopt.items() if v is not None})
async def _persist(status, output_items=None, usage=None, error=None, insert=False):
if not store:
return
db = get_db()
if insert:
await db.store_response(
response_id, previous_response_id=previous_response_id, model=model,
status=status, created_at=created_at, input_messages=messages,
output_items=output_items, usage=usage, instructions=instructions, error=error)
else:
await db.update_response_status(response_id, status, output_items=output_items,
usage=usage, error=error)
async def _track(usage):
prompt_tok, comp_tok = _usage_tokens(usage)
if prompt_tok or comp_tok:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
async def _cache_store(output_items, usage):
if _cache is None or not _cache_enabled or not output_items:
return
obj = build_response_object(response_id=response_id, model=model,
output_items=output_items, usage=usage,
created_at=created_at,
previous_response_id=previous_response_id,
instructions=instructions, metadata=metadata)
try:
await _cache.set_chat("openai_responses", model, messages, orjson.dumps(obj))
except Exception as _ce:
print(f"[cache] set_chat (openai_responses) failed: {_ce}")
# ---- background: run detached, return queued immediately --------------
if background:
await _persist("queued", insert=True)
async def _bg_run():
try:
await get_db().update_response_status(response_id, "in_progress")
output_items, usage = await _run_to_completion(
native=native, oclient=oclient, endpoint=endpoint, model=model,
tracking_model=tracking_model, send_params=send_params,
native_params=native_params)
await _track(usage)
await _persist("completed", output_items=output_items, usage=usage)
await _cache_store(output_items, usage)
except asyncio.CancelledError:
await get_db().update_response_status(response_id, "cancelled")
raise
except Exception as e:
await get_db().update_response_status(
response_id, "failed",
error={"message": str(e)[:500], "type": type(e).__name__})
finally:
await decrement_usage(endpoint, tracking_model)
_background_tasks.pop(response_id, None)
task = asyncio.create_task(_bg_run())
_background_tasks[response_id] = task
queued = build_response_object(response_id=response_id, model=model, output_items=[],
status="queued", created_at=created_at,
previous_response_id=previous_response_id,
instructions=instructions, metadata=metadata)
return JSONResponse(content=queued, status_code=200)
# ---- streaming sync ----------------------------------------------------
if stream:
if native:
source = await oclient.responses.create(stream=True, **native_params)
translator = _NativeStream(response_id)
else:
source = await create_chat_with_retries(
oclient, {**send_params, "stream": True,
"stream_options": {"include_usage": True}},
endpoint, model, tracking_model)
translator = ChatToResponsesStream(
response_id, model, created_at=created_at,
previous_response_id=previous_response_id, instructions=instructions,
metadata=metadata)
async def _stream():
await _persist("in_progress", insert=True)
try:
async for sse in translator.events(source):
yield sse
await _track(translator.usage)
await _persist("completed", output_items=translator.output_items,
usage=translator.usage)
await _cache_store(translator.output_items, translator.usage)
finally:
await decrement_usage(endpoint, tracking_model)
return StreamingResponse(_stream(), media_type="text/event-stream")
# ---- non-streaming sync ------------------------------------------------
try:
output_items, usage = await _run_to_completion(
native=native, oclient=oclient, endpoint=endpoint, model=model,
tracking_model=tracking_model, send_params=send_params,
native_params=native_params)
await _track(usage)
await _persist("completed", output_items=output_items, usage=usage, insert=True)
await _cache_store(output_items, usage)
finally:
await decrement_usage(endpoint, tracking_model)
resp_obj = build_response_object(
response_id=response_id, model=model, output_items=output_items, usage=usage,
created_at=created_at, previous_response_id=previous_response_id,
instructions=instructions, metadata=metadata)
return JSONResponse(content=resp_obj)
# ---------------------------------------------------------------------------
# GET / DELETE / cancel
# ---------------------------------------------------------------------------
def _stored_to_response_object(row):
return build_response_object(
response_id=row["response_id"], model=row.get("model"),
output_items=row.get("output_items") or [], usage=row.get("usage"),
status=row.get("status") or "completed", created_at=row.get("created_at"),
previous_response_id=row.get("previous_response_id"),
instructions=row.get("instructions"), error=row.get("error"))
@router.get("/v1/responses/{response_id}")
async def get_response(response_id: str):
row = await get_db().get_response(response_id)
if row is None:
raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found")
return JSONResponse(content=_stored_to_response_object(row))
@router.delete("/v1/responses/{response_id}")
async def delete_response(response_id: str):
deleted = await get_db().delete_response(response_id)
if not deleted:
raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found")
return JSONResponse(content={"id": response_id, "object": "response.deleted", "deleted": True})
@router.post("/v1/responses/{response_id}/cancel")
async def cancel_response(response_id: str):
row = await get_db().get_response(response_id)
if row is None:
raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found")
# Cancel the running task if it lives in this worker; otherwise just mark the
# DB row so a polling client sees a terminal state (cross-worker limitation).
task = _background_tasks.get(response_id)
if task is not None and not task.done():
task.cancel()
elif row.get("status") in ("queued", "in_progress"):
await get_db().update_response_status(response_id, "cancelled")
row = await get_db().get_response(response_id)
return JSONResponse(content=_stored_to_response_object(row))