feat: transparent openai responses api integration
This commit is contained in:
parent
e7407b86b3
commit
b28f175b61
7 changed files with 1674 additions and 86 deletions
398
api/responses.py
Normal file
398
api/responses.py
Normal file
|
|
@ -0,0 +1,398 @@
|
|||
"""OpenAI **Responses API** routes (``/v1/responses`` and its retrieve / delete /
|
||||
cancel companions).
|
||||
|
||||
The router speaks Chat Completions to its backends, so this layer:
|
||||
|
||||
* **native** (external OpenAI): forwards via ``oclient.responses.create`` and
|
||||
streams the SDK's typed events straight back, rewriting the response ``id`` to
|
||||
a router-owned ``resp_`` id so chaining stays router-managed.
|
||||
* **translated** (Ollama / llama-server): converts the request to chat, reuses
|
||||
the resilient ``create_chat_with_retries`` ladder, and re-emits the result as
|
||||
Responses typed SSE events (``requests/responses.py``).
|
||||
|
||||
State (``store`` / ``previous_response_id``) and background-task status live in the
|
||||
router's SQLite DB (``db.py``); the router mints and owns every response id.
|
||||
"""
|
||||
import asyncio
|
||||
import secrets
|
||||
import time
|
||||
|
||||
import orjson
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from cache import get_llm_cache
|
||||
from config import get_config
|
||||
from db import get_db
|
||||
from fingerprint import _conversation_fingerprint
|
||||
from state import token_queue, default_headers
|
||||
from backends.normalize import is_ext_openai_endpoint
|
||||
from backends.sessions import _make_openai_client
|
||||
from routing import choose_endpoint, decrement_usage
|
||||
from api.openai import create_chat_with_retries
|
||||
from requests.responses import (
|
||||
ChatToResponsesStream,
|
||||
build_response_object,
|
||||
chat_message_to_output_items,
|
||||
messages_to_responses_input,
|
||||
responses_input_to_messages,
|
||||
responses_object_to_sse,
|
||||
tools_responses_to_chat,
|
||||
usage_chat_to_responses,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# In-memory handles for background tasks so /cancel can reach a running task in
|
||||
# this worker. Cross-worker cancel falls back to marking the DB row cancelled.
|
||||
_background_tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# small helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
def _usage_tokens(usage):
|
||||
"""Return ``(prompt, completion)`` tokens from a chat- or responses-shaped usage."""
|
||||
if not usage:
|
||||
return 0, 0
|
||||
if "input_tokens" in usage:
|
||||
return usage.get("input_tokens", 0) or 0, usage.get("output_tokens", 0) or 0
|
||||
return usage.get("prompt_tokens", 0) or 0, usage.get("completion_tokens", 0) or 0
|
||||
|
||||
|
||||
def _text_format_to_response_format(text):
|
||||
"""Map Responses ``text.format`` → Chat Completions ``response_format`` (best effort)."""
|
||||
if not isinstance(text, dict):
|
||||
return None
|
||||
fmt = text.get("format")
|
||||
if not isinstance(fmt, dict):
|
||||
return None
|
||||
ftype = fmt.get("type")
|
||||
if ftype == "json_object":
|
||||
return {"type": "json_object"}
|
||||
if ftype == "json_schema":
|
||||
return {"type": "json_schema", "json_schema": {
|
||||
k: fmt[k] for k in ("name", "schema", "strict", "description") if k in fmt
|
||||
}}
|
||||
return None
|
||||
|
||||
|
||||
def _native_usage_from_response(data):
|
||||
return data.get("usage")
|
||||
|
||||
|
||||
async def _resolve_history_messages(previous_response_id):
|
||||
"""Rebuild prior-turn chat messages from the stored response chain."""
|
||||
if not previous_response_id:
|
||||
return []
|
||||
db = get_db()
|
||||
chain = await db.get_response_chain(previous_response_id)
|
||||
messages = []
|
||||
for turn in chain:
|
||||
# Each turn stored the chat messages that produced it + its output items.
|
||||
for m in turn.get("input_messages") or []:
|
||||
messages.append(m)
|
||||
for item in turn.get("output_items") or []:
|
||||
if item.get("type") == "message":
|
||||
text = "".join(
|
||||
p.get("text", "") for p in item.get("content") or []
|
||||
if p.get("type") == "output_text"
|
||||
)
|
||||
if text:
|
||||
messages.append({"role": "assistant", "content": text})
|
||||
elif item.get("type") == "function_call":
|
||||
messages.append({
|
||||
"role": "assistant", "content": None,
|
||||
"tool_calls": [{"id": item.get("call_id"), "type": "function",
|
||||
"function": {"name": item.get("name"),
|
||||
"arguments": item.get("arguments", "")}}],
|
||||
})
|
||||
return messages
|
||||
|
||||
|
||||
class _NativeStream:
|
||||
"""Re-emit an SDK Responses event stream, rewriting the response id and
|
||||
capturing the final output/usage for storage."""
|
||||
|
||||
def __init__(self, response_id):
|
||||
self.response_id = response_id
|
||||
self.output_items = []
|
||||
self.usage = None
|
||||
|
||||
async def events(self, sdk_gen):
|
||||
async for event in sdk_gen:
|
||||
data = event.model_dump() if hasattr(event, "model_dump") else event
|
||||
etype = data.get("type", "")
|
||||
resp = data.get("response")
|
||||
if isinstance(resp, dict) and resp.get("id"):
|
||||
resp["id"] = self.response_id
|
||||
if etype in ("response.completed", "response.incomplete", "response.failed") \
|
||||
and isinstance(resp, dict):
|
||||
self.output_items = resp.get("output", []) or []
|
||||
self.usage = resp.get("usage")
|
||||
yield f"event: {etype}\ndata: {orjson.dumps(data).decode('utf-8')}\n\n".encode("utf-8")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# backend execution (non-streaming, used by background + non-stream sync)
|
||||
# ---------------------------------------------------------------------------
|
||||
async def _run_to_completion(*, native, oclient, endpoint, model, tracking_model,
|
||||
send_params, native_params):
|
||||
"""Drive the backend to completion (no client streaming).
|
||||
|
||||
Returns ``(output_items, usage)`` where usage is responses-shaped. Caller is
|
||||
responsible for ``decrement_usage`` (translated failures self-decrement inside
|
||||
``create_chat_with_retries``)."""
|
||||
if native:
|
||||
resp_obj = await oclient.responses.create(stream=False, **native_params)
|
||||
data = resp_obj.model_dump()
|
||||
return data.get("output", []) or [], data.get("usage")
|
||||
async_gen = await create_chat_with_retries(oclient, {**send_params, "stream": False},
|
||||
endpoint, model, tracking_model)
|
||||
message = async_gen.choices[0].message.model_dump() if async_gen.choices else {}
|
||||
output_items = chat_message_to_output_items(message)
|
||||
usage = usage_chat_to_responses(
|
||||
async_gen.usage.model_dump() if async_gen.usage is not None else None
|
||||
)
|
||||
return output_items, usage
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /v1/responses
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.post("/v1/responses")
|
||||
async def openai_responses_proxy(request: Request):
|
||||
config = get_config()
|
||||
try:
|
||||
payload = orjson.loads((await request.body()).decode("utf-8"))
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
model = payload.get("model")
|
||||
input_data = payload.get("input")
|
||||
instructions = payload.get("instructions")
|
||||
stream = bool(payload.get("stream"))
|
||||
store = payload.get("store", True)
|
||||
background = bool(payload.get("background"))
|
||||
previous_response_id = payload.get("previous_response_id")
|
||||
tools = payload.get("tools")
|
||||
metadata = payload.get("metadata") or {}
|
||||
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
|
||||
|
||||
if not model:
|
||||
raise HTTPException(status_code=400, detail="Missing required field 'model'")
|
||||
if input_data is None:
|
||||
raise HTTPException(status_code=400, detail="Missing required field 'input'")
|
||||
if background and not store:
|
||||
raise HTTPException(status_code=400, detail="background mode requires store=true")
|
||||
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")[0]
|
||||
|
||||
# Resolve conversation: prior turns (from store) + this turn's input.
|
||||
history = await _resolve_history_messages(previous_response_id)
|
||||
messages = history + responses_input_to_messages(input_data, instructions)
|
||||
|
||||
response_id = f"resp_{secrets.token_hex(24)}"
|
||||
created_at = int(time.time())
|
||||
|
||||
# Cache lookup (foreground only) — before endpoint selection.
|
||||
_cache = get_llm_cache()
|
||||
if _cache is not None and _cache_enabled and not background:
|
||||
cached = await _cache.get_chat("openai_responses", model, messages)
|
||||
if cached is not None:
|
||||
resp_obj = orjson.loads(cached)
|
||||
resp_obj["id"] = response_id
|
||||
if stream:
|
||||
async def _served_cached():
|
||||
yield responses_object_to_sse(resp_obj)
|
||||
return StreamingResponse(_served_cached(), media_type="text/event-stream")
|
||||
return JSONResponse(content=resp_obj)
|
||||
|
||||
# Endpoint selection (reserves a slot — must be released exactly once).
|
||||
_affinity_key = _conversation_fingerprint(model, messages, None)
|
||||
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
|
||||
oclient = _make_openai_client(endpoint, default_headers=default_headers,
|
||||
api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
native = is_ext_openai_endpoint(endpoint)
|
||||
|
||||
# Build backend params for both shapes.
|
||||
send_params = {"messages": messages, "model": model}
|
||||
_opt = {
|
||||
"temperature": payload.get("temperature"),
|
||||
"top_p": payload.get("top_p"),
|
||||
"max_tokens": payload.get("max_output_tokens"),
|
||||
"tools": tools_responses_to_chat(tools),
|
||||
"tool_choice": payload.get("tool_choice"),
|
||||
"response_format": _text_format_to_response_format(payload.get("text")),
|
||||
}
|
||||
send_params.update({k: v for k, v in _opt.items() if v is not None})
|
||||
|
||||
native_instructions, native_input = messages_to_responses_input(messages)
|
||||
native_params = {"model": model, "input": native_input, "store": False}
|
||||
_nopt = {
|
||||
"instructions": native_instructions,
|
||||
"temperature": payload.get("temperature"),
|
||||
"top_p": payload.get("top_p"),
|
||||
"max_output_tokens": payload.get("max_output_tokens"),
|
||||
"tools": tools,
|
||||
"tool_choice": payload.get("tool_choice"),
|
||||
"text": payload.get("text"),
|
||||
"reasoning": payload.get("reasoning"),
|
||||
}
|
||||
native_params.update({k: v for k, v in _nopt.items() if v is not None})
|
||||
|
||||
async def _persist(status, output_items=None, usage=None, error=None, insert=False):
|
||||
if not store:
|
||||
return
|
||||
db = get_db()
|
||||
if insert:
|
||||
await db.store_response(
|
||||
response_id, previous_response_id=previous_response_id, model=model,
|
||||
status=status, created_at=created_at, input_messages=messages,
|
||||
output_items=output_items, usage=usage, instructions=instructions, error=error)
|
||||
else:
|
||||
await db.update_response_status(response_id, status, output_items=output_items,
|
||||
usage=usage, error=error)
|
||||
|
||||
async def _track(usage):
|
||||
prompt_tok, comp_tok = _usage_tokens(usage)
|
||||
if prompt_tok or comp_tok:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
|
||||
async def _cache_store(output_items, usage):
|
||||
if _cache is None or not _cache_enabled or not output_items:
|
||||
return
|
||||
obj = build_response_object(response_id=response_id, model=model,
|
||||
output_items=output_items, usage=usage,
|
||||
created_at=created_at,
|
||||
previous_response_id=previous_response_id,
|
||||
instructions=instructions, metadata=metadata)
|
||||
try:
|
||||
await _cache.set_chat("openai_responses", model, messages, orjson.dumps(obj))
|
||||
except Exception as _ce:
|
||||
print(f"[cache] set_chat (openai_responses) failed: {_ce}")
|
||||
|
||||
# ---- background: run detached, return queued immediately --------------
|
||||
if background:
|
||||
await _persist("queued", insert=True)
|
||||
|
||||
async def _bg_run():
|
||||
try:
|
||||
await get_db().update_response_status(response_id, "in_progress")
|
||||
output_items, usage = await _run_to_completion(
|
||||
native=native, oclient=oclient, endpoint=endpoint, model=model,
|
||||
tracking_model=tracking_model, send_params=send_params,
|
||||
native_params=native_params)
|
||||
await _track(usage)
|
||||
await _persist("completed", output_items=output_items, usage=usage)
|
||||
await _cache_store(output_items, usage)
|
||||
except asyncio.CancelledError:
|
||||
await get_db().update_response_status(response_id, "cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
await get_db().update_response_status(
|
||||
response_id, "failed",
|
||||
error={"message": str(e)[:500], "type": type(e).__name__})
|
||||
finally:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
_background_tasks.pop(response_id, None)
|
||||
|
||||
task = asyncio.create_task(_bg_run())
|
||||
_background_tasks[response_id] = task
|
||||
queued = build_response_object(response_id=response_id, model=model, output_items=[],
|
||||
status="queued", created_at=created_at,
|
||||
previous_response_id=previous_response_id,
|
||||
instructions=instructions, metadata=metadata)
|
||||
return JSONResponse(content=queued, status_code=200)
|
||||
|
||||
# ---- streaming sync ----------------------------------------------------
|
||||
if stream:
|
||||
if native:
|
||||
source = await oclient.responses.create(stream=True, **native_params)
|
||||
translator = _NativeStream(response_id)
|
||||
else:
|
||||
source = await create_chat_with_retries(
|
||||
oclient, {**send_params, "stream": True,
|
||||
"stream_options": {"include_usage": True}},
|
||||
endpoint, model, tracking_model)
|
||||
translator = ChatToResponsesStream(
|
||||
response_id, model, created_at=created_at,
|
||||
previous_response_id=previous_response_id, instructions=instructions,
|
||||
metadata=metadata)
|
||||
|
||||
async def _stream():
|
||||
await _persist("in_progress", insert=True)
|
||||
try:
|
||||
async for sse in translator.events(source):
|
||||
yield sse
|
||||
await _track(translator.usage)
|
||||
await _persist("completed", output_items=translator.output_items,
|
||||
usage=translator.usage)
|
||||
await _cache_store(translator.output_items, translator.usage)
|
||||
finally:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
return StreamingResponse(_stream(), media_type="text/event-stream")
|
||||
|
||||
# ---- non-streaming sync ------------------------------------------------
|
||||
try:
|
||||
output_items, usage = await _run_to_completion(
|
||||
native=native, oclient=oclient, endpoint=endpoint, model=model,
|
||||
tracking_model=tracking_model, send_params=send_params,
|
||||
native_params=native_params)
|
||||
await _track(usage)
|
||||
await _persist("completed", output_items=output_items, usage=usage, insert=True)
|
||||
await _cache_store(output_items, usage)
|
||||
finally:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
resp_obj = build_response_object(
|
||||
response_id=response_id, model=model, output_items=output_items, usage=usage,
|
||||
created_at=created_at, previous_response_id=previous_response_id,
|
||||
instructions=instructions, metadata=metadata)
|
||||
return JSONResponse(content=resp_obj)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET / DELETE / cancel
|
||||
# ---------------------------------------------------------------------------
|
||||
def _stored_to_response_object(row):
|
||||
return build_response_object(
|
||||
response_id=row["response_id"], model=row.get("model"),
|
||||
output_items=row.get("output_items") or [], usage=row.get("usage"),
|
||||
status=row.get("status") or "completed", created_at=row.get("created_at"),
|
||||
previous_response_id=row.get("previous_response_id"),
|
||||
instructions=row.get("instructions"), error=row.get("error"))
|
||||
|
||||
|
||||
@router.get("/v1/responses/{response_id}")
|
||||
async def get_response(response_id: str):
|
||||
row = await get_db().get_response(response_id)
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found")
|
||||
return JSONResponse(content=_stored_to_response_object(row))
|
||||
|
||||
|
||||
@router.delete("/v1/responses/{response_id}")
|
||||
async def delete_response(response_id: str):
|
||||
deleted = await get_db().delete_response(response_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found")
|
||||
return JSONResponse(content={"id": response_id, "object": "response.deleted", "deleted": True})
|
||||
|
||||
|
||||
@router.post("/v1/responses/{response_id}/cancel")
|
||||
async def cancel_response(response_id: str):
|
||||
row = await get_db().get_response(response_id)
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found")
|
||||
# Cancel the running task if it lives in this worker; otherwise just mark the
|
||||
# DB row so a polling client sees a terminal state (cross-worker limitation).
|
||||
task = _background_tasks.get(response_id)
|
||||
if task is not None and not task.done():
|
||||
task.cancel()
|
||||
elif row.get("status") in ("queued", "in_progress"):
|
||||
await get_db().update_response_status(response_id, "cancelled")
|
||||
row = await get_db().get_response(response_id)
|
||||
return JSONResponse(content=_stored_to_response_object(row))
|
||||
Loading…
Add table
Add a link
Reference in a new issue