diff --git a/api/ollama.py b/api/ollama.py index d62d6b4..5ce339c 100644 --- a/api/ollama.py +++ b/api/ollama.py @@ -61,6 +61,36 @@ from routing import choose_endpoint, decrement_usage router = APIRouter() +async def _handle_stream_error( + exc: Exception, endpoint: str, model: str, *, context: str +) -> bytes: + """Surface an upstream backend error transitively from a streaming generator. + + Errors raised while iterating a backend response (e.g. an ollama + ``ResponseError`` for a 504 Gateway Time-out) would otherwise escape the + StreamingResponse generator and be dumped by Starlette as an opaque + "Exception in ASGI application" traceback with no indication of which + endpoint/model failed. This logs the failure with that context — which is + what makes the many timeout errors greppable and analyzable — marks the + backend unhealthy when it is a connection-class failure, and returns a + terminal Ollama-format ``{"error": ...}`` line so the client receives a + meaningful error instead of a silently truncated stream. + """ + status_code = getattr(exc, "status_code", None) + err_msg = getattr(exc, "error", None) or str(exc) + print( + f"[{context}] upstream error from ({endpoint}, {model}) " + f"status={status_code} type={type(exc).__name__}: {str(err_msg)[:500]}", + flush=True, + ) + if _is_backend_connection_error(exc): + await _mark_backend_unhealthy(endpoint, model, str(err_msg)) + err_payload = {"error": str(err_msg)} + if status_code is not None: + err_payload["status_code"] = status_code + return orjson.dumps(err_payload) + b"\n" + + @router.post("/api/generate") async def proxy(request: Request): """ @@ -202,6 +232,13 @@ async def proxy(request: Request): except Exception as _ce: print(f"[cache] set_generate (non-streaming) failed: {_ce}") + except asyncio.CancelledError: + raise + except Exception as e: + try: + yield await _handle_stream_error(e, endpoint, model, context="generate_proxy") + except Exception: + pass finally: # Ensure counter is decremented even if an exception occurs await decrement_usage(endpoint, tracking_model) @@ -486,6 +523,13 @@ async def chat_proxy(request: Request): except Exception as _ce: print(f"[cache] set_chat (ollama_chat non-streaming) failed: {_ce}") + except asyncio.CancelledError: + raise + except Exception as e: + try: + yield await _handle_stream_error(e, endpoint, model, context="chat_proxy") + except Exception: + pass finally: # Ensure counter is decremented even if an exception occurs await decrement_usage(endpoint, tracking_model) @@ -550,6 +594,13 @@ async def embedding_proxy(request: Request): else: json_line = orjson.dumps(async_gen) yield json_line.encode("utf-8") + b"\n" + except asyncio.CancelledError: + raise + except Exception as e: + try: + yield await _handle_stream_error(e, endpoint, model, context="embeddings_proxy") + except Exception: + pass finally: # Ensure counter is decremented even if an exception occurs await decrement_usage(endpoint, tracking_model) @@ -614,6 +665,13 @@ async def embed_proxy(request: Request): else: json_line = orjson.dumps(async_gen) yield json_line.encode("utf-8") + b"\n" + except asyncio.CancelledError: + raise + except Exception as e: + try: + yield await _handle_stream_error(e, endpoint, model, context="embed_proxy") + except Exception: + pass finally: # Ensure counter is decremented even if an exception occurs await decrement_usage(endpoint, tracking_model)