From 4f42f350a3797f92c868307ae0689ceaf47aa72e Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Tue, 23 Jun 2026 14:56:04 +0200 Subject: [PATCH] fix: stale reservation counters be releasing it only once --- api/ollama.py | 179 +++++++++++++++++++++++++++-------------------- api/openai.py | 156 ++++++++++++++++++++--------------------- api/responses.py | 152 +++++++++++++++++++++++----------------- requests/chat.py | 64 +++++++++-------- routing.py | 30 +++++--- 5 files changed, 319 insertions(+), 262 deletions(-) diff --git a/api/ollama.py b/api/ollama.py index 6ae6027..7585a5b 100644 --- a/api/ollama.py +++ b/api/ollama.py @@ -167,31 +167,38 @@ async def proxy(request: Request): _affinity_key = _conversation_fingerprint(model, None, prompt) endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) - use_openai = is_openai_compatible(endpoint) - if use_openai: - if ":latest" in model: - model = model.split(":latest") - model = model[0] - params = { - "prompt": prompt, - "model": model, - } - - optional_params = { - "stream": stream, - "max_tokens": options.get("num_predict") if options and "num_predict" in options else None, - "frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None, - "presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None, - "seed": options.get("seed") if options and "seed" in options else None, - "stop": options.get("stop") if options and "stop" in options else None, - "top_p": options.get("top_p") if options and "top_p" in options else None, - "temperature": options.get("temperature") if options and "temperature" in options else None, - "suffix": suffix, + # _guarded_stream's finally releases the reservation once we hand off; until + # then any failure during request building / client construction (including + # CancelledError on client disconnect) must release it or the counter leaks. + try: + use_openai = is_openai_compatible(endpoint) + if use_openai: + if ":latest" in model: + model = model.split(":latest") + model = model[0] + params = { + "prompt": prompt, + "model": model, } - params.update({k: v for k, v in optional_params.items() if v is not None}) - oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) - else: - client = get_ollama_client(endpoint) + + optional_params = { + "stream": stream, + "max_tokens": options.get("num_predict") if options and "num_predict" in options else None, + "frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None, + "presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None, + "seed": options.get("seed") if options and "seed" in options else None, + "stop": options.get("stop") if options and "stop" in options else None, + "top_p": options.get("top_p") if options and "top_p" in options else None, + "temperature": options.get("temperature") if options and "temperature" in options else None, + "suffix": suffix, + } + params.update({k: v for k, v in optional_params.items() if v is not None}) + oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) + else: + client = get_ollama_client(endpoint) + except BaseException: + await decrement_usage(endpoint, tracking_model) + raise # 4. Async generator body (error handling + cleanup handled by _guarded_stream) async def stream_generate_response(): @@ -336,59 +343,70 @@ async def chat_proxy(request: Request): opt = False _affinity_key = _conversation_fingerprint(model, messages, None) endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) - use_openai = is_openai_compatible(endpoint) - if use_openai: - if ":latest" in model: - model = model.split(":latest") - model = model[0] - if messages: - if any("images" in m for m in messages): - messages = await asyncio.to_thread(transform_images_to_data_urls, messages) - messages = transform_tool_calls_to_openai(messages) - messages = _strip_assistant_prefill(messages) - params = { - "messages": messages, - "model": model, - } - optional_params = { - "tools": tools, - "stream": stream, - "stream_options": {"include_usage": True} if stream else None, - "max_tokens": options.get("num_predict") if options and "num_predict" in options else None, - "frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None, - "presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None, - "seed": options.get("seed") if options and "seed" in options else None, - "stop": options.get("stop") if options and "stop" in options else None, - "top_p": options.get("top_p") if options and "top_p" in options else None, - "temperature": options.get("temperature") if options and "temperature" in options else None, - "logprobs": logprobs if logprobs is not None else (options.get("logprobs") if options and "logprobs" in options else None), - "top_logprobs": top_logprobs if top_logprobs is not None else (options.get("top_logprobs") if options and "top_logprobs" in options else None), - "response_format": {"type": "json_schema", "json_schema": _format} if _format is not None else None - } - params.update({k: v for k, v in optional_params.items() if v is not None}) - oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) - else: - client = get_ollama_client(endpoint) + # Releasing the reservation is owned by _guarded_stream's finally once we hand + # off to the streaming generator. Until then, any failure during request + # building / client construction (including CancelledError on client + # disconnect) must release it here or the usage counter leaks. + try: + use_openai = is_openai_compatible(endpoint) + if use_openai: + if ":latest" in model: + model = model.split(":latest") + model = model[0] + if messages: + if any("images" in m for m in messages): + messages = await asyncio.to_thread(transform_images_to_data_urls, messages) + messages = transform_tool_calls_to_openai(messages) + messages = _strip_assistant_prefill(messages) + params = { + "messages": messages, + "model": model, + } + optional_params = { + "tools": tools, + "stream": stream, + "stream_options": {"include_usage": True} if stream else None, + "max_tokens": options.get("num_predict") if options and "num_predict" in options else None, + "frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None, + "presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None, + "seed": options.get("seed") if options and "seed" in options else None, + "stop": options.get("stop") if options and "stop" in options else None, + "top_p": options.get("top_p") if options and "top_p" in options else None, + "temperature": options.get("temperature") if options and "temperature" in options else None, + "logprobs": logprobs if logprobs is not None else (options.get("logprobs") if options and "logprobs" in options else None), + "top_logprobs": top_logprobs if top_logprobs is not None else (options.get("top_logprobs") if options and "top_logprobs" in options else None), + "response_format": {"type": "json_schema", "json_schema": _format} if _format is not None else None + } + params.update({k: v for k, v in optional_params.items() if v is not None}) + oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) + else: + client = get_ollama_client(endpoint) + except BaseException: + await decrement_usage(endpoint, tracking_model) + raise # For OpenAI endpoints: make the API call in handler scope # (try/except inside async generators is unreliable with Starlette's streaming) start_ts = None async_gen = None if use_openai: start_ts = time.perf_counter() - # Proactive trim: only for small-ctx models we've already seen run out of space - _lookup_model = _normalize_llama_model_name(model) if is_llama_server(endpoint) else model - _known_nctx = _endpoint_nctx.get((endpoint, _lookup_model)) - if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT: - _pre_target = int((_known_nctx - _known_nctx // 4) / 1.2) - _pre_est = _count_message_tokens(params.get("messages", [])) - if _pre_est > _pre_target: - _pre_msgs = params.get("messages", []) - _pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target) - _dropped = len(_pre_msgs) - len(_pre_trimmed) - print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True) - params = {**params, "messages": _pre_trimmed} try: + # Proactive trim: only for small-ctx models we've already seen run out of space + _lookup_model = _normalize_llama_model_name(model) if is_llama_server(endpoint) else model + _known_nctx = _endpoint_nctx.get((endpoint, _lookup_model)) + if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT: + _pre_target = int((_known_nctx - _known_nctx // 4) / 1.2) + _pre_est = _count_message_tokens(params.get("messages", [])) + if _pre_est > _pre_target: + _pre_msgs = params.get("messages", []) + _pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target) + _dropped = len(_pre_msgs) - len(_pre_trimmed) + print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True) + params = {**params, "messages": _pre_trimmed} async_gen = await oclient.chat.completions.create(**params) + except asyncio.CancelledError: + await decrement_usage(endpoint, tracking_model) + raise except Exception as e: _e_str = str(e) print(f"[chat_proxy] caught {type(e).__name__}: {_e_str[:200]}") @@ -595,14 +613,21 @@ async def _handle_embedding_request( # 2. Endpoint logic endpoint, tracking_model = await choose_endpoint(model) - use_openai = is_openai_compatible(endpoint) - if use_openai: - if ":latest" in model: - model = model.split(":latest") - model = model[0] - client = _make_openai_client(endpoint, api_key=config.api_keys.get(endpoint, "no-key")) - else: - client = get_ollama_client(endpoint) + # _guarded_stream's finally releases the reservation once we hand off; until + # then any failure during client construction (including CancelledError on + # client disconnect) must release it or the counter leaks. + try: + use_openai = is_openai_compatible(endpoint) + if use_openai: + if ":latest" in model: + model = model.split(":latest") + model = model[0] + client = _make_openai_client(endpoint, api_key=config.api_keys.get(endpoint, "no-key")) + else: + client = get_ollama_client(endpoint) + except BaseException: + await decrement_usage(endpoint, tracking_model) + raise # 3. Async generator body (error handling + cleanup handled by _guarded_stream) async def stream_embedding_response(): diff --git a/api/openai.py b/api/openai.py index 1f0d22d..dc5432c 100644 --- a/api/openai.py +++ b/api/openai.py @@ -61,8 +61,10 @@ async def create_chat_with_retries(oclient, send_params, endpoint, model, tracki request reroutes, then re-raise * ``image input is not supported`` → strip images and retry - On unrecoverable failure the endpoint usage counter is decremented and the - exception is re-raised. Returns the established async generator / response. + The caller owns the usage reservation taken by ``choose_endpoint``: this + function never decrements it. On unrecoverable failure the exception is + re-raised so the caller's guard releases the slot exactly once. Returns the + established async generator / response. """ config = get_config() try: @@ -74,12 +76,8 @@ async def create_chat_with_retries(oclient, send_params, endpoint, model, tracki if "does not support tools" in _e_str: # Model doesn't support tools — retry without them print(f"[ochat] retry: no tools", flush=True) - try: - params_without_tools = {k: v for k, v in send_params.items() if k != "tools"} - async_gen = await oclient.chat.completions.create(**params_without_tools) - except Exception: - await decrement_usage(endpoint, tracking_model) - raise + params_without_tools = {k: v for k, v in send_params.items() if k != "tools"} + async_gen = await oclient.chat.completions.create(**params_without_tools) elif _is_ctx_err: # Backend context limit hit — apply sliding-window trim (context-shift at message level) err_body = getattr(e, "body", {}) or {} @@ -97,7 +95,6 @@ async def create_chat_with_retries(oclient, send_params, endpoint, model, tracki actual_tokens = int(_m.group(1)) print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True) if not n_ctx_limit: - await decrement_usage(endpoint, tracking_model) raise if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT: _endpoint_nctx[(endpoint, model)] = n_ctx_limit @@ -108,7 +105,6 @@ async def create_chat_with_retries(oclient, send_params, endpoint, model, tracki trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target) except Exception as _helper_exc: print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True) - await decrement_usage(endpoint, tracking_model) raise dropped = len(msgs_to_trim) - len(trimmed_messages) print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True) @@ -121,14 +117,9 @@ async def create_chat_with_retries(oclient, send_params, endpoint, model, tracki # Still too large — tool definitions likely consuming too many tokens, strip them too print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True) params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")} - try: - async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages}) - print(f"[ctx-trim] retry-2 ok", flush=True) - except Exception: - await decrement_usage(endpoint, tracking_model) - raise + async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages}) + print(f"[ctx-trim] retry-2 ok", flush=True) else: - await decrement_usage(endpoint, tracking_model) raise elif _is_backend_connection_error(e): # Upstream connection failed (e.g. llama-server in router mode @@ -136,18 +127,12 @@ async def create_chat_with_retries(oclient, send_params, endpoint, model, tracki # next request reroutes; the client will retry this one. print(f"[ochat] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True) await _mark_backend_unhealthy(endpoint, model, _e_str) - await decrement_usage(endpoint, tracking_model) raise elif "image input is not supported" in _e_str: # Model doesn't support images — strip and retry print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages") - try: - async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))}) - except Exception: - await decrement_usage(endpoint, tracking_model) - raise + async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))}) else: - await decrement_usage(endpoint, tracking_model) raise return async_gen @@ -195,13 +180,14 @@ async def openai_embedding_proxy(request: Request): # 2. Endpoint logic endpoint, tracking_model = await choose_endpoint(model) - if is_openai_compatible(endpoint): - api_key = config.api_keys.get(endpoint, "no-key") - else: - api_key = "ollama" - oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=api_key) - + # The finally below releases the reservation for every exit — success, error, + # or CancelledError — so client construction is kept inside the guarded block. try: + if is_openai_compatible(endpoint): + api_key = config.api_keys.get(endpoint, "no-key") + else: + api_key = "ollama" + oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=api_key) async_gen = await oclient.embeddings.create(input=doc, model=model) result = async_gen.model_dump() for item in result.get("data", []): @@ -350,23 +336,30 @@ async def openai_chat_completions_proxy(request: Request): # Make the API call in handler scope — try/except inside async generators is unreliable # with Starlette's streaming machinery, so we resolve errors here before the generator starts. - send_params = params - if not is_ext_openai_endpoint(endpoint): - resolved_msgs = await _normalize_images_in_messages(params.get("messages", [])) - send_params = {**params, "messages": resolved_msgs} - # Proactive trim: only for small-ctx models we've already seen run out of space - _lookup_model = _normalize_llama_model_name(model) if is_llama_server(endpoint) else model - _known_nctx = _endpoint_nctx.get((endpoint, _lookup_model)) - if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT: - _pre_target = int(((_known_nctx - _known_nctx // 4)) / 1.2) - _pre_est = _count_message_tokens(send_params.get("messages", [])) - if _pre_est > _pre_target: - _pre_msgs = send_params.get("messages", []) - _pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target) - _dropped = len(_pre_msgs) - len(_pre_trimmed) - print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True) - send_params = {**send_params, "messages": _pre_trimmed} - async_gen = await create_chat_with_retries(oclient, send_params, endpoint, model, tracking_model) + # The reservation taken by choose_endpoint is released by stream_ochat_response's finally + # once we hand off; until then, any failure here (including CancelledError on client + # disconnect during a cold model load) must release it or the counter leaks. + try: + send_params = params + if not is_ext_openai_endpoint(endpoint): + resolved_msgs = await _normalize_images_in_messages(params.get("messages", [])) + send_params = {**params, "messages": resolved_msgs} + # Proactive trim: only for small-ctx models we've already seen run out of space + _lookup_model = _normalize_llama_model_name(model) if is_llama_server(endpoint) else model + _known_nctx = _endpoint_nctx.get((endpoint, _lookup_model)) + if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT: + _pre_target = int(((_known_nctx - _known_nctx // 4)) / 1.2) + _pre_est = _count_message_tokens(send_params.get("messages", [])) + if _pre_est > _pre_target: + _pre_msgs = send_params.get("messages", []) + _pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target) + _dropped = len(_pre_msgs) - len(_pre_trimmed) + print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True) + send_params = {**send_params, "messages": _pre_trimmed} + async_gen = await create_chat_with_retries(oclient, send_params, endpoint, model, tracking_model) + except BaseException: + await decrement_usage(endpoint, tracking_model) + raise # 4. Async generator — only streams the already-established async_gen async def stream_ochat_response(): @@ -547,12 +540,17 @@ async def openai_completions_proxy(request: Request): # 2. Endpoint logic _affinity_key = _conversation_fingerprint(model, None, prompt) endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) - oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) # 3. Async generator that streams completions data and decrements the counter - # Make the API call in handler scope (try/except inside async generators is unreliable) + # Make the API call in handler scope (try/except inside async generators is unreliable). + # The reservation is released by stream_ocompletions_response's finally once we hand off; + # until then any failure here — including CancelledError on client disconnect — releases it. try: + oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) async_gen = await oclient.completions.create(**params) + except asyncio.CancelledError: + await decrement_usage(endpoint, tracking_model) + raise except Exception as e: if _is_backend_connection_error(e): print(f"[ocompl] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True) @@ -775,36 +773,38 @@ async def rerank_proxy(request: Request): ), ) - if ":latest" in model: - model = model.split(":latest")[0] - - # Build upstream rerank request body – forward only recognised fields - upstream_payload: dict = {"model": model, "query": query, "documents": documents} - for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"): - if optional_key in payload: - upstream_payload[optional_key] = payload[optional_key] - - # Determine upstream URL: - # llama-server / llama-swap expose /v1/rerank (base already contains /v1) - # External OpenAI endpoints expose /rerank under their /v1 base - if is_llama_server(endpoint): - # llama-server / llama-swap: endpoint may or may not already contain /v1 - if "/v1" in endpoint: - rerank_url = f"{endpoint}/rerank" - else: - rerank_url = f"{endpoint}/v1/rerank" - else: - # External OpenAI-compatible: ep2base gives us the /v1 base - rerank_url = f"{ep2base(endpoint)}/rerank" - - api_key = config.api_keys.get(endpoint, "no-key") - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}", - } - - client: aiohttp.ClientSession = get_session(endpoint) + # The finally below releases the reservation for every exit (success, error, + # or CancelledError), so request building and session lookup stay inside it. try: + if ":latest" in model: + model = model.split(":latest")[0] + + # Build upstream rerank request body – forward only recognised fields + upstream_payload: dict = {"model": model, "query": query, "documents": documents} + for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"): + if optional_key in payload: + upstream_payload[optional_key] = payload[optional_key] + + # Determine upstream URL: + # llama-server / llama-swap expose /v1/rerank (base already contains /v1) + # External OpenAI endpoints expose /rerank under their /v1 base + if is_llama_server(endpoint): + # llama-server / llama-swap: endpoint may or may not already contain /v1 + if "/v1" in endpoint: + rerank_url = f"{endpoint}/rerank" + else: + rerank_url = f"{endpoint}/v1/rerank" + else: + # External OpenAI-compatible: ep2base gives us the /v1 base + rerank_url = f"{ep2base(endpoint)}/rerank" + + api_key = config.api_keys.get(endpoint, "no-key") + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + } + + client: aiohttp.ClientSession = get_session(endpoint) async with client.post(rerank_url, json=upstream_payload, headers=headers) as resp: response_bytes = await resp.read() if resp.status >= 400: diff --git a/api/responses.py b/api/responses.py index 0a803d3..5627c98 100644 --- a/api/responses.py +++ b/api/responses.py @@ -140,9 +140,9 @@ async def _run_to_completion(*, native, oclient, endpoint, model, tracking_model send_params, native_params): """Drive the backend to completion (no client streaming). - Returns ``(output_items, usage)`` where usage is responses-shaped. Caller is - responsible for ``decrement_usage`` (translated failures self-decrement inside - ``create_chat_with_retries``).""" + Returns ``(output_items, usage)`` where usage is responses-shaped. The caller + owns the usage reservation and must release it (this function and + ``create_chat_with_retries`` never decrement).""" if native: resp_obj = await oclient.responses.create(stream=False, **native_params) data = resp_obj.model_dump() @@ -209,38 +209,46 @@ async def openai_responses_proxy(request: Request): return StreamingResponse(_served_cached(), media_type="text/event-stream") return JSONResponse(content=resp_obj) - # Endpoint selection (reserves a slot — must be released exactly once). + # Endpoint selection (reserves a slot — must be released exactly once). The + # release is owned by the per-branch finally (_bg_run / _stream / the + # non-streaming try) once we hand off; any failure during client/param + # construction (including CancelledError on client disconnect) must release + # it here or the usage counter leaks. _affinity_key = _conversation_fingerprint(model, messages, None) endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) - oclient = _make_openai_client(endpoint, default_headers=default_headers, - api_key=config.api_keys.get(endpoint, "no-key")) - native = is_ext_openai_endpoint(endpoint) + try: + oclient = _make_openai_client(endpoint, default_headers=default_headers, + api_key=config.api_keys.get(endpoint, "no-key")) + native = is_ext_openai_endpoint(endpoint) - # Build backend params for both shapes. - send_params = {"messages": messages, "model": model} - _opt = { - "temperature": payload.get("temperature"), - "top_p": payload.get("top_p"), - "max_tokens": payload.get("max_output_tokens"), - "tools": tools_responses_to_chat(tools), - "tool_choice": payload.get("tool_choice"), - "response_format": _text_format_to_response_format(payload.get("text")), - } - send_params.update({k: v for k, v in _opt.items() if v is not None}) + # Build backend params for both shapes. + send_params = {"messages": messages, "model": model} + _opt = { + "temperature": payload.get("temperature"), + "top_p": payload.get("top_p"), + "max_tokens": payload.get("max_output_tokens"), + "tools": tools_responses_to_chat(tools), + "tool_choice": payload.get("tool_choice"), + "response_format": _text_format_to_response_format(payload.get("text")), + } + send_params.update({k: v for k, v in _opt.items() if v is not None}) - native_instructions, native_input = messages_to_responses_input(messages) - native_params = {"model": model, "input": native_input, "store": False} - _nopt = { - "instructions": native_instructions, - "temperature": payload.get("temperature"), - "top_p": payload.get("top_p"), - "max_output_tokens": payload.get("max_output_tokens"), - "tools": tools, - "tool_choice": payload.get("tool_choice"), - "text": payload.get("text"), - "reasoning": payload.get("reasoning"), - } - native_params.update({k: v for k, v in _nopt.items() if v is not None}) + native_instructions, native_input = messages_to_responses_input(messages) + native_params = {"model": model, "input": native_input, "store": False} + _nopt = { + "instructions": native_instructions, + "temperature": payload.get("temperature"), + "top_p": payload.get("top_p"), + "max_output_tokens": payload.get("max_output_tokens"), + "tools": tools, + "tool_choice": payload.get("tool_choice"), + "text": payload.get("text"), + "reasoning": payload.get("reasoning"), + } + native_params.update({k: v for k, v in _nopt.items() if v is not None}) + except BaseException: + await decrement_usage(endpoint, tracking_model) + raise async def _persist(status, output_items=None, usage=None, error=None, insert=False): if not store: @@ -275,30 +283,37 @@ async def openai_responses_proxy(request: Request): # ---- background: run detached, return queued immediately -------------- if background: - await _persist("queued", insert=True) + # Once the task is created, _bg_run's finally owns the release. Guard the + # pre-task setup so a failure there (queued persist, task creation, or a + # client disconnect) still releases the reservation. + try: + await _persist("queued", insert=True) - async def _bg_run(): - try: - await get_db().update_response_status(response_id, "in_progress") - output_items, usage = await _run_to_completion( - native=native, oclient=oclient, endpoint=endpoint, model=model, - tracking_model=tracking_model, send_params=send_params, - native_params=native_params) - await _track(usage) - await _persist("completed", output_items=output_items, usage=usage) - await _cache_store(output_items, usage) - except asyncio.CancelledError: - await get_db().update_response_status(response_id, "cancelled") - raise - except Exception as e: - await get_db().update_response_status( - response_id, "failed", - error={"message": str(e)[:500], "type": type(e).__name__}) - finally: - await decrement_usage(endpoint, tracking_model) - _background_tasks.pop(response_id, None) + async def _bg_run(): + try: + await get_db().update_response_status(response_id, "in_progress") + output_items, usage = await _run_to_completion( + native=native, oclient=oclient, endpoint=endpoint, model=model, + tracking_model=tracking_model, send_params=send_params, + native_params=native_params) + await _track(usage) + await _persist("completed", output_items=output_items, usage=usage) + await _cache_store(output_items, usage) + except asyncio.CancelledError: + await get_db().update_response_status(response_id, "cancelled") + raise + except Exception as e: + await get_db().update_response_status( + response_id, "failed", + error={"message": str(e)[:500], "type": type(e).__name__}) + finally: + await decrement_usage(endpoint, tracking_model) + _background_tasks.pop(response_id, None) - task = asyncio.create_task(_bg_run()) + task = asyncio.create_task(_bg_run()) + except BaseException: + await decrement_usage(endpoint, tracking_model) + raise _background_tasks[response_id] = task queued = build_response_object(response_id=response_id, model=model, output_items=[], status="queued", created_at=created_at, @@ -308,18 +323,25 @@ async def openai_responses_proxy(request: Request): # ---- streaming sync ---------------------------------------------------- if stream: - if native: - source = await oclient.responses.create(stream=True, **native_params) - translator = _NativeStream(response_id) - else: - source = await create_chat_with_retries( - oclient, {**send_params, "stream": True, - "stream_options": {"include_usage": True}}, - endpoint, model, tracking_model) - translator = ChatToResponsesStream( - response_id, model, created_at=created_at, - previous_response_id=previous_response_id, instructions=instructions, - metadata=metadata) + # _stream's finally owns the release once iteration starts. Establishing + # the source can fail (or be cancelled) before that — release here, since + # create_chat_with_retries no longer self-decrements. + try: + if native: + source = await oclient.responses.create(stream=True, **native_params) + translator = _NativeStream(response_id) + else: + source = await create_chat_with_retries( + oclient, {**send_params, "stream": True, + "stream_options": {"include_usage": True}}, + endpoint, model, tracking_model) + translator = ChatToResponsesStream( + response_id, model, created_at=created_at, + previous_response_id=previous_response_id, instructions=instructions, + metadata=metadata) + except BaseException: + await decrement_usage(endpoint, tracking_model) + raise async def _stream(): await _persist("in_progress", insert=True) diff --git a/requests/chat.py b/requests/chat.py index 5bf1572..168e713 100644 --- a/requests/chat.py +++ b/requests/chat.py @@ -44,38 +44,40 @@ async def _make_chat_request(model: str, messages: list, tools=None, stream: boo """ config = get_config() endpoint, tracking_model = await choose_endpoint(model) # selects and atomically reserves - use_openai = is_openai_compatible(endpoint) - if use_openai: - if ":latest" in model: - model = model.split(":latest")[0] - if messages: - if any("images" in m for m in messages): - messages = await asyncio.to_thread(transform_images_to_data_urls, messages) - messages = transform_tool_calls_to_openai(messages) - messages = _strip_assistant_prefill(messages) - params = { - "messages": messages, - "model": model, - } - optional_params = { - "tools": tools, - "stream": stream, - "stream_options": {"include_usage": True} if stream else None, - "max_tokens": options.get("num_predict") if options and "num_predict" in options else None, - "frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None, - "presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None, - "seed": options.get("seed") if options and "seed" in options else None, - "stop": options.get("stop") if options and "stop" in options else None, - "top_p": options.get("top_p") if options and "top_p" in options else None, - "temperature": options.get("temperature") if options and "temperature" in options else None, - "response_format": {"type": "json_schema", "json_schema": format} if format is not None else None - } - params.update({k: v for k, v in optional_params.items() if v is not None}) - oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) - else: - client = ollama.AsyncClient(host=endpoint) - + # The finally below releases the reservation on every exit — success, error, + # or CancelledError — so request building and client construction stay inside it. try: + use_openai = is_openai_compatible(endpoint) + if use_openai: + if ":latest" in model: + model = model.split(":latest")[0] + if messages: + if any("images" in m for m in messages): + messages = await asyncio.to_thread(transform_images_to_data_urls, messages) + messages = transform_tool_calls_to_openai(messages) + messages = _strip_assistant_prefill(messages) + params = { + "messages": messages, + "model": model, + } + optional_params = { + "tools": tools, + "stream": stream, + "stream_options": {"include_usage": True} if stream else None, + "max_tokens": options.get("num_predict") if options and "num_predict" in options else None, + "frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None, + "presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None, + "seed": options.get("seed") if options and "seed" in options else None, + "stop": options.get("stop") if options and "stop" in options else None, + "top_p": options.get("top_p") if options and "top_p" in options else None, + "temperature": options.get("temperature") if options and "temperature" in options else None, + "response_format": {"type": "json_schema", "json_schema": format} if format is not None else None + } + params.update({k: v for k, v in optional_params.items() if v is not None}) + oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) + else: + client = ollama.AsyncClient(host=endpoint) + if use_openai: start_ts = time.perf_counter() try: diff --git a/routing.py b/routing.py index 0a1cc7f..9afe2f7 100644 --- a/routing.py +++ b/routing.py @@ -311,15 +311,23 @@ async def choose_endpoint(model: str, reserve: bool = True, if reserve: usage_counts[selected][tracking_model] += 1 snapshot = _capture_snapshot() - if snapshot is not None: - await _distribute_snapshot(snapshot) - # Record / refresh affinity *after* releasing usage_lock. - if reserve and config.conversation_affinity and affinity_key: - expires_at = time.monotonic() + config.conversation_affinity_ttl - async with _affinity_lock: - _affinity_map[affinity_key] = (selected, model, expires_at) - if len(_affinity_map) > _AFFINITY_MAX_ENTRIES: - now = time.monotonic() - for k in [k for k, v in _affinity_map.items() if v[2] < now]: - _affinity_map.pop(k, None) + # The slot is now reserved. Any failure (including CancelledError on client + # disconnect) between here and `return` would otherwise leak it — the caller + # never receives (endpoint, tracking_model) and so can never decrement it. + try: + if snapshot is not None: + await _distribute_snapshot(snapshot) + # Record / refresh affinity *after* releasing usage_lock. + if reserve and config.conversation_affinity and affinity_key: + expires_at = time.monotonic() + config.conversation_affinity_ttl + async with _affinity_lock: + _affinity_map[affinity_key] = (selected, model, expires_at) + if len(_affinity_map) > _AFFINITY_MAX_ENTRIES: + now = time.monotonic() + for k in [k for k, v in _affinity_map.items() if v[2] < now]: + _affinity_map.pop(k, None) + except BaseException: + if reserve: + await decrement_usage(selected, tracking_model) + raise return selected, tracking_model