398 lines
18 KiB
Python
398 lines
18 KiB
Python
"""OpenAI **Responses API** routes (``/v1/responses`` and its retrieve / delete /
|
|
cancel companions).
|
|
|
|
The router speaks Chat Completions to its backends, so this layer:
|
|
|
|
* **native** (external OpenAI): forwards via ``oclient.responses.create`` and
|
|
streams the SDK's typed events straight back, rewriting the response ``id`` to
|
|
a router-owned ``resp_`` id so chaining stays router-managed.
|
|
* **translated** (Ollama / llama-server): converts the request to chat, reuses
|
|
the resilient ``create_chat_with_retries`` ladder, and re-emits the result as
|
|
Responses typed SSE events (``requests/responses.py``).
|
|
|
|
State (``store`` / ``previous_response_id``) and background-task status live in the
|
|
router's SQLite DB (``db.py``); the router mints and owns every response id.
|
|
"""
|
|
import asyncio
|
|
import secrets
|
|
import time
|
|
|
|
import orjson
|
|
from fastapi import APIRouter, HTTPException, Request
|
|
from starlette.responses import JSONResponse, StreamingResponse
|
|
|
|
from cache import get_llm_cache
|
|
from config import get_config
|
|
from db import get_db
|
|
from fingerprint import _conversation_fingerprint
|
|
from state import token_queue, default_headers
|
|
from backends.normalize import is_ext_openai_endpoint
|
|
from backends.sessions import _make_openai_client
|
|
from routing import choose_endpoint, decrement_usage
|
|
from api.openai import create_chat_with_retries
|
|
from requests.responses import (
|
|
ChatToResponsesStream,
|
|
build_response_object,
|
|
chat_message_to_output_items,
|
|
messages_to_responses_input,
|
|
responses_input_to_messages,
|
|
responses_object_to_sse,
|
|
tools_responses_to_chat,
|
|
usage_chat_to_responses,
|
|
)
|
|
|
|
router = APIRouter()
|
|
|
|
# In-memory handles for background tasks so /cancel can reach a running task in
|
|
# this worker. Cross-worker cancel falls back to marking the DB row cancelled.
|
|
_background_tasks: dict[str, asyncio.Task] = {}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# small helpers
|
|
# ---------------------------------------------------------------------------
|
|
def _usage_tokens(usage):
|
|
"""Return ``(prompt, completion)`` tokens from a chat- or responses-shaped usage."""
|
|
if not usage:
|
|
return 0, 0
|
|
if "input_tokens" in usage:
|
|
return usage.get("input_tokens", 0) or 0, usage.get("output_tokens", 0) or 0
|
|
return usage.get("prompt_tokens", 0) or 0, usage.get("completion_tokens", 0) or 0
|
|
|
|
|
|
def _text_format_to_response_format(text):
|
|
"""Map Responses ``text.format`` → Chat Completions ``response_format`` (best effort)."""
|
|
if not isinstance(text, dict):
|
|
return None
|
|
fmt = text.get("format")
|
|
if not isinstance(fmt, dict):
|
|
return None
|
|
ftype = fmt.get("type")
|
|
if ftype == "json_object":
|
|
return {"type": "json_object"}
|
|
if ftype == "json_schema":
|
|
return {"type": "json_schema", "json_schema": {
|
|
k: fmt[k] for k in ("name", "schema", "strict", "description") if k in fmt
|
|
}}
|
|
return None
|
|
|
|
|
|
def _native_usage_from_response(data):
|
|
return data.get("usage")
|
|
|
|
|
|
async def _resolve_history_messages(previous_response_id):
|
|
"""Rebuild prior-turn chat messages from the stored response chain."""
|
|
if not previous_response_id:
|
|
return []
|
|
db = get_db()
|
|
chain = await db.get_response_chain(previous_response_id)
|
|
messages = []
|
|
for turn in chain:
|
|
# Each turn stored the chat messages that produced it + its output items.
|
|
for m in turn.get("input_messages") or []:
|
|
messages.append(m)
|
|
for item in turn.get("output_items") or []:
|
|
if item.get("type") == "message":
|
|
text = "".join(
|
|
p.get("text", "") for p in item.get("content") or []
|
|
if p.get("type") == "output_text"
|
|
)
|
|
if text:
|
|
messages.append({"role": "assistant", "content": text})
|
|
elif item.get("type") == "function_call":
|
|
messages.append({
|
|
"role": "assistant", "content": None,
|
|
"tool_calls": [{"id": item.get("call_id"), "type": "function",
|
|
"function": {"name": item.get("name"),
|
|
"arguments": item.get("arguments", "")}}],
|
|
})
|
|
return messages
|
|
|
|
|
|
class _NativeStream:
|
|
"""Re-emit an SDK Responses event stream, rewriting the response id and
|
|
capturing the final output/usage for storage."""
|
|
|
|
def __init__(self, response_id):
|
|
self.response_id = response_id
|
|
self.output_items = []
|
|
self.usage = None
|
|
|
|
async def events(self, sdk_gen):
|
|
async for event in sdk_gen:
|
|
data = event.model_dump() if hasattr(event, "model_dump") else event
|
|
etype = data.get("type", "")
|
|
resp = data.get("response")
|
|
if isinstance(resp, dict) and resp.get("id"):
|
|
resp["id"] = self.response_id
|
|
if etype in ("response.completed", "response.incomplete", "response.failed") \
|
|
and isinstance(resp, dict):
|
|
self.output_items = resp.get("output", []) or []
|
|
self.usage = resp.get("usage")
|
|
yield f"event: {etype}\ndata: {orjson.dumps(data).decode('utf-8')}\n\n".encode("utf-8")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# backend execution (non-streaming, used by background + non-stream sync)
|
|
# ---------------------------------------------------------------------------
|
|
async def _run_to_completion(*, native, oclient, endpoint, model, tracking_model,
|
|
send_params, native_params):
|
|
"""Drive the backend to completion (no client streaming).
|
|
|
|
Returns ``(output_items, usage)`` where usage is responses-shaped. Caller is
|
|
responsible for ``decrement_usage`` (translated failures self-decrement inside
|
|
``create_chat_with_retries``)."""
|
|
if native:
|
|
resp_obj = await oclient.responses.create(stream=False, **native_params)
|
|
data = resp_obj.model_dump()
|
|
return data.get("output", []) or [], data.get("usage")
|
|
async_gen = await create_chat_with_retries(oclient, {**send_params, "stream": False},
|
|
endpoint, model, tracking_model)
|
|
message = async_gen.choices[0].message.model_dump() if async_gen.choices else {}
|
|
output_items = chat_message_to_output_items(message)
|
|
usage = usage_chat_to_responses(
|
|
async_gen.usage.model_dump() if async_gen.usage is not None else None
|
|
)
|
|
return output_items, usage
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# POST /v1/responses
|
|
# ---------------------------------------------------------------------------
|
|
@router.post("/v1/responses")
|
|
async def openai_responses_proxy(request: Request):
|
|
config = get_config()
|
|
try:
|
|
payload = orjson.loads((await request.body()).decode("utf-8"))
|
|
except orjson.JSONDecodeError as e:
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
model = payload.get("model")
|
|
input_data = payload.get("input")
|
|
instructions = payload.get("instructions")
|
|
stream = bool(payload.get("stream"))
|
|
store = payload.get("store", True)
|
|
background = bool(payload.get("background"))
|
|
previous_response_id = payload.get("previous_response_id")
|
|
tools = payload.get("tools")
|
|
metadata = payload.get("metadata") or {}
|
|
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
|
|
|
|
if not model:
|
|
raise HTTPException(status_code=400, detail="Missing required field 'model'")
|
|
if input_data is None:
|
|
raise HTTPException(status_code=400, detail="Missing required field 'input'")
|
|
if background and not store:
|
|
raise HTTPException(status_code=400, detail="background mode requires store=true")
|
|
|
|
if ":latest" in model:
|
|
model = model.split(":latest")[0]
|
|
|
|
# Resolve conversation: prior turns (from store) + this turn's input.
|
|
history = await _resolve_history_messages(previous_response_id)
|
|
messages = history + responses_input_to_messages(input_data, instructions)
|
|
|
|
response_id = f"resp_{secrets.token_hex(24)}"
|
|
created_at = int(time.time())
|
|
|
|
# Cache lookup (foreground only) — before endpoint selection.
|
|
_cache = get_llm_cache()
|
|
if _cache is not None and _cache_enabled and not background:
|
|
cached = await _cache.get_chat("openai_responses", model, messages)
|
|
if cached is not None:
|
|
resp_obj = orjson.loads(cached)
|
|
resp_obj["id"] = response_id
|
|
if stream:
|
|
async def _served_cached():
|
|
yield responses_object_to_sse(resp_obj)
|
|
return StreamingResponse(_served_cached(), media_type="text/event-stream")
|
|
return JSONResponse(content=resp_obj)
|
|
|
|
# Endpoint selection (reserves a slot — must be released exactly once).
|
|
_affinity_key = _conversation_fingerprint(model, messages, None)
|
|
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
|
|
oclient = _make_openai_client(endpoint, default_headers=default_headers,
|
|
api_key=config.api_keys.get(endpoint, "no-key"))
|
|
native = is_ext_openai_endpoint(endpoint)
|
|
|
|
# Build backend params for both shapes.
|
|
send_params = {"messages": messages, "model": model}
|
|
_opt = {
|
|
"temperature": payload.get("temperature"),
|
|
"top_p": payload.get("top_p"),
|
|
"max_tokens": payload.get("max_output_tokens"),
|
|
"tools": tools_responses_to_chat(tools),
|
|
"tool_choice": payload.get("tool_choice"),
|
|
"response_format": _text_format_to_response_format(payload.get("text")),
|
|
}
|
|
send_params.update({k: v for k, v in _opt.items() if v is not None})
|
|
|
|
native_instructions, native_input = messages_to_responses_input(messages)
|
|
native_params = {"model": model, "input": native_input, "store": False}
|
|
_nopt = {
|
|
"instructions": native_instructions,
|
|
"temperature": payload.get("temperature"),
|
|
"top_p": payload.get("top_p"),
|
|
"max_output_tokens": payload.get("max_output_tokens"),
|
|
"tools": tools,
|
|
"tool_choice": payload.get("tool_choice"),
|
|
"text": payload.get("text"),
|
|
"reasoning": payload.get("reasoning"),
|
|
}
|
|
native_params.update({k: v for k, v in _nopt.items() if v is not None})
|
|
|
|
async def _persist(status, output_items=None, usage=None, error=None, insert=False):
|
|
if not store:
|
|
return
|
|
db = get_db()
|
|
if insert:
|
|
await db.store_response(
|
|
response_id, previous_response_id=previous_response_id, model=model,
|
|
status=status, created_at=created_at, input_messages=messages,
|
|
output_items=output_items, usage=usage, instructions=instructions, error=error)
|
|
else:
|
|
await db.update_response_status(response_id, status, output_items=output_items,
|
|
usage=usage, error=error)
|
|
|
|
async def _track(usage):
|
|
prompt_tok, comp_tok = _usage_tokens(usage)
|
|
if prompt_tok or comp_tok:
|
|
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
|
|
|
async def _cache_store(output_items, usage):
|
|
if _cache is None or not _cache_enabled or not output_items:
|
|
return
|
|
obj = build_response_object(response_id=response_id, model=model,
|
|
output_items=output_items, usage=usage,
|
|
created_at=created_at,
|
|
previous_response_id=previous_response_id,
|
|
instructions=instructions, metadata=metadata)
|
|
try:
|
|
await _cache.set_chat("openai_responses", model, messages, orjson.dumps(obj))
|
|
except Exception as _ce:
|
|
print(f"[cache] set_chat (openai_responses) failed: {_ce}")
|
|
|
|
# ---- background: run detached, return queued immediately --------------
|
|
if background:
|
|
await _persist("queued", insert=True)
|
|
|
|
async def _bg_run():
|
|
try:
|
|
await get_db().update_response_status(response_id, "in_progress")
|
|
output_items, usage = await _run_to_completion(
|
|
native=native, oclient=oclient, endpoint=endpoint, model=model,
|
|
tracking_model=tracking_model, send_params=send_params,
|
|
native_params=native_params)
|
|
await _track(usage)
|
|
await _persist("completed", output_items=output_items, usage=usage)
|
|
await _cache_store(output_items, usage)
|
|
except asyncio.CancelledError:
|
|
await get_db().update_response_status(response_id, "cancelled")
|
|
raise
|
|
except Exception as e:
|
|
await get_db().update_response_status(
|
|
response_id, "failed",
|
|
error={"message": str(e)[:500], "type": type(e).__name__})
|
|
finally:
|
|
await decrement_usage(endpoint, tracking_model)
|
|
_background_tasks.pop(response_id, None)
|
|
|
|
task = asyncio.create_task(_bg_run())
|
|
_background_tasks[response_id] = task
|
|
queued = build_response_object(response_id=response_id, model=model, output_items=[],
|
|
status="queued", created_at=created_at,
|
|
previous_response_id=previous_response_id,
|
|
instructions=instructions, metadata=metadata)
|
|
return JSONResponse(content=queued, status_code=200)
|
|
|
|
# ---- streaming sync ----------------------------------------------------
|
|
if stream:
|
|
if native:
|
|
source = await oclient.responses.create(stream=True, **native_params)
|
|
translator = _NativeStream(response_id)
|
|
else:
|
|
source = await create_chat_with_retries(
|
|
oclient, {**send_params, "stream": True,
|
|
"stream_options": {"include_usage": True}},
|
|
endpoint, model, tracking_model)
|
|
translator = ChatToResponsesStream(
|
|
response_id, model, created_at=created_at,
|
|
previous_response_id=previous_response_id, instructions=instructions,
|
|
metadata=metadata)
|
|
|
|
async def _stream():
|
|
await _persist("in_progress", insert=True)
|
|
try:
|
|
async for sse in translator.events(source):
|
|
yield sse
|
|
await _track(translator.usage)
|
|
await _persist("completed", output_items=translator.output_items,
|
|
usage=translator.usage)
|
|
await _cache_store(translator.output_items, translator.usage)
|
|
finally:
|
|
await decrement_usage(endpoint, tracking_model)
|
|
|
|
return StreamingResponse(_stream(), media_type="text/event-stream")
|
|
|
|
# ---- non-streaming sync ------------------------------------------------
|
|
try:
|
|
output_items, usage = await _run_to_completion(
|
|
native=native, oclient=oclient, endpoint=endpoint, model=model,
|
|
tracking_model=tracking_model, send_params=send_params,
|
|
native_params=native_params)
|
|
await _track(usage)
|
|
await _persist("completed", output_items=output_items, usage=usage, insert=True)
|
|
await _cache_store(output_items, usage)
|
|
finally:
|
|
await decrement_usage(endpoint, tracking_model)
|
|
|
|
resp_obj = build_response_object(
|
|
response_id=response_id, model=model, output_items=output_items, usage=usage,
|
|
created_at=created_at, previous_response_id=previous_response_id,
|
|
instructions=instructions, metadata=metadata)
|
|
return JSONResponse(content=resp_obj)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GET / DELETE / cancel
|
|
# ---------------------------------------------------------------------------
|
|
def _stored_to_response_object(row):
|
|
return build_response_object(
|
|
response_id=row["response_id"], model=row.get("model"),
|
|
output_items=row.get("output_items") or [], usage=row.get("usage"),
|
|
status=row.get("status") or "completed", created_at=row.get("created_at"),
|
|
previous_response_id=row.get("previous_response_id"),
|
|
instructions=row.get("instructions"), error=row.get("error"))
|
|
|
|
|
|
@router.get("/v1/responses/{response_id}")
|
|
async def get_response(response_id: str):
|
|
row = await get_db().get_response(response_id)
|
|
if row is None:
|
|
raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found")
|
|
return JSONResponse(content=_stored_to_response_object(row))
|
|
|
|
|
|
@router.delete("/v1/responses/{response_id}")
|
|
async def delete_response(response_id: str):
|
|
deleted = await get_db().delete_response(response_id)
|
|
if not deleted:
|
|
raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found")
|
|
return JSONResponse(content={"id": response_id, "object": "response.deleted", "deleted": True})
|
|
|
|
|
|
@router.post("/v1/responses/{response_id}/cancel")
|
|
async def cancel_response(response_id: str):
|
|
row = await get_db().get_response(response_id)
|
|
if row is None:
|
|
raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found")
|
|
# Cancel the running task if it lives in this worker; otherwise just mark the
|
|
# DB row so a polling client sees a terminal state (cross-worker limitation).
|
|
task = _background_tasks.get(response_id)
|
|
if task is not None and not task.done():
|
|
task.cancel()
|
|
elif row.get("status") in ("queued", "in_progress"):
|
|
await get_db().update_response_status(response_id, "cancelled")
|
|
row = await get_db().get_response(response_id)
|
|
return JSONResponse(content=_stored_to_response_object(row))
|