feat: transparent openai responses api integration
This commit is contained in:
parent
e7407b86b3
commit
b28f175b61
7 changed files with 1674 additions and 86 deletions
189
api/openai.py
189
api/openai.py
|
|
@ -46,6 +46,110 @@ from routing import choose_endpoint, decrement_usage
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
async def create_chat_with_retries(oclient, send_params, endpoint, model, tracking_model):
|
||||
"""Call ``chat.completions.create`` with the router's resilience retries.
|
||||
|
||||
Encapsulates the recovery ladder shared by the chat-completions handler and
|
||||
the translated ``/v1/responses`` path:
|
||||
|
||||
* ``does not support tools`` → retry without ``tools``
|
||||
* llama-server context exhaustion → sliding-window message trim, with a
|
||||
second retry that also strips ``tools``/``tool_choice``
|
||||
* backend connection failure → mark (endpoint, model) unhealthy so the next
|
||||
request reroutes, then re-raise
|
||||
* ``image input is not supported`` → strip images and retry
|
||||
|
||||
On unrecoverable failure the endpoint usage counter is decremented and the
|
||||
exception is re-raised. Returns the established async generator / response.
|
||||
"""
|
||||
config = get_config()
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**send_params)
|
||||
except Exception as e:
|
||||
_e_str = str(e)
|
||||
_is_ctx_err = "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str
|
||||
print(f"[ochat] caught={type(e).__name__} ctx={_is_ctx_err} msg={_e_str[:120]}", flush=True)
|
||||
if "does not support tools" in _e_str:
|
||||
# Model doesn't support tools — retry without them
|
||||
print(f"[ochat] retry: no tools", flush=True)
|
||||
try:
|
||||
params_without_tools = {k: v for k, v in send_params.items() if k != "tools"}
|
||||
async_gen = await oclient.chat.completions.create(**params_without_tools)
|
||||
except Exception:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
elif _is_ctx_err:
|
||||
# Backend context limit hit — apply sliding-window trim (context-shift at message level)
|
||||
err_body = getattr(e, "body", {}) or {}
|
||||
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
|
||||
n_ctx_limit = err_detail.get("n_ctx", 0)
|
||||
actual_tokens = err_detail.get("n_prompt_tokens", 0)
|
||||
# Fallback: parse from string if body parsing yielded nothing (SDK may not parse llama-server errors)
|
||||
if not n_ctx_limit:
|
||||
import re as _re
|
||||
_m = _re.search(r"'n_ctx':\s*(\d+)", _e_str)
|
||||
if _m:
|
||||
n_ctx_limit = int(_m.group(1))
|
||||
_m = _re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str)
|
||||
if _m:
|
||||
actual_tokens = int(_m.group(1))
|
||||
print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True)
|
||||
if not n_ctx_limit:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
|
||||
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
|
||||
|
||||
msgs_to_trim = send_params.get("messages", [])
|
||||
try:
|
||||
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
|
||||
trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
|
||||
except Exception as _helper_exc:
|
||||
print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True)
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
dropped = len(msgs_to_trim) - len(trimmed_messages)
|
||||
print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True)
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages})
|
||||
print(f"[ctx-trim] retry-1 ok", flush=True)
|
||||
except Exception as e2:
|
||||
_e2_str = str(e2)
|
||||
if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str:
|
||||
# Still too large — tool definitions likely consuming too many tokens, strip them too
|
||||
print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True)
|
||||
params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")}
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages})
|
||||
print(f"[ctx-trim] retry-2 ok", flush=True)
|
||||
except Exception:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
else:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
elif _is_backend_connection_error(e):
|
||||
# Upstream connection failed (e.g. llama-server in router mode
|
||||
# whose delegated worker died). Mark (endpoint, model) so the
|
||||
# next request reroutes; the client will retry this one.
|
||||
print(f"[ochat] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
|
||||
await _mark_backend_unhealthy(endpoint, model, _e_str)
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
elif "image input is not supported" in _e_str:
|
||||
# Model doesn't support images — strip and retry
|
||||
print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages")
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))})
|
||||
except Exception:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
else:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
return async_gen
|
||||
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
async def openai_embedding_proxy(request: Request):
|
||||
"""
|
||||
|
|
@ -260,90 +364,7 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
_dropped = len(_pre_msgs) - len(_pre_trimmed)
|
||||
print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True)
|
||||
send_params = {**send_params, "messages": _pre_trimmed}
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**send_params)
|
||||
except Exception as e:
|
||||
_e_str = str(e)
|
||||
_is_ctx_err = "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str
|
||||
print(f"[ochat] caught={type(e).__name__} ctx={_is_ctx_err} msg={_e_str[:120]}", flush=True)
|
||||
if "does not support tools" in _e_str:
|
||||
# Model doesn't support tools — retry without them
|
||||
print(f"[ochat] retry: no tools", flush=True)
|
||||
try:
|
||||
params_without_tools = {k: v for k, v in send_params.items() if k != "tools"}
|
||||
async_gen = await oclient.chat.completions.create(**params_without_tools)
|
||||
except Exception:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
elif _is_ctx_err:
|
||||
# Backend context limit hit — apply sliding-window trim (context-shift at message level)
|
||||
err_body = getattr(e, "body", {}) or {}
|
||||
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
|
||||
n_ctx_limit = err_detail.get("n_ctx", 0)
|
||||
actual_tokens = err_detail.get("n_prompt_tokens", 0)
|
||||
# Fallback: parse from string if body parsing yielded nothing (SDK may not parse llama-server errors)
|
||||
if not n_ctx_limit:
|
||||
import re as _re
|
||||
_m = _re.search(r"'n_ctx':\s*(\d+)", _e_str)
|
||||
if _m:
|
||||
n_ctx_limit = int(_m.group(1))
|
||||
_m = _re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str)
|
||||
if _m:
|
||||
actual_tokens = int(_m.group(1))
|
||||
print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True)
|
||||
if not n_ctx_limit:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
|
||||
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
|
||||
|
||||
msgs_to_trim = send_params.get("messages", [])
|
||||
try:
|
||||
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
|
||||
trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
|
||||
except Exception as _helper_exc:
|
||||
print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True)
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
dropped = len(msgs_to_trim) - len(trimmed_messages)
|
||||
print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True)
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages})
|
||||
print(f"[ctx-trim] retry-1 ok", flush=True)
|
||||
except Exception as e2:
|
||||
_e2_str = str(e2)
|
||||
if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str:
|
||||
# Still too large — tool definitions likely consuming too many tokens, strip them too
|
||||
print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True)
|
||||
params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")}
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages})
|
||||
print(f"[ctx-trim] retry-2 ok", flush=True)
|
||||
except Exception:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
else:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
elif _is_backend_connection_error(e):
|
||||
# Upstream connection failed (e.g. llama-server in router mode
|
||||
# whose delegated worker died). Mark (endpoint, model) so the
|
||||
# next request reroutes; the client will retry this one.
|
||||
print(f"[ochat] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
|
||||
await _mark_backend_unhealthy(endpoint, model, _e_str)
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
elif "image input is not supported" in _e_str:
|
||||
# Model doesn't support images — strip and retry
|
||||
print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages")
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))})
|
||||
except Exception:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
else:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
async_gen = await create_chat_with_retries(oclient, send_params, endpoint, model, tracking_model)
|
||||
|
||||
# 4. Async generator — only streams the already-established async_gen
|
||||
async def stream_ochat_response():
|
||||
|
|
|
|||
398
api/responses.py
Normal file
398
api/responses.py
Normal file
|
|
@ -0,0 +1,398 @@
|
|||
"""OpenAI **Responses API** routes (``/v1/responses`` and its retrieve / delete /
|
||||
cancel companions).
|
||||
|
||||
The router speaks Chat Completions to its backends, so this layer:
|
||||
|
||||
* **native** (external OpenAI): forwards via ``oclient.responses.create`` and
|
||||
streams the SDK's typed events straight back, rewriting the response ``id`` to
|
||||
a router-owned ``resp_`` id so chaining stays router-managed.
|
||||
* **translated** (Ollama / llama-server): converts the request to chat, reuses
|
||||
the resilient ``create_chat_with_retries`` ladder, and re-emits the result as
|
||||
Responses typed SSE events (``requests/responses.py``).
|
||||
|
||||
State (``store`` / ``previous_response_id``) and background-task status live in the
|
||||
router's SQLite DB (``db.py``); the router mints and owns every response id.
|
||||
"""
|
||||
import asyncio
|
||||
import secrets
|
||||
import time
|
||||
|
||||
import orjson
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from cache import get_llm_cache
|
||||
from config import get_config
|
||||
from db import get_db
|
||||
from fingerprint import _conversation_fingerprint
|
||||
from state import token_queue, default_headers
|
||||
from backends.normalize import is_ext_openai_endpoint
|
||||
from backends.sessions import _make_openai_client
|
||||
from routing import choose_endpoint, decrement_usage
|
||||
from api.openai import create_chat_with_retries
|
||||
from requests.responses import (
|
||||
ChatToResponsesStream,
|
||||
build_response_object,
|
||||
chat_message_to_output_items,
|
||||
messages_to_responses_input,
|
||||
responses_input_to_messages,
|
||||
responses_object_to_sse,
|
||||
tools_responses_to_chat,
|
||||
usage_chat_to_responses,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# In-memory handles for background tasks so /cancel can reach a running task in
|
||||
# this worker. Cross-worker cancel falls back to marking the DB row cancelled.
|
||||
_background_tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# small helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
def _usage_tokens(usage):
|
||||
"""Return ``(prompt, completion)`` tokens from a chat- or responses-shaped usage."""
|
||||
if not usage:
|
||||
return 0, 0
|
||||
if "input_tokens" in usage:
|
||||
return usage.get("input_tokens", 0) or 0, usage.get("output_tokens", 0) or 0
|
||||
return usage.get("prompt_tokens", 0) or 0, usage.get("completion_tokens", 0) or 0
|
||||
|
||||
|
||||
def _text_format_to_response_format(text):
|
||||
"""Map Responses ``text.format`` → Chat Completions ``response_format`` (best effort)."""
|
||||
if not isinstance(text, dict):
|
||||
return None
|
||||
fmt = text.get("format")
|
||||
if not isinstance(fmt, dict):
|
||||
return None
|
||||
ftype = fmt.get("type")
|
||||
if ftype == "json_object":
|
||||
return {"type": "json_object"}
|
||||
if ftype == "json_schema":
|
||||
return {"type": "json_schema", "json_schema": {
|
||||
k: fmt[k] for k in ("name", "schema", "strict", "description") if k in fmt
|
||||
}}
|
||||
return None
|
||||
|
||||
|
||||
def _native_usage_from_response(data):
|
||||
return data.get("usage")
|
||||
|
||||
|
||||
async def _resolve_history_messages(previous_response_id):
|
||||
"""Rebuild prior-turn chat messages from the stored response chain."""
|
||||
if not previous_response_id:
|
||||
return []
|
||||
db = get_db()
|
||||
chain = await db.get_response_chain(previous_response_id)
|
||||
messages = []
|
||||
for turn in chain:
|
||||
# Each turn stored the chat messages that produced it + its output items.
|
||||
for m in turn.get("input_messages") or []:
|
||||
messages.append(m)
|
||||
for item in turn.get("output_items") or []:
|
||||
if item.get("type") == "message":
|
||||
text = "".join(
|
||||
p.get("text", "") for p in item.get("content") or []
|
||||
if p.get("type") == "output_text"
|
||||
)
|
||||
if text:
|
||||
messages.append({"role": "assistant", "content": text})
|
||||
elif item.get("type") == "function_call":
|
||||
messages.append({
|
||||
"role": "assistant", "content": None,
|
||||
"tool_calls": [{"id": item.get("call_id"), "type": "function",
|
||||
"function": {"name": item.get("name"),
|
||||
"arguments": item.get("arguments", "")}}],
|
||||
})
|
||||
return messages
|
||||
|
||||
|
||||
class _NativeStream:
|
||||
"""Re-emit an SDK Responses event stream, rewriting the response id and
|
||||
capturing the final output/usage for storage."""
|
||||
|
||||
def __init__(self, response_id):
|
||||
self.response_id = response_id
|
||||
self.output_items = []
|
||||
self.usage = None
|
||||
|
||||
async def events(self, sdk_gen):
|
||||
async for event in sdk_gen:
|
||||
data = event.model_dump() if hasattr(event, "model_dump") else event
|
||||
etype = data.get("type", "")
|
||||
resp = data.get("response")
|
||||
if isinstance(resp, dict) and resp.get("id"):
|
||||
resp["id"] = self.response_id
|
||||
if etype in ("response.completed", "response.incomplete", "response.failed") \
|
||||
and isinstance(resp, dict):
|
||||
self.output_items = resp.get("output", []) or []
|
||||
self.usage = resp.get("usage")
|
||||
yield f"event: {etype}\ndata: {orjson.dumps(data).decode('utf-8')}\n\n".encode("utf-8")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# backend execution (non-streaming, used by background + non-stream sync)
|
||||
# ---------------------------------------------------------------------------
|
||||
async def _run_to_completion(*, native, oclient, endpoint, model, tracking_model,
|
||||
send_params, native_params):
|
||||
"""Drive the backend to completion (no client streaming).
|
||||
|
||||
Returns ``(output_items, usage)`` where usage is responses-shaped. Caller is
|
||||
responsible for ``decrement_usage`` (translated failures self-decrement inside
|
||||
``create_chat_with_retries``)."""
|
||||
if native:
|
||||
resp_obj = await oclient.responses.create(stream=False, **native_params)
|
||||
data = resp_obj.model_dump()
|
||||
return data.get("output", []) or [], data.get("usage")
|
||||
async_gen = await create_chat_with_retries(oclient, {**send_params, "stream": False},
|
||||
endpoint, model, tracking_model)
|
||||
message = async_gen.choices[0].message.model_dump() if async_gen.choices else {}
|
||||
output_items = chat_message_to_output_items(message)
|
||||
usage = usage_chat_to_responses(
|
||||
async_gen.usage.model_dump() if async_gen.usage is not None else None
|
||||
)
|
||||
return output_items, usage
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /v1/responses
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.post("/v1/responses")
|
||||
async def openai_responses_proxy(request: Request):
|
||||
config = get_config()
|
||||
try:
|
||||
payload = orjson.loads((await request.body()).decode("utf-8"))
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
model = payload.get("model")
|
||||
input_data = payload.get("input")
|
||||
instructions = payload.get("instructions")
|
||||
stream = bool(payload.get("stream"))
|
||||
store = payload.get("store", True)
|
||||
background = bool(payload.get("background"))
|
||||
previous_response_id = payload.get("previous_response_id")
|
||||
tools = payload.get("tools")
|
||||
metadata = payload.get("metadata") or {}
|
||||
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
|
||||
|
||||
if not model:
|
||||
raise HTTPException(status_code=400, detail="Missing required field 'model'")
|
||||
if input_data is None:
|
||||
raise HTTPException(status_code=400, detail="Missing required field 'input'")
|
||||
if background and not store:
|
||||
raise HTTPException(status_code=400, detail="background mode requires store=true")
|
||||
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")[0]
|
||||
|
||||
# Resolve conversation: prior turns (from store) + this turn's input.
|
||||
history = await _resolve_history_messages(previous_response_id)
|
||||
messages = history + responses_input_to_messages(input_data, instructions)
|
||||
|
||||
response_id = f"resp_{secrets.token_hex(24)}"
|
||||
created_at = int(time.time())
|
||||
|
||||
# Cache lookup (foreground only) — before endpoint selection.
|
||||
_cache = get_llm_cache()
|
||||
if _cache is not None and _cache_enabled and not background:
|
||||
cached = await _cache.get_chat("openai_responses", model, messages)
|
||||
if cached is not None:
|
||||
resp_obj = orjson.loads(cached)
|
||||
resp_obj["id"] = response_id
|
||||
if stream:
|
||||
async def _served_cached():
|
||||
yield responses_object_to_sse(resp_obj)
|
||||
return StreamingResponse(_served_cached(), media_type="text/event-stream")
|
||||
return JSONResponse(content=resp_obj)
|
||||
|
||||
# Endpoint selection (reserves a slot — must be released exactly once).
|
||||
_affinity_key = _conversation_fingerprint(model, messages, None)
|
||||
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
|
||||
oclient = _make_openai_client(endpoint, default_headers=default_headers,
|
||||
api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
native = is_ext_openai_endpoint(endpoint)
|
||||
|
||||
# Build backend params for both shapes.
|
||||
send_params = {"messages": messages, "model": model}
|
||||
_opt = {
|
||||
"temperature": payload.get("temperature"),
|
||||
"top_p": payload.get("top_p"),
|
||||
"max_tokens": payload.get("max_output_tokens"),
|
||||
"tools": tools_responses_to_chat(tools),
|
||||
"tool_choice": payload.get("tool_choice"),
|
||||
"response_format": _text_format_to_response_format(payload.get("text")),
|
||||
}
|
||||
send_params.update({k: v for k, v in _opt.items() if v is not None})
|
||||
|
||||
native_instructions, native_input = messages_to_responses_input(messages)
|
||||
native_params = {"model": model, "input": native_input, "store": False}
|
||||
_nopt = {
|
||||
"instructions": native_instructions,
|
||||
"temperature": payload.get("temperature"),
|
||||
"top_p": payload.get("top_p"),
|
||||
"max_output_tokens": payload.get("max_output_tokens"),
|
||||
"tools": tools,
|
||||
"tool_choice": payload.get("tool_choice"),
|
||||
"text": payload.get("text"),
|
||||
"reasoning": payload.get("reasoning"),
|
||||
}
|
||||
native_params.update({k: v for k, v in _nopt.items() if v is not None})
|
||||
|
||||
async def _persist(status, output_items=None, usage=None, error=None, insert=False):
|
||||
if not store:
|
||||
return
|
||||
db = get_db()
|
||||
if insert:
|
||||
await db.store_response(
|
||||
response_id, previous_response_id=previous_response_id, model=model,
|
||||
status=status, created_at=created_at, input_messages=messages,
|
||||
output_items=output_items, usage=usage, instructions=instructions, error=error)
|
||||
else:
|
||||
await db.update_response_status(response_id, status, output_items=output_items,
|
||||
usage=usage, error=error)
|
||||
|
||||
async def _track(usage):
|
||||
prompt_tok, comp_tok = _usage_tokens(usage)
|
||||
if prompt_tok or comp_tok:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
|
||||
async def _cache_store(output_items, usage):
|
||||
if _cache is None or not _cache_enabled or not output_items:
|
||||
return
|
||||
obj = build_response_object(response_id=response_id, model=model,
|
||||
output_items=output_items, usage=usage,
|
||||
created_at=created_at,
|
||||
previous_response_id=previous_response_id,
|
||||
instructions=instructions, metadata=metadata)
|
||||
try:
|
||||
await _cache.set_chat("openai_responses", model, messages, orjson.dumps(obj))
|
||||
except Exception as _ce:
|
||||
print(f"[cache] set_chat (openai_responses) failed: {_ce}")
|
||||
|
||||
# ---- background: run detached, return queued immediately --------------
|
||||
if background:
|
||||
await _persist("queued", insert=True)
|
||||
|
||||
async def _bg_run():
|
||||
try:
|
||||
await get_db().update_response_status(response_id, "in_progress")
|
||||
output_items, usage = await _run_to_completion(
|
||||
native=native, oclient=oclient, endpoint=endpoint, model=model,
|
||||
tracking_model=tracking_model, send_params=send_params,
|
||||
native_params=native_params)
|
||||
await _track(usage)
|
||||
await _persist("completed", output_items=output_items, usage=usage)
|
||||
await _cache_store(output_items, usage)
|
||||
except asyncio.CancelledError:
|
||||
await get_db().update_response_status(response_id, "cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
await get_db().update_response_status(
|
||||
response_id, "failed",
|
||||
error={"message": str(e)[:500], "type": type(e).__name__})
|
||||
finally:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
_background_tasks.pop(response_id, None)
|
||||
|
||||
task = asyncio.create_task(_bg_run())
|
||||
_background_tasks[response_id] = task
|
||||
queued = build_response_object(response_id=response_id, model=model, output_items=[],
|
||||
status="queued", created_at=created_at,
|
||||
previous_response_id=previous_response_id,
|
||||
instructions=instructions, metadata=metadata)
|
||||
return JSONResponse(content=queued, status_code=200)
|
||||
|
||||
# ---- streaming sync ----------------------------------------------------
|
||||
if stream:
|
||||
if native:
|
||||
source = await oclient.responses.create(stream=True, **native_params)
|
||||
translator = _NativeStream(response_id)
|
||||
else:
|
||||
source = await create_chat_with_retries(
|
||||
oclient, {**send_params, "stream": True,
|
||||
"stream_options": {"include_usage": True}},
|
||||
endpoint, model, tracking_model)
|
||||
translator = ChatToResponsesStream(
|
||||
response_id, model, created_at=created_at,
|
||||
previous_response_id=previous_response_id, instructions=instructions,
|
||||
metadata=metadata)
|
||||
|
||||
async def _stream():
|
||||
await _persist("in_progress", insert=True)
|
||||
try:
|
||||
async for sse in translator.events(source):
|
||||
yield sse
|
||||
await _track(translator.usage)
|
||||
await _persist("completed", output_items=translator.output_items,
|
||||
usage=translator.usage)
|
||||
await _cache_store(translator.output_items, translator.usage)
|
||||
finally:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
return StreamingResponse(_stream(), media_type="text/event-stream")
|
||||
|
||||
# ---- non-streaming sync ------------------------------------------------
|
||||
try:
|
||||
output_items, usage = await _run_to_completion(
|
||||
native=native, oclient=oclient, endpoint=endpoint, model=model,
|
||||
tracking_model=tracking_model, send_params=send_params,
|
||||
native_params=native_params)
|
||||
await _track(usage)
|
||||
await _persist("completed", output_items=output_items, usage=usage, insert=True)
|
||||
await _cache_store(output_items, usage)
|
||||
finally:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
resp_obj = build_response_object(
|
||||
response_id=response_id, model=model, output_items=output_items, usage=usage,
|
||||
created_at=created_at, previous_response_id=previous_response_id,
|
||||
instructions=instructions, metadata=metadata)
|
||||
return JSONResponse(content=resp_obj)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET / DELETE / cancel
|
||||
# ---------------------------------------------------------------------------
|
||||
def _stored_to_response_object(row):
|
||||
return build_response_object(
|
||||
response_id=row["response_id"], model=row.get("model"),
|
||||
output_items=row.get("output_items") or [], usage=row.get("usage"),
|
||||
status=row.get("status") or "completed", created_at=row.get("created_at"),
|
||||
previous_response_id=row.get("previous_response_id"),
|
||||
instructions=row.get("instructions"), error=row.get("error"))
|
||||
|
||||
|
||||
@router.get("/v1/responses/{response_id}")
|
||||
async def get_response(response_id: str):
|
||||
row = await get_db().get_response(response_id)
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found")
|
||||
return JSONResponse(content=_stored_to_response_object(row))
|
||||
|
||||
|
||||
@router.delete("/v1/responses/{response_id}")
|
||||
async def delete_response(response_id: str):
|
||||
deleted = await get_db().delete_response(response_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found")
|
||||
return JSONResponse(content={"id": response_id, "object": "response.deleted", "deleted": True})
|
||||
|
||||
|
||||
@router.post("/v1/responses/{response_id}/cancel")
|
||||
async def cancel_response(response_id: str):
|
||||
row = await get_db().get_response(response_id)
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found")
|
||||
# Cancel the running task if it lives in this worker; otherwise just mark the
|
||||
# DB row so a polling client sees a terminal state (cross-worker limitation).
|
||||
task = _background_tasks.get(response_id)
|
||||
if task is not None and not task.done():
|
||||
task.cancel()
|
||||
elif row.get("status") in ("queued", "in_progress"):
|
||||
await get_db().update_response_status(response_id, "cancelled")
|
||||
row = await get_db().get_response(response_id)
|
||||
return JSONResponse(content=_stored_to_response_object(row))
|
||||
Loading…
Add table
Add a link
Reference in a new issue