diff --git a/README.md b/README.md index fb60988..52f7b99 100644 --- a/README.md +++ b/README.md @@ -132,6 +132,41 @@ This way the Ollama backend servers are utilized more efficient than by simply u NOMYO Router also supports OpenAI API compatible v1 backend servers. +## OpenAI Responses API + +In addition to Chat Completions, NOMYO Router exposes the OpenAI **Responses API**: + +``` +POST /v1/responses # create a response (stream or non-stream) +GET /v1/responses/{id} # retrieve a stored response +DELETE /v1/responses/{id} # delete a stored response +POST /v1/responses/{id}/cancel # cancel a background response +``` + +It works transparently across **all** backends. When the routed model lives on a native +Responses backend (external OpenAI) the request is forwarded as-is; for Ollama and llama-server the +router translates Responses ⇄ Chat Completions in both directions (request, response, and streaming +typed SSE events), so clients get a consistent `/v1/responses` surface regardless of backend. + +### Conversation state (`store` / `previous_response_id`) + +The router **owns conversation state itself** (persisted in its SQLite DB) rather than delegating to +the upstream provider, so `store` and `previous_response_id` behave identically on every backend. +On a follow-up request the router rehydrates the prior turns from its DB and expands them into the +conversation; outbound native calls always send `store=false`. Trade-off: this forgoes OpenAI's +server-side reasoning-state reuse in exchange for uniform, backend-agnostic chaining. + +### Background mode + +`background:true` (which requires `store:true`) returns immediately with `{"status":"queued"}`; the +request runs server-side and the client polls `GET /v1/responses/{id}` until the status reaches a +terminal state (`completed` / `failed` / `cancelled`). `POST /v1/responses/{id}/cancel` aborts it. + +Limitations: streaming reconnect-resume via `starting_after` is not yet implemented. In a +multi-worker/replica deployment polling works via the shared DB, but `cancel` only reaches the +running task in the worker that started it (other workers just mark the stored row cancelled). A +background task interrupted by a server restart is reconciled to `failed` on the next startup. + ## Semantic LLM Cache NOMYO Router includes an optional semantic cache that serves repeated or semantically similar LLM requests from cache — no endpoint round-trip, no token cost, response in <10 ms. @@ -172,7 +207,7 @@ Each request is keyed on `model + system_prompt` (exact) combined with a weighte ### Cached routes -`/api/chat` · `/api/generate` · `/v1/chat/completions` · `/v1/completions` +`/api/chat` · `/api/generate` · `/v1/chat/completions` · `/v1/completions` · `/v1/responses` ### Cache management diff --git a/api/openai.py b/api/openai.py index 4662f50..ab24f54 100644 --- a/api/openai.py +++ b/api/openai.py @@ -46,6 +46,110 @@ from routing import choose_endpoint, decrement_usage router = APIRouter() +async def create_chat_with_retries(oclient, send_params, endpoint, model, tracking_model): + """Call ``chat.completions.create`` with the router's resilience retries. + + Encapsulates the recovery ladder shared by the chat-completions handler and + the translated ``/v1/responses`` path: + + * ``does not support tools`` → retry without ``tools`` + * llama-server context exhaustion → sliding-window message trim, with a + second retry that also strips ``tools``/``tool_choice`` + * backend connection failure → mark (endpoint, model) unhealthy so the next + request reroutes, then re-raise + * ``image input is not supported`` → strip images and retry + + On unrecoverable failure the endpoint usage counter is decremented and the + exception is re-raised. Returns the established async generator / response. + """ + config = get_config() + try: + async_gen = await oclient.chat.completions.create(**send_params) + except Exception as e: + _e_str = str(e) + _is_ctx_err = "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str + print(f"[ochat] caught={type(e).__name__} ctx={_is_ctx_err} msg={_e_str[:120]}", flush=True) + if "does not support tools" in _e_str: + # Model doesn't support tools — retry without them + print(f"[ochat] retry: no tools", flush=True) + try: + params_without_tools = {k: v for k, v in send_params.items() if k != "tools"} + async_gen = await oclient.chat.completions.create(**params_without_tools) + except Exception: + await decrement_usage(endpoint, tracking_model) + raise + elif _is_ctx_err: + # Backend context limit hit — apply sliding-window trim (context-shift at message level) + err_body = getattr(e, "body", {}) or {} + err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {} + n_ctx_limit = err_detail.get("n_ctx", 0) + actual_tokens = err_detail.get("n_prompt_tokens", 0) + # Fallback: parse from string if body parsing yielded nothing (SDK may not parse llama-server errors) + if not n_ctx_limit: + import re as _re + _m = _re.search(r"'n_ctx':\s*(\d+)", _e_str) + if _m: + n_ctx_limit = int(_m.group(1)) + _m = _re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str) + if _m: + actual_tokens = int(_m.group(1)) + print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True) + if not n_ctx_limit: + await decrement_usage(endpoint, tracking_model) + raise + if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT: + _endpoint_nctx[(endpoint, model)] = n_ctx_limit + + msgs_to_trim = send_params.get("messages", []) + try: + cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens) + trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target) + except Exception as _helper_exc: + print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True) + await decrement_usage(endpoint, tracking_model) + raise + dropped = len(msgs_to_trim) - len(trimmed_messages) + print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True) + try: + async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages}) + print(f"[ctx-trim] retry-1 ok", flush=True) + except Exception as e2: + _e2_str = str(e2) + if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str: + # Still too large — tool definitions likely consuming too many tokens, strip them too + print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True) + params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")} + try: + async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages}) + print(f"[ctx-trim] retry-2 ok", flush=True) + except Exception: + await decrement_usage(endpoint, tracking_model) + raise + else: + await decrement_usage(endpoint, tracking_model) + raise + elif _is_backend_connection_error(e): + # Upstream connection failed (e.g. llama-server in router mode + # whose delegated worker died). Mark (endpoint, model) so the + # next request reroutes; the client will retry this one. + print(f"[ochat] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True) + await _mark_backend_unhealthy(endpoint, model, _e_str) + await decrement_usage(endpoint, tracking_model) + raise + elif "image input is not supported" in _e_str: + # Model doesn't support images — strip and retry + print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages") + try: + async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))}) + except Exception: + await decrement_usage(endpoint, tracking_model) + raise + else: + await decrement_usage(endpoint, tracking_model) + raise + return async_gen + + @router.post("/v1/embeddings") async def openai_embedding_proxy(request: Request): """ @@ -260,90 +364,7 @@ async def openai_chat_completions_proxy(request: Request): _dropped = len(_pre_msgs) - len(_pre_trimmed) print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True) send_params = {**send_params, "messages": _pre_trimmed} - try: - async_gen = await oclient.chat.completions.create(**send_params) - except Exception as e: - _e_str = str(e) - _is_ctx_err = "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str - print(f"[ochat] caught={type(e).__name__} ctx={_is_ctx_err} msg={_e_str[:120]}", flush=True) - if "does not support tools" in _e_str: - # Model doesn't support tools — retry without them - print(f"[ochat] retry: no tools", flush=True) - try: - params_without_tools = {k: v for k, v in send_params.items() if k != "tools"} - async_gen = await oclient.chat.completions.create(**params_without_tools) - except Exception: - await decrement_usage(endpoint, tracking_model) - raise - elif _is_ctx_err: - # Backend context limit hit — apply sliding-window trim (context-shift at message level) - err_body = getattr(e, "body", {}) or {} - err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {} - n_ctx_limit = err_detail.get("n_ctx", 0) - actual_tokens = err_detail.get("n_prompt_tokens", 0) - # Fallback: parse from string if body parsing yielded nothing (SDK may not parse llama-server errors) - if not n_ctx_limit: - import re as _re - _m = _re.search(r"'n_ctx':\s*(\d+)", _e_str) - if _m: - n_ctx_limit = int(_m.group(1)) - _m = _re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str) - if _m: - actual_tokens = int(_m.group(1)) - print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True) - if not n_ctx_limit: - await decrement_usage(endpoint, tracking_model) - raise - if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT: - _endpoint_nctx[(endpoint, model)] = n_ctx_limit - - msgs_to_trim = send_params.get("messages", []) - try: - cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens) - trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target) - except Exception as _helper_exc: - print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True) - await decrement_usage(endpoint, tracking_model) - raise - dropped = len(msgs_to_trim) - len(trimmed_messages) - print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True) - try: - async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages}) - print(f"[ctx-trim] retry-1 ok", flush=True) - except Exception as e2: - _e2_str = str(e2) - if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str: - # Still too large — tool definitions likely consuming too many tokens, strip them too - print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True) - params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")} - try: - async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages}) - print(f"[ctx-trim] retry-2 ok", flush=True) - except Exception: - await decrement_usage(endpoint, tracking_model) - raise - else: - await decrement_usage(endpoint, tracking_model) - raise - elif _is_backend_connection_error(e): - # Upstream connection failed (e.g. llama-server in router mode - # whose delegated worker died). Mark (endpoint, model) so the - # next request reroutes; the client will retry this one. - print(f"[ochat] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True) - await _mark_backend_unhealthy(endpoint, model, _e_str) - await decrement_usage(endpoint, tracking_model) - raise - elif "image input is not supported" in _e_str: - # Model doesn't support images — strip and retry - print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages") - try: - async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))}) - except Exception: - await decrement_usage(endpoint, tracking_model) - raise - else: - await decrement_usage(endpoint, tracking_model) - raise + async_gen = await create_chat_with_retries(oclient, send_params, endpoint, model, tracking_model) # 4. Async generator — only streams the already-established async_gen async def stream_ochat_response(): diff --git a/api/responses.py b/api/responses.py new file mode 100644 index 0000000..0a803d3 --- /dev/null +++ b/api/responses.py @@ -0,0 +1,398 @@ +"""OpenAI **Responses API** routes (``/v1/responses`` and its retrieve / delete / +cancel companions). + +The router speaks Chat Completions to its backends, so this layer: + + * **native** (external OpenAI): forwards via ``oclient.responses.create`` and + streams the SDK's typed events straight back, rewriting the response ``id`` to + a router-owned ``resp_`` id so chaining stays router-managed. + * **translated** (Ollama / llama-server): converts the request to chat, reuses + the resilient ``create_chat_with_retries`` ladder, and re-emits the result as + Responses typed SSE events (``requests/responses.py``). + +State (``store`` / ``previous_response_id``) and background-task status live in the +router's SQLite DB (``db.py``); the router mints and owns every response id. +""" +import asyncio +import secrets +import time + +import orjson +from fastapi import APIRouter, HTTPException, Request +from starlette.responses import JSONResponse, StreamingResponse + +from cache import get_llm_cache +from config import get_config +from db import get_db +from fingerprint import _conversation_fingerprint +from state import token_queue, default_headers +from backends.normalize import is_ext_openai_endpoint +from backends.sessions import _make_openai_client +from routing import choose_endpoint, decrement_usage +from api.openai import create_chat_with_retries +from requests.responses import ( + ChatToResponsesStream, + build_response_object, + chat_message_to_output_items, + messages_to_responses_input, + responses_input_to_messages, + responses_object_to_sse, + tools_responses_to_chat, + usage_chat_to_responses, +) + +router = APIRouter() + +# In-memory handles for background tasks so /cancel can reach a running task in +# this worker. Cross-worker cancel falls back to marking the DB row cancelled. +_background_tasks: dict[str, asyncio.Task] = {} + + +# --------------------------------------------------------------------------- +# small helpers +# --------------------------------------------------------------------------- +def _usage_tokens(usage): + """Return ``(prompt, completion)`` tokens from a chat- or responses-shaped usage.""" + if not usage: + return 0, 0 + if "input_tokens" in usage: + return usage.get("input_tokens", 0) or 0, usage.get("output_tokens", 0) or 0 + return usage.get("prompt_tokens", 0) or 0, usage.get("completion_tokens", 0) or 0 + + +def _text_format_to_response_format(text): + """Map Responses ``text.format`` → Chat Completions ``response_format`` (best effort).""" + if not isinstance(text, dict): + return None + fmt = text.get("format") + if not isinstance(fmt, dict): + return None + ftype = fmt.get("type") + if ftype == "json_object": + return {"type": "json_object"} + if ftype == "json_schema": + return {"type": "json_schema", "json_schema": { + k: fmt[k] for k in ("name", "schema", "strict", "description") if k in fmt + }} + return None + + +def _native_usage_from_response(data): + return data.get("usage") + + +async def _resolve_history_messages(previous_response_id): + """Rebuild prior-turn chat messages from the stored response chain.""" + if not previous_response_id: + return [] + db = get_db() + chain = await db.get_response_chain(previous_response_id) + messages = [] + for turn in chain: + # Each turn stored the chat messages that produced it + its output items. + for m in turn.get("input_messages") or []: + messages.append(m) + for item in turn.get("output_items") or []: + if item.get("type") == "message": + text = "".join( + p.get("text", "") for p in item.get("content") or [] + if p.get("type") == "output_text" + ) + if text: + messages.append({"role": "assistant", "content": text}) + elif item.get("type") == "function_call": + messages.append({ + "role": "assistant", "content": None, + "tool_calls": [{"id": item.get("call_id"), "type": "function", + "function": {"name": item.get("name"), + "arguments": item.get("arguments", "")}}], + }) + return messages + + +class _NativeStream: + """Re-emit an SDK Responses event stream, rewriting the response id and + capturing the final output/usage for storage.""" + + def __init__(self, response_id): + self.response_id = response_id + self.output_items = [] + self.usage = None + + async def events(self, sdk_gen): + async for event in sdk_gen: + data = event.model_dump() if hasattr(event, "model_dump") else event + etype = data.get("type", "") + resp = data.get("response") + if isinstance(resp, dict) and resp.get("id"): + resp["id"] = self.response_id + if etype in ("response.completed", "response.incomplete", "response.failed") \ + and isinstance(resp, dict): + self.output_items = resp.get("output", []) or [] + self.usage = resp.get("usage") + yield f"event: {etype}\ndata: {orjson.dumps(data).decode('utf-8')}\n\n".encode("utf-8") + + +# --------------------------------------------------------------------------- +# backend execution (non-streaming, used by background + non-stream sync) +# --------------------------------------------------------------------------- +async def _run_to_completion(*, native, oclient, endpoint, model, tracking_model, + send_params, native_params): + """Drive the backend to completion (no client streaming). + + Returns ``(output_items, usage)`` where usage is responses-shaped. Caller is + responsible for ``decrement_usage`` (translated failures self-decrement inside + ``create_chat_with_retries``).""" + if native: + resp_obj = await oclient.responses.create(stream=False, **native_params) + data = resp_obj.model_dump() + return data.get("output", []) or [], data.get("usage") + async_gen = await create_chat_with_retries(oclient, {**send_params, "stream": False}, + endpoint, model, tracking_model) + message = async_gen.choices[0].message.model_dump() if async_gen.choices else {} + output_items = chat_message_to_output_items(message) + usage = usage_chat_to_responses( + async_gen.usage.model_dump() if async_gen.usage is not None else None + ) + return output_items, usage + + +# --------------------------------------------------------------------------- +# POST /v1/responses +# --------------------------------------------------------------------------- +@router.post("/v1/responses") +async def openai_responses_proxy(request: Request): + config = get_config() + try: + payload = orjson.loads((await request.body()).decode("utf-8")) + except orjson.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + model = payload.get("model") + input_data = payload.get("input") + instructions = payload.get("instructions") + stream = bool(payload.get("stream")) + store = payload.get("store", True) + background = bool(payload.get("background")) + previous_response_id = payload.get("previous_response_id") + tools = payload.get("tools") + metadata = payload.get("metadata") or {} + _cache_enabled = payload.get("nomyo", {}).get("cache", False) + + if not model: + raise HTTPException(status_code=400, detail="Missing required field 'model'") + if input_data is None: + raise HTTPException(status_code=400, detail="Missing required field 'input'") + if background and not store: + raise HTTPException(status_code=400, detail="background mode requires store=true") + + if ":latest" in model: + model = model.split(":latest")[0] + + # Resolve conversation: prior turns (from store) + this turn's input. + history = await _resolve_history_messages(previous_response_id) + messages = history + responses_input_to_messages(input_data, instructions) + + response_id = f"resp_{secrets.token_hex(24)}" + created_at = int(time.time()) + + # Cache lookup (foreground only) — before endpoint selection. + _cache = get_llm_cache() + if _cache is not None and _cache_enabled and not background: + cached = await _cache.get_chat("openai_responses", model, messages) + if cached is not None: + resp_obj = orjson.loads(cached) + resp_obj["id"] = response_id + if stream: + async def _served_cached(): + yield responses_object_to_sse(resp_obj) + return StreamingResponse(_served_cached(), media_type="text/event-stream") + return JSONResponse(content=resp_obj) + + # Endpoint selection (reserves a slot — must be released exactly once). + _affinity_key = _conversation_fingerprint(model, messages, None) + endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) + oclient = _make_openai_client(endpoint, default_headers=default_headers, + api_key=config.api_keys.get(endpoint, "no-key")) + native = is_ext_openai_endpoint(endpoint) + + # Build backend params for both shapes. + send_params = {"messages": messages, "model": model} + _opt = { + "temperature": payload.get("temperature"), + "top_p": payload.get("top_p"), + "max_tokens": payload.get("max_output_tokens"), + "tools": tools_responses_to_chat(tools), + "tool_choice": payload.get("tool_choice"), + "response_format": _text_format_to_response_format(payload.get("text")), + } + send_params.update({k: v for k, v in _opt.items() if v is not None}) + + native_instructions, native_input = messages_to_responses_input(messages) + native_params = {"model": model, "input": native_input, "store": False} + _nopt = { + "instructions": native_instructions, + "temperature": payload.get("temperature"), + "top_p": payload.get("top_p"), + "max_output_tokens": payload.get("max_output_tokens"), + "tools": tools, + "tool_choice": payload.get("tool_choice"), + "text": payload.get("text"), + "reasoning": payload.get("reasoning"), + } + native_params.update({k: v for k, v in _nopt.items() if v is not None}) + + async def _persist(status, output_items=None, usage=None, error=None, insert=False): + if not store: + return + db = get_db() + if insert: + await db.store_response( + response_id, previous_response_id=previous_response_id, model=model, + status=status, created_at=created_at, input_messages=messages, + output_items=output_items, usage=usage, instructions=instructions, error=error) + else: + await db.update_response_status(response_id, status, output_items=output_items, + usage=usage, error=error) + + async def _track(usage): + prompt_tok, comp_tok = _usage_tokens(usage) + if prompt_tok or comp_tok: + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) + + async def _cache_store(output_items, usage): + if _cache is None or not _cache_enabled or not output_items: + return + obj = build_response_object(response_id=response_id, model=model, + output_items=output_items, usage=usage, + created_at=created_at, + previous_response_id=previous_response_id, + instructions=instructions, metadata=metadata) + try: + await _cache.set_chat("openai_responses", model, messages, orjson.dumps(obj)) + except Exception as _ce: + print(f"[cache] set_chat (openai_responses) failed: {_ce}") + + # ---- background: run detached, return queued immediately -------------- + if background: + await _persist("queued", insert=True) + + async def _bg_run(): + try: + await get_db().update_response_status(response_id, "in_progress") + output_items, usage = await _run_to_completion( + native=native, oclient=oclient, endpoint=endpoint, model=model, + tracking_model=tracking_model, send_params=send_params, + native_params=native_params) + await _track(usage) + await _persist("completed", output_items=output_items, usage=usage) + await _cache_store(output_items, usage) + except asyncio.CancelledError: + await get_db().update_response_status(response_id, "cancelled") + raise + except Exception as e: + await get_db().update_response_status( + response_id, "failed", + error={"message": str(e)[:500], "type": type(e).__name__}) + finally: + await decrement_usage(endpoint, tracking_model) + _background_tasks.pop(response_id, None) + + task = asyncio.create_task(_bg_run()) + _background_tasks[response_id] = task + queued = build_response_object(response_id=response_id, model=model, output_items=[], + status="queued", created_at=created_at, + previous_response_id=previous_response_id, + instructions=instructions, metadata=metadata) + return JSONResponse(content=queued, status_code=200) + + # ---- streaming sync ---------------------------------------------------- + if stream: + if native: + source = await oclient.responses.create(stream=True, **native_params) + translator = _NativeStream(response_id) + else: + source = await create_chat_with_retries( + oclient, {**send_params, "stream": True, + "stream_options": {"include_usage": True}}, + endpoint, model, tracking_model) + translator = ChatToResponsesStream( + response_id, model, created_at=created_at, + previous_response_id=previous_response_id, instructions=instructions, + metadata=metadata) + + async def _stream(): + await _persist("in_progress", insert=True) + try: + async for sse in translator.events(source): + yield sse + await _track(translator.usage) + await _persist("completed", output_items=translator.output_items, + usage=translator.usage) + await _cache_store(translator.output_items, translator.usage) + finally: + await decrement_usage(endpoint, tracking_model) + + return StreamingResponse(_stream(), media_type="text/event-stream") + + # ---- non-streaming sync ------------------------------------------------ + try: + output_items, usage = await _run_to_completion( + native=native, oclient=oclient, endpoint=endpoint, model=model, + tracking_model=tracking_model, send_params=send_params, + native_params=native_params) + await _track(usage) + await _persist("completed", output_items=output_items, usage=usage, insert=True) + await _cache_store(output_items, usage) + finally: + await decrement_usage(endpoint, tracking_model) + + resp_obj = build_response_object( + response_id=response_id, model=model, output_items=output_items, usage=usage, + created_at=created_at, previous_response_id=previous_response_id, + instructions=instructions, metadata=metadata) + return JSONResponse(content=resp_obj) + + +# --------------------------------------------------------------------------- +# GET / DELETE / cancel +# --------------------------------------------------------------------------- +def _stored_to_response_object(row): + return build_response_object( + response_id=row["response_id"], model=row.get("model"), + output_items=row.get("output_items") or [], usage=row.get("usage"), + status=row.get("status") or "completed", created_at=row.get("created_at"), + previous_response_id=row.get("previous_response_id"), + instructions=row.get("instructions"), error=row.get("error")) + + +@router.get("/v1/responses/{response_id}") +async def get_response(response_id: str): + row = await get_db().get_response(response_id) + if row is None: + raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found") + return JSONResponse(content=_stored_to_response_object(row)) + + +@router.delete("/v1/responses/{response_id}") +async def delete_response(response_id: str): + deleted = await get_db().delete_response(response_id) + if not deleted: + raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found") + return JSONResponse(content={"id": response_id, "object": "response.deleted", "deleted": True}) + + +@router.post("/v1/responses/{response_id}/cancel") +async def cancel_response(response_id: str): + row = await get_db().get_response(response_id) + if row is None: + raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found") + # Cancel the running task if it lives in this worker; otherwise just mark the + # DB row so a polling client sees a terminal state (cross-worker limitation). + task = _background_tasks.get(response_id) + if task is not None and not task.done(): + task.cancel() + elif row.get("status") in ("queued", "in_progress"): + await get_db().update_response_status(response_id, "cancelled") + row = await get_db().get_response(response_id) + return JSONResponse(content=_stored_to_response_object(row)) diff --git a/db.py b/db.py index 3621144..03a5393 100644 --- a/db.py +++ b/db.py @@ -1,4 +1,4 @@ -import aiosqlite, asyncio +import aiosqlite, asyncio, orjson from typing import Optional from pathlib import Path from datetime import datetime, timezone @@ -75,6 +75,24 @@ class TokenDatabase: ''') await db.execute('CREATE INDEX IF NOT EXISTS idx_token_time_series_timestamp ON token_time_series(timestamp)') await db.execute('CREATE INDEX IF NOT EXISTS idx_token_time_series_model_ts ON token_time_series(model, timestamp)') + # Responses API state — the router owns conversation state for the + # /v1/responses family (store / previous_response_id) and tracks + # background-task status here so polling survives across workers. + await db.execute(''' + CREATE TABLE IF NOT EXISTS stored_responses ( + response_id TEXT PRIMARY KEY, + previous_response_id TEXT, + model TEXT, + status TEXT, + created_at INTEGER, + input_messages TEXT, + output_items TEXT, + usage TEXT, + instructions TEXT, + error TEXT + ) + ''') + await db.execute('CREATE INDEX IF NOT EXISTS idx_stored_responses_prev ON stored_responses(previous_response_id)') await db.commit() async def update_token_counts(self, endpoint: str, model: str, input_tokens: int, output_tokens: int): @@ -319,3 +337,158 @@ class TokenDatabase: await db.commit() return aggregated_count + + # ----------------------------------------------------------------- + # Responses API state (store / previous_response_id / background) + # ----------------------------------------------------------------- + @staticmethod + def _row_to_response(row) -> dict: + """Map a stored_responses row to a plain dict, decoding JSON columns.""" + def _loads(val): + if val is None: + return None + try: + return orjson.loads(val) + except (orjson.JSONDecodeError, TypeError): + return None + return { + 'response_id': row[0], + 'previous_response_id': row[1], + 'model': row[2], + 'status': row[3], + 'created_at': row[4], + 'input_messages': _loads(row[5]), + 'output_items': _loads(row[6]), + 'usage': _loads(row[7]), + 'instructions': row[8], + 'error': _loads(row[9]), + } + + async def store_response( + self, + response_id: str, + *, + previous_response_id: Optional[str], + model: str, + status: str, + created_at: int, + input_messages: list, + output_items: Optional[list] = None, + usage: Optional[dict] = None, + instructions: Optional[str] = None, + error: Optional[dict] = None, + ): + """Insert or replace a stored Responses-API response row.""" + db = await self._get_connection() + async with self._operation_lock: + await db.execute(''' + INSERT INTO stored_responses + (response_id, previous_response_id, model, status, created_at, + input_messages, output_items, usage, instructions, error) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT (response_id) DO UPDATE SET + previous_response_id = excluded.previous_response_id, + model = excluded.model, + status = excluded.status, + created_at = excluded.created_at, + input_messages = excluded.input_messages, + output_items = excluded.output_items, + usage = excluded.usage, + instructions = excluded.instructions, + error = excluded.error + ''', ( + response_id, previous_response_id, model, status, created_at, + orjson.dumps(input_messages).decode("utf-8"), + orjson.dumps(output_items).decode("utf-8") if output_items is not None else None, + orjson.dumps(usage).decode("utf-8") if usage is not None else None, + instructions, + orjson.dumps(error).decode("utf-8") if error is not None else None, + )) + await db.commit() + + async def update_response_status( + self, + response_id: str, + status: str, + *, + output_items: Optional[list] = None, + usage: Optional[dict] = None, + error: Optional[dict] = None, + ): + """Update the status (and optionally output/usage/error) of a stored response.""" + db = await self._get_connection() + async with self._operation_lock: + await db.execute(''' + UPDATE stored_responses + SET status = ?, + output_items = COALESCE(?, output_items), + usage = COALESCE(?, usage), + error = COALESCE(?, error) + WHERE response_id = ? + ''', ( + status, + orjson.dumps(output_items).decode("utf-8") if output_items is not None else None, + orjson.dumps(usage).decode("utf-8") if usage is not None else None, + orjson.dumps(error).decode("utf-8") if error is not None else None, + response_id, + )) + await db.commit() + + async def get_response(self, response_id: str) -> Optional[dict]: + """Return a stored response as a dict, or None if not found.""" + db = await self._get_connection() + async with self._operation_lock: + async with db.execute(''' + SELECT response_id, previous_response_id, model, status, created_at, + input_messages, output_items, usage, instructions, error + FROM stored_responses WHERE response_id = ? + ''', (response_id,)) as cursor: + row = await cursor.fetchone() + return self._row_to_response(row) if row is not None else None + + async def delete_response(self, response_id: str) -> bool: + """Delete a stored response. Returns True if a row was removed.""" + db = await self._get_connection() + async with self._operation_lock: + cursor = await db.execute( + 'DELETE FROM stored_responses WHERE response_id = ?', (response_id,) + ) + await db.commit() + return cursor.rowcount > 0 + + async def get_response_chain(self, response_id: str, max_turns: int = 50) -> list: + """Walk previous_response_id back to the root, returned oldest-first. + + Bounded to ``max_turns`` so a pathological chain cannot stall a request. + Missing links terminate the walk gracefully. + """ + chain: list = [] + seen: set = set() + current = response_id + while current and current not in seen and len(chain) < max_turns: + seen.add(current) + resp = await self.get_response(current) + if resp is None: + break + chain.append(resp) + current = resp.get('previous_response_id') + chain.reverse() + return chain + + async def fail_orphaned_responses(self) -> int: + """Mark non-terminal responses as failed (called on startup). + + A background task lives in a worker's event loop; a process restart loses + it while the DB row stays ``queued``/``in_progress`` forever. Reconcile + those to ``failed`` so polling clients get a terminal state. + """ + db = await self._get_connection() + async with self._operation_lock: + cursor = await db.execute(''' + UPDATE stored_responses + SET status = 'failed', + error = ? + WHERE status IN ('queued', 'in_progress') + ''', (orjson.dumps({"message": "Response interrupted by server restart", "type": "server_error"}).decode("utf-8"),)) + await db.commit() + return cursor.rowcount diff --git a/requests/responses.py b/requests/responses.py new file mode 100644 index 0000000..907dc42 --- /dev/null +++ b/requests/responses.py @@ -0,0 +1,492 @@ +"""Translation between the OpenAI **Responses API** and **Chat Completions**. + +The router speaks Chat Completions to every backend (Ollama, llama-server, +external OpenAI). To expose ``/v1/responses`` transparently on top of that, this +module converts in both directions: + + * request: Responses ``input`` / ``instructions`` / ``tools`` → chat ``messages`` / ``tools`` + * response: chat ``choices[0].message`` → Responses ``output`` items + * stream: chat completion deltas → Responses typed SSE events + +Pure functions / a stream-translator class — no I/O, mirroring the style of +``requests/messages.py``. The native passthrough path (external OpenAI) does not +use this module; it forwards the SDK's Responses objects directly. +""" +import secrets +import time + +import orjson + +from requests.messages import _accumulate_openai_tc_delta + + +# --------------------------------------------------------------------------- +# Request direction: Responses → Chat Completions +# --------------------------------------------------------------------------- +def _responses_content_to_chat(content): + """Convert a Responses message ``content`` into Chat Completions content. + + Collapses a single text part to a plain string (what most backends expect); + keeps a multimodal list otherwise. + """ + if content is None or isinstance(content, str): + return content + if not isinstance(content, list): + return str(content) + parts = [] + for p in content: + if not isinstance(p, dict): + parts.append({"type": "text", "text": str(p)}) + continue + ptype = p.get("type") + if ptype in ("input_text", "output_text", "text"): + parts.append({"type": "text", "text": p.get("text", "")}) + elif ptype in ("input_image", "image_url"): + url = p.get("image_url") + if isinstance(url, dict): + url = url.get("url") + if url: + parts.append({"type": "image_url", "image_url": {"url": url}}) + # input_file / refusal / reasoning parts have no chat equivalent → skip + if len(parts) == 1 and parts[0].get("type") == "text": + return parts[0]["text"] + return parts + + +def _input_item_to_message(item): + """Convert a single Responses ``input`` item to a chat message (or None).""" + if isinstance(item, str): + return {"role": "user", "content": item} + if not isinstance(item, dict): + return None + + itype = item.get("type") + + if itype == "function_call": + return { + "role": "assistant", + "content": None, + "tool_calls": [{ + "id": item.get("call_id") or item.get("id"), + "type": "function", + "function": { + "name": item.get("name"), + "arguments": item.get("arguments", ""), + }, + }], + } + + if itype == "function_call_output": + output = item.get("output", "") + if not isinstance(output, str): + output = orjson.dumps(output).decode("utf-8") + return { + "role": "tool", + "tool_call_id": item.get("call_id") or item.get("id"), + "content": output, + } + + if itype in ("reasoning",): + # No Chat Completions equivalent — drop. + return None + + # "message" item or a bare {role, content} chat-style item + role = item.get("role") + if role is None: + return None + return {"role": role, "content": _responses_content_to_chat(item.get("content"))} + + +def responses_input_to_messages(input_data, instructions=None): + """Build a Chat Completions ``messages`` list from Responses ``input``. + + ``instructions`` becomes a leading system message; a string ``input`` becomes + a single user message; a list ``input`` is mapped item-by-item. + """ + messages = [] + if instructions: + messages.append({"role": "system", "content": instructions}) + if input_data is None: + return messages + if isinstance(input_data, str): + messages.append({"role": "user", "content": input_data}) + return messages + if isinstance(input_data, list): + for item in input_data: + msg = _input_item_to_message(item) + if msg is not None: + messages.append(msg) + return messages + + +def _chat_content_to_responses_parts(content, assistant=False): + """Convert chat message content → Responses content parts.""" + text_type = "output_text" if assistant else "input_text" + if content is None: + return [] + if isinstance(content, str): + return [{"type": text_type, "text": content}] + parts = [] + for p in content if isinstance(content, list) else []: + if not isinstance(p, dict): + parts.append({"type": text_type, "text": str(p)}) + elif p.get("type") == "text": + parts.append({"type": text_type, "text": p.get("text", "")}) + elif p.get("type") == "image_url": + url = (p.get("image_url") or {}).get("url") + if url: + parts.append({"type": "input_image", "image_url": url}) + return parts + + +def messages_to_responses_input(messages): + """Convert chat messages → ``(instructions, Responses input items)``. + + Used for the native passthrough path: history that the router has resolved in + chat-message space is re-expressed as Responses ``input``. Leading/standalone + system messages are merged into ``instructions``. + """ + instructions_parts = [] + items = [] + for m in messages: + role = m.get("role") + if role == "system": + c = m.get("content") + instructions_parts.append(c if isinstance(c, str) else orjson.dumps(c).decode("utf-8")) + continue + if role == "tool": + out = m.get("content") + if not isinstance(out, str): + out = orjson.dumps(out).decode("utf-8") + items.append({"type": "function_call_output", + "call_id": m.get("tool_call_id"), "output": out}) + continue + if role == "assistant" and m.get("tool_calls"): + for tc in m["tool_calls"]: + fn = tc.get("function", {}) + items.append({"type": "function_call", "call_id": tc.get("id"), + "name": fn.get("name"), "arguments": fn.get("arguments", "")}) + if m.get("content"): + items.append({"role": "assistant", + "content": _chat_content_to_responses_parts(m["content"], assistant=True)}) + continue + items.append({"role": role, + "content": _chat_content_to_responses_parts(m.get("content"), + assistant=(role == "assistant"))}) + instructions = "\n\n".join(p for p in instructions_parts if p) or None + return instructions, items + + +def responses_object_to_sse(resp): + """Render a *finished* Responses object as a valid SSE event stream. + + Used to serve cache/store hits to streaming clients without a backend call. + """ + seq = [-1] + + def ev(etype, payload): + seq[0] += 1 + body = {"type": etype, "sequence_number": seq[0], **payload} + return f"event: {etype}\ndata: {orjson.dumps(body).decode('utf-8')}\n\n".encode("utf-8") + + parts_out = [] + in_progress = {**resp, "status": "in_progress", "output": [], "output_text": ""} + parts_out.append(ev("response.created", {"response": in_progress})) + parts_out.append(ev("response.in_progress", {"response": in_progress})) + for oi, item in enumerate(resp.get("output", [])): + parts_out.append(ev("response.output_item.added", + {"output_index": oi, "item": {**item, "status": "in_progress"}})) + if item.get("type") == "message": + for ci, part in enumerate(item.get("content", [])): + if part.get("type") == "output_text": + iid = item.get("id") + parts_out.append(ev("response.content_part.added", { + "item_id": iid, "output_index": oi, "content_index": ci, + "part": {"type": "output_text", "text": "", "annotations": []}})) + parts_out.append(ev("response.output_text.delta", { + "item_id": iid, "output_index": oi, "content_index": ci, + "delta": part.get("text", "")})) + parts_out.append(ev("response.output_text.done", { + "item_id": iid, "output_index": oi, "content_index": ci, + "text": part.get("text", "")})) + parts_out.append(ev("response.content_part.done", { + "item_id": iid, "output_index": oi, "content_index": ci, "part": part})) + parts_out.append(ev("response.output_item.done", {"output_index": oi, "item": item})) + parts_out.append(ev("response.completed", {"response": resp})) + return b"".join(parts_out) + + +def tools_responses_to_chat(tools): + """Map Responses tool definitions (flattened) → Chat Completions (nested).""" + if not tools: + return None + out = [] + for t in tools: + if isinstance(t, dict) and t.get("type") == "function" and "function" not in t: + fn = {k: t[k] for k in ("name", "description", "parameters", "strict") if k in t} + out.append({"type": "function", "function": fn}) + else: + out.append(t) + return out + + +# --------------------------------------------------------------------------- +# Response direction: Chat Completions → Responses +# --------------------------------------------------------------------------- +def _new_id(prefix): + return f"{prefix}_{secrets.token_hex(16)}" + + +def chat_message_to_output_items(message): + """Convert an assistant chat message (dict) into Responses output items.""" + items = [] + content = message.get("content") + if content: + items.append({ + "type": "message", + "id": _new_id("msg"), + "status": "completed", + "role": "assistant", + "content": [{"type": "output_text", "text": content, "annotations": []}], + }) + for tc in message.get("tool_calls") or []: + fn = tc.get("function", {}) + items.append({ + "type": "function_call", + "id": _new_id("fc"), + "call_id": tc.get("id"), + "name": fn.get("name"), + "arguments": fn.get("arguments", ""), + "status": "completed", + }) + return items + + +def usage_chat_to_responses(usage): + """Map chat usage ``{prompt_tokens, completion_tokens}`` → Responses usage.""" + if not usage: + return None + prompt = usage.get("prompt_tokens") or 0 + completion = usage.get("completion_tokens") or 0 + return { + "input_tokens": prompt, + "output_tokens": completion, + "total_tokens": usage.get("total_tokens") or (prompt + completion), + } + + +def output_items_to_text(output_items): + """Concatenate the ``output_text`` parts of all message items.""" + chunks = [] + for item in output_items or []: + if item.get("type") != "message": + continue + for part in item.get("content") or []: + if part.get("type") == "output_text": + chunks.append(part.get("text", "")) + return "".join(chunks) + + +def build_response_object( + *, + response_id, + model, + output_items=None, + usage=None, + status="completed", + created_at=None, + previous_response_id=None, + instructions=None, + error=None, + metadata=None, +): + """Assemble a full ``object:"response"`` body for a non-streaming reply.""" + output_items = output_items or [] + return { + "id": response_id, + "object": "response", + "created_at": created_at or int(time.time()), + "status": status, + "model": model, + "output": output_items, + "output_text": output_items_to_text(output_items), + "instructions": instructions, + "previous_response_id": previous_response_id, + "usage": usage_chat_to_responses(usage) if usage and "input_tokens" not in usage else usage, + "error": error, + "metadata": metadata or {}, + } + + +# --------------------------------------------------------------------------- +# Streaming direction: Chat Completions deltas → Responses typed SSE events +# --------------------------------------------------------------------------- +class ChatToResponsesStream: + """Translate a Chat Completions streaming generator into Responses events. + + Usage:: + + translator = ChatToResponsesStream(response_id, model, created_at) + async for sse_bytes in translator.events(chat_async_gen): + yield sse_bytes + # translator.output_items / translator.usage now populated for storage + + Emits the ordered event family + ``response.created`` → ``response.in_progress`` → + (``response.output_item.added`` → ``response.content_part.added`` → + ``response.output_text.delta``* → ``response.output_text.done`` → + ``response.content_part.done`` → ``response.output_item.done``) and/or + function-call item events → ``response.completed`` (carrying usage). + """ + + def __init__(self, response_id, model, created_at=None, + previous_response_id=None, instructions=None, metadata=None): + self.response_id = response_id + self.model = model + self.created_at = created_at or int(time.time()) + self.previous_response_id = previous_response_id + self.instructions = instructions + self.metadata = metadata or {} + self.seq = -1 + self.output_items = [] + self.usage = None + + def _snapshot(self, status, output=None): + return build_response_object( + response_id=self.response_id, + model=self.model, + output_items=output if output is not None else [], + usage=self.usage, + status=status, + created_at=self.created_at, + previous_response_id=self.previous_response_id, + instructions=self.instructions, + metadata=self.metadata, + ) + + def _event(self, etype, payload): + self.seq += 1 + body = {"type": etype, "sequence_number": self.seq, **payload} + return f"event: {etype}\ndata: {orjson.dumps(body).decode('utf-8')}\n\n".encode("utf-8") + + async def events(self, async_gen): + yield self._event("response.created", {"response": self._snapshot("in_progress")}) + yield self._event("response.in_progress", {"response": self._snapshot("in_progress")}) + + next_oi = 0 + # text message state + msg_item_id = None + msg_oi = None + text_parts = [] + # function-call state, keyed by chat tool_call index + tc_state = {} # idx -> {oi, item_id, call_id, name, args} + + async for chunk in async_gen: + usage = getattr(chunk, "usage", None) + if usage is not None: + self.usage = { + "prompt_tokens": usage.prompt_tokens or 0, + "completion_tokens": usage.completion_tokens or 0, + } + choices = getattr(chunk, "choices", None) + if not choices: + continue + delta = choices[0].delta + + content_piece = getattr(delta, "content", None) + if content_piece: + if msg_item_id is None: + msg_item_id = _new_id("msg") + msg_oi = next_oi + next_oi += 1 + item = { + "id": msg_item_id, "type": "message", "status": "in_progress", + "role": "assistant", "content": [], + } + yield self._event("response.output_item.added", + {"output_index": msg_oi, "item": item}) + yield self._event("response.content_part.added", { + "item_id": msg_item_id, "output_index": msg_oi, "content_index": 0, + "part": {"type": "output_text", "text": "", "annotations": []}, + }) + text_parts.append(content_piece) + yield self._event("response.output_text.delta", { + "item_id": msg_item_id, "output_index": msg_oi, "content_index": 0, + "delta": content_piece, + }) + + for tc in getattr(delta, "tool_calls", None) or []: + idx = tc.index + fn = getattr(tc, "function", None) + if idx not in tc_state: + item_id = _new_id("fc") + state = { + "oi": next_oi, "item_id": item_id, + "call_id": getattr(tc, "id", None) or _new_id("call"), + "name": (fn.name if fn else None), "args": "", + } + next_oi += 1 + tc_state[idx] = state + yield self._event("response.output_item.added", { + "output_index": state["oi"], + "item": { + "id": item_id, "type": "function_call", "status": "in_progress", + "call_id": state["call_id"], "name": state["name"], "arguments": "", + }, + }) + else: + state = tc_state[idx] + if getattr(tc, "id", None): + state["call_id"] = tc.id + if fn and fn.name: + state["name"] = fn.name + if fn and fn.arguments: + state["args"] += fn.arguments + yield self._event("response.function_call_arguments.delta", { + "item_id": state["item_id"], "output_index": state["oi"], + "delta": fn.arguments, + }) + + # finalize message item + if msg_item_id is not None: + full_text = "".join(text_parts) + yield self._event("response.output_text.done", { + "item_id": msg_item_id, "output_index": msg_oi, "content_index": 0, + "text": full_text, + }) + done_part = {"type": "output_text", "text": full_text, "annotations": []} + yield self._event("response.content_part.done", { + "item_id": msg_item_id, "output_index": msg_oi, "content_index": 0, + "part": done_part, + }) + msg_item = { + "id": msg_item_id, "type": "message", "status": "completed", + "role": "assistant", "content": [done_part], + } + yield self._event("response.output_item.done", + {"output_index": msg_oi, "item": msg_item}) + + # finalize function-call items (in output-index order) + tc_items = {} + for idx, state in tc_state.items(): + yield self._event("response.function_call_arguments.done", { + "item_id": state["item_id"], "output_index": state["oi"], + "arguments": state["args"], + }) + fc_item = { + "id": state["item_id"], "type": "function_call", "status": "completed", + "call_id": state["call_id"], "name": state["name"], "arguments": state["args"], + } + tc_items[state["oi"]] = fc_item + yield self._event("response.output_item.done", + {"output_index": state["oi"], "item": fc_item}) + + # assemble final output items ordered by output index + ordered = [] + if msg_item_id is not None: + ordered.append((msg_oi, msg_item)) + ordered.extend(tc_items.items()) + self.output_items = [item for _, item in sorted(ordered, key=lambda kv: kv[0])] + + yield self._event("response.completed", + {"response": self._snapshot("completed", self.output_items)}) diff --git a/router.py b/router.py index 676e42b..a2f9dd8 100644 --- a/router.py +++ b/router.py @@ -290,6 +290,8 @@ from api.management import router as management_router app.include_router(management_router) from api.openai import router as openai_router app.include_router(openai_router) +from api.responses import router as responses_router +app.include_router(responses_router) from api.ollama import router as ollama_router app.include_router(ollama_router) @@ -322,6 +324,13 @@ async def startup_event() -> None: db = TokenDatabase(config.db_path) await db.init_db() + # Reconcile Responses-API background tasks lost across a restart: their + # in-memory asyncio task is gone but the DB row may still read queued / + # in_progress, so mark those failed to give polling clients a terminal state. + _orphaned = await db.fail_orphaned_responses() + if _orphaned: + print(f"[startup] Marked {_orphaned} orphaned background response(s) as failed.") + # Load existing token counts from database async for count_entry in db.load_token_counts(): endpoint = count_entry['endpoint'] diff --git a/test/test_responses.py b/test/test_responses.py new file mode 100644 index 0000000..73d3ff2 --- /dev/null +++ b/test/test_responses.py @@ -0,0 +1,460 @@ +"""Tests for the OpenAI Responses API support (api/responses.py + requests/responses.py). + +Covers the pure translation layer, the translated (Ollama-style) and native +(external-OpenAI) backend paths, conversation storage / chaining, background mode, +and the retrieve / delete / cancel routes. +""" +import asyncio +from contextlib import ExitStack, contextmanager +from types import SimpleNamespace as NS +from unittest.mock import AsyncMock, MagicMock, patch + +import orjson +import pytest + +import router +from api import responses as api_responses +from requests import responses as rt + + +# ────────────────────────────────────────────────────────────────────────────── +# Pure translation unit tests (no app / no I/O) +# ────────────────────────────────────────────────────────────────────────────── + +class TestTranslationInputToMessages: + def test_string_input(self): + msgs = rt.responses_input_to_messages("hello") + assert msgs == [{"role": "user", "content": "hello"}] + + def test_instructions_become_system(self): + msgs = rt.responses_input_to_messages("hi", instructions="be brief") + assert msgs[0] == {"role": "system", "content": "be brief"} + assert msgs[1] == {"role": "user", "content": "hi"} + + def test_item_list_text_and_image(self): + items = [{ + "type": "message", "role": "user", + "content": [ + {"type": "input_text", "text": "describe"}, + {"type": "input_image", "image_url": "http://x/y.png"}, + ], + }] + msgs = rt.responses_input_to_messages(items) + assert msgs[0]["role"] == "user" + assert msgs[0]["content"] == [ + {"type": "text", "text": "describe"}, + {"type": "image_url", "image_url": {"url": "http://x/y.png"}}, + ] + + def test_single_text_part_collapses_to_string(self): + items = [{"type": "message", "role": "user", + "content": [{"type": "input_text", "text": "yo"}]}] + assert rt.responses_input_to_messages(items)[0]["content"] == "yo" + + def test_function_call_roundtrip(self): + items = [ + {"type": "function_call", "call_id": "c1", "name": "get", "arguments": "{\"x\":1}"}, + {"type": "function_call_output", "call_id": "c1", "output": "42"}, + ] + msgs = rt.responses_input_to_messages(items) + assert msgs[0]["role"] == "assistant" + assert msgs[0]["tool_calls"][0]["id"] == "c1" + assert msgs[0]["tool_calls"][0]["function"]["name"] == "get" + assert msgs[1] == {"role": "tool", "tool_call_id": "c1", "content": "42"} + + +class TestTranslationResponseDirection: + def test_chat_message_to_output_items_text(self): + items = rt.chat_message_to_output_items({"role": "assistant", "content": "hi there"}) + assert len(items) == 1 + assert items[0]["type"] == "message" + assert items[0]["content"][0] == {"type": "output_text", "text": "hi there", "annotations": []} + + def test_chat_message_to_output_items_tool_call(self): + items = rt.chat_message_to_output_items({ + "role": "assistant", "content": None, + "tool_calls": [{"id": "c9", "function": {"name": "f", "arguments": "{}"}}], + }) + assert items[0]["type"] == "function_call" + assert items[0]["call_id"] == "c9" + assert items[0]["name"] == "f" + + def test_usage_mapping(self): + u = rt.usage_chat_to_responses({"prompt_tokens": 7, "completion_tokens": 3}) + assert u == {"input_tokens": 7, "output_tokens": 3, "total_tokens": 10} + + def test_build_response_object_output_text(self): + items = rt.chat_message_to_output_items({"role": "assistant", "content": "abc"}) + obj = rt.build_response_object(response_id="resp_1", model="m", output_items=items) + assert obj["object"] == "response" + assert obj["output_text"] == "abc" + assert obj["status"] == "completed" + + def test_tools_responses_to_chat(self): + tools = [{"type": "function", "name": "f", "description": "d", "parameters": {"type": "object"}}] + chat_tools = rt.tools_responses_to_chat(tools) + assert chat_tools == [{"type": "function", + "function": {"name": "f", "description": "d", + "parameters": {"type": "object"}}}] + + def test_messages_to_responses_input(self): + instr, items = rt.messages_to_responses_input([ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "yo"}, + ]) + assert instr == "sys" + assert items[0] == {"role": "user", "content": [{"type": "input_text", "text": "hi"}]} + assert items[1] == {"role": "assistant", "content": [{"type": "output_text", "text": "yo"}]} + + +# ────────────────────────────────────────────────────────────────────────────── +# Fakes for backend generators +# ────────────────────────────────────────────────────────────────────────────── + +def _fake_completion(content="hello world", usage=(3, 5)): + msg = MagicMock() + msg.model_dump.return_value = {"role": "assistant", "content": content} + usage_obj = MagicMock() + usage_obj.model_dump.return_value = { + "prompt_tokens": usage[0], "completion_tokens": usage[1], "total_tokens": sum(usage)} + return NS(choices=[NS(message=msg)], usage=usage_obj) + + +def _chunk(content=None, tool_calls=None): + return NS(choices=[NS(delta=NS(content=content, tool_calls=tool_calls), + finish_reason=None)], usage=None) + + +def _usage_chunk(p, c): + return NS(choices=[], usage=NS(prompt_tokens=p, completion_tokens=c)) + + +def _text_chunks(): + async def _gen(): + yield _chunk(content="Hel") + yield _chunk(content="lo") + yield _usage_chunk(3, 5) + return _gen() + + +def _toolcall_chunks(): + tc0 = NS(index=0, id="call_1", function=NS(name="lookup", arguments='{"q":')) + tc1 = NS(index=0, id=None, function=NS(name=None, arguments='"hi"}')) + + async def _gen(): + yield _chunk(tool_calls=[tc0]) + yield _chunk(tool_calls=[tc1]) + yield _usage_chunk(4, 2) + return _gen() + + +class _FakeEvent: + def __init__(self, data): + self._data = data + + def model_dump(self): + return self._data + + +def _native_event_stream(): + async def _gen(): + yield _FakeEvent({"type": "response.created", + "response": {"id": "resp_openai", "status": "in_progress", "output": []}}) + yield _FakeEvent({"type": "response.output_text.delta", + "item_id": "msg_1", "output_index": 0, "delta": "hi"}) + yield _FakeEvent({"type": "response.completed", "response": { + "id": "resp_openai", "status": "completed", + "output": [{"type": "message", "role": "assistant", + "content": [{"type": "output_text", "text": "hi"}]}], + "usage": {"input_tokens": 2, "output_tokens": 1, "total_tokens": 3}}}) + return _gen() + + +def _sse_events(text): + """Split an SSE body into a list of (event_type, data_dict).""" + out = [] + for frame in text.strip().split("\n\n"): + if not frame.strip(): + continue + etype = data = None + for line in frame.splitlines(): + if line.startswith("event: "): + etype = line[len("event: "):] + elif line.startswith("data: "): + data = orjson.loads(line[len("data: "):]) + out.append((etype, data)) + return out + + +@contextmanager +def _enter(*cms): + """Enter a variable number of context managers (works with *unpacked tuples).""" + with ExitStack() as stack: + for cm in cms: + stack.enter_context(cm) + yield + + +def _patch_backend(native=False, endpoint="http://ollama:11434"): + """Context managers patching endpoint selection + client construction.""" + return ( + patch.object(api_responses, "choose_endpoint", + AsyncMock(return_value=(endpoint, "test-model:latest"))), + patch.object(api_responses, "decrement_usage", AsyncMock()), + patch.object(api_responses, "is_ext_openai_endpoint", return_value=native), + patch.object(api_responses, "_make_openai_client", return_value=MagicMock()), + patch.object(api_responses, "get_llm_cache", return_value=None), + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# Translated path (Ollama-style backend) +# ────────────────────────────────────────────────────────────────────────────── + +class TestTranslatedPath: + async def test_nonstream(self, client): + with _enter(*_patch_backend(native=False), + patch.object(api_responses, "create_chat_with_retries", + AsyncMock(return_value=_fake_completion("hello world")))): + resp = await client.post("/v1/responses", + json={"model": "test-model", "input": "hi", "store": False}) + assert resp.status_code == 200 + body = resp.json() + assert body["object"] == "response" + assert body["output_text"] == "hello world" + assert body["usage"] == {"input_tokens": 3, "output_tokens": 5, "total_tokens": 8} + assert body["id"].startswith("resp_") + + async def test_stream_event_sequence(self, client): + with _enter(*_patch_backend(native=False), + patch.object(api_responses, "create_chat_with_retries", + AsyncMock(return_value=_text_chunks()))): + resp = await client.post("/v1/responses", + json={"model": "test-model", "input": "hi", + "stream": True, "store": False}) + assert resp.status_code == 200 + assert resp.headers["content-type"].startswith("text/event-stream") + events = _sse_events(resp.content.decode()) + types = [e[0] for e in events] + assert types[0] == "response.created" + assert "response.output_text.delta" in types + assert types[-1] == "response.completed" + # concatenated deltas reconstruct the content + deltas = "".join(d["delta"] for t, d in events if t == "response.output_text.delta") + assert deltas == "Hello" + # completed event carries usage + completed = [d for t, d in events if t == "response.completed"][0] + assert completed["response"]["usage"]["input_tokens"] == 3 + + async def test_stream_tool_calls(self, client): + with _enter(*_patch_backend(native=False), + patch.object(api_responses, "create_chat_with_retries", + AsyncMock(return_value=_toolcall_chunks()))): + resp = await client.post("/v1/responses", + json={"model": "test-model", "input": "lookup hi", + "stream": True, "store": False}) + events = _sse_events(resp.content.decode()) + types = [e[0] for e in events] + assert "response.function_call_arguments.delta" in types + assert "response.function_call_arguments.done" in types + args = "".join(d["delta"] for t, d in events + if t == "response.function_call_arguments.delta") + assert args == '{"q":"hi"}' + completed = [d for t, d in events if t == "response.completed"][0] + fc = [i for i in completed["response"]["output"] if i["type"] == "function_call"][0] + assert fc["name"] == "lookup" + assert fc["arguments"] == '{"q":"hi"}' + + +# ────────────────────────────────────────────────────────────────────────────── +# Native path (external OpenAI backend) +# ────────────────────────────────────────────────────────────────────────────── + +class TestNativePath: + async def test_nonstream_passthrough_rewrites_id(self, client): + oclient = MagicMock() + resp_obj = MagicMock() + resp_obj.model_dump.return_value = { + "id": "resp_openai", "status": "completed", + "output": [{"type": "message", "role": "assistant", + "content": [{"type": "output_text", "text": "native hi"}]}], + "usage": {"input_tokens": 2, "output_tokens": 3, "total_tokens": 5}} + oclient.responses.create = AsyncMock(return_value=resp_obj) + with (patch.object(api_responses, "choose_endpoint", + AsyncMock(return_value=("https://api.openai.com/v1", "gpt"))), + patch.object(api_responses, "decrement_usage", AsyncMock()), + patch.object(api_responses, "is_ext_openai_endpoint", return_value=True), + patch.object(api_responses, "_make_openai_client", return_value=oclient), + patch.object(api_responses, "get_llm_cache", return_value=None)): + resp = await client.post("/v1/responses", + json={"model": "gpt", "input": "hi", "store": False}) + body = resp.json() + assert body["output_text"] == "native hi" + assert body["id"].startswith("resp_") and body["id"] != "resp_openai" + # native call must not delegate state upstream + assert oclient.responses.create.call_args.kwargs["store"] is False + + async def test_stream_passthrough(self, client): + oclient = MagicMock() + oclient.responses.create = AsyncMock(return_value=_native_event_stream()) + with (patch.object(api_responses, "choose_endpoint", + AsyncMock(return_value=("https://api.openai.com/v1", "gpt"))), + patch.object(api_responses, "decrement_usage", AsyncMock()), + patch.object(api_responses, "is_ext_openai_endpoint", return_value=True), + patch.object(api_responses, "_make_openai_client", return_value=oclient), + patch.object(api_responses, "get_llm_cache", return_value=None)): + resp = await client.post("/v1/responses", + json={"model": "gpt", "input": "hi", + "stream": True, "store": False}) + events = _sse_events(resp.content.decode()) + # the completed event's response id is rewritten to the router id + completed = [d for t, d in events if t == "response.completed"][0] + assert completed["response"]["id"].startswith("resp_") + assert completed["response"]["id"] != "resp_openai" + + +# ────────────────────────────────────────────────────────────────────────────── +# Storage + chaining + retrieve/delete +# ────────────────────────────────────────────────────────────────────────────── + +class TestStorageAndChaining: + async def test_store_and_retrieve(self, client): + with _enter(*_patch_backend(native=False), + patch.object(api_responses, "create_chat_with_retries", + AsyncMock(return_value=_fake_completion("remembered")))): + created = await client.post("/v1/responses", + json={"model": "test-model", "input": "hi", "store": True}) + rid = created.json()["id"] + got = await client.get(f"/v1/responses/{rid}") + assert got.status_code == 200 + assert got.json()["output_text"] == "remembered" + + async def test_previous_response_id_rehydrates_history(self, client): + # First turn + with _enter(*_patch_backend(native=False), + patch.object(api_responses, "create_chat_with_retries", + AsyncMock(return_value=_fake_completion("turn-one")))): + first = await client.post("/v1/responses", + json={"model": "test-model", "input": "first?", "store": True}) + rid = first.json()["id"] + + # Second turn references the first — capture the messages sent to the backend + capture = AsyncMock(return_value=_fake_completion("turn-two")) + with _enter(*_patch_backend(native=False), + patch.object(api_responses, "create_chat_with_retries", capture)): + await client.post("/v1/responses", + json={"model": "test-model", "input": "second?", + "previous_response_id": rid, "store": True}) + sent_messages = capture.call_args.args[1]["messages"] + contents = [m.get("content") for m in sent_messages] + assert "first?" in contents # prior user turn replayed + assert "turn-one" in contents # prior assistant turn replayed + assert "second?" in contents # current turn appended + + async def test_delete(self, client): + with _enter(*_patch_backend(native=False), + patch.object(api_responses, "create_chat_with_retries", + AsyncMock(return_value=_fake_completion("bye")))): + created = await client.post("/v1/responses", + json={"model": "test-model", "input": "hi", "store": True}) + rid = created.json()["id"] + deleted = await client.delete(f"/v1/responses/{rid}") + assert deleted.status_code == 200 + assert deleted.json()["deleted"] is True + assert (await client.get(f"/v1/responses/{rid}")).status_code == 404 + + async def test_retrieve_missing_404(self, client): + assert (await client.get("/v1/responses/resp_missing")).status_code == 404 + + +# ────────────────────────────────────────────────────────────────────────────── +# Background mode +# ────────────────────────────────────────────────────────────────────────────── + +class TestBackgroundMode: + async def test_background_requires_store(self, client): + resp = await client.post("/v1/responses", + json={"model": "test-model", "input": "hi", + "background": True, "store": False}) + assert resp.status_code == 400 + + async def test_background_lifecycle(self, client): + with _enter(*_patch_backend(native=False), + patch.object(api_responses, "create_chat_with_retries", + AsyncMock(return_value=_fake_completion("bg-done")))): + created = await client.post("/v1/responses", + json={"model": "test-model", "input": "hi", + "background": True, "store": True}) + assert created.status_code == 200 + assert created.json()["status"] == "queued" + rid = created.json()["id"] + # poll until terminal + status = None + for _ in range(100): + await asyncio.sleep(0.01) + got = await client.get(f"/v1/responses/{rid}") + status = got.json()["status"] + if status in ("completed", "failed", "cancelled"): + break + assert status == "completed" + assert got.json()["output_text"] == "bg-done" + + async def test_fail_orphaned_responses(self, client): + db = router.db + await db.store_response("resp_orphan", previous_response_id=None, model="m", + status="in_progress", created_at=0, input_messages=[]) + n = await db.fail_orphaned_responses() + assert n >= 1 + row = await db.get_response("resp_orphan") + assert row["status"] == "failed" + + +# ────────────────────────────────────────────────────────────────────────────── +# Cache parity +# ────────────────────────────────────────────────────────────────────────────── + +class _FakeCache: + def __init__(self, response_bytes): + self._resp = response_bytes + self.calls = [] + + async def get_chat(self, route, model, messages): + self.calls.append((route, model, messages)) + return self._resp + + +class TestCacheParity: + async def test_cache_hit_served_as_response(self, client): + cached = orjson.dumps(rt.build_response_object( + response_id="resp_cached", model="test-model", + output_items=rt.chat_message_to_output_items( + {"role": "assistant", "content": "from-cache"}))) + fake = _FakeCache(cached) + with (patch.object(api_responses, "get_llm_cache", return_value=fake), + patch.object(api_responses, "choose_endpoint", + AsyncMock(side_effect=AssertionError("backend must not be reached")))): + resp = await client.post("/v1/responses", + json={"model": "test-model", "input": "ping", + "store": False, "nomyo": {"cache": True}}) + assert resp.status_code == 200 + assert resp.json()["output_text"] == "from-cache" + assert fake.calls and fake.calls[0][0] == "openai_responses" + + async def test_cache_hit_served_as_sse(self, client): + cached = orjson.dumps(rt.build_response_object( + response_id="resp_cached", model="test-model", + output_items=rt.chat_message_to_output_items( + {"role": "assistant", "content": "from-cache"}))) + fake = _FakeCache(cached) + with (patch.object(api_responses, "get_llm_cache", return_value=fake), + patch.object(api_responses, "choose_endpoint", + AsyncMock(side_effect=AssertionError("backend must not be reached")))): + resp = await client.post("/v1/responses", + json={"model": "test-model", "input": "ping", + "stream": True, "store": False, + "nomyo": {"cache": True}}) + assert resp.headers["content-type"].startswith("text/event-stream") + events = _sse_events(resp.content.decode()) + deltas = "".join(d["delta"] for t, d in events if t == "response.output_text.delta") + assert deltas == "from-cache"