diff --git a/api/ollama.py b/api/ollama.py index 5ce339c..6ebed5f 100644 --- a/api/ollama.py +++ b/api/ollama.py @@ -91,6 +91,30 @@ async def _handle_stream_error( return orjson.dumps(err_payload) + b"\n" +async def _guarded_stream(inner, *, endpoint: str, model: str, tracking_model: str, context: str): + """Wrap a per-route body generator with the shared streaming contract. + + Every ``/api/*`` streaming handler needs the same three guarantees around its + body: surface backend errors transitively (via :func:`_handle_stream_error`), + let client-disconnect cancellation propagate untouched, and always decrement + the usage counter. Centralising them here keeps the four bodies free of + duplicated ``try/except/finally`` scaffolding. + """ + try: + async for item in inner: + yield item + except asyncio.CancelledError: + raise + except Exception as e: + try: + yield await _handle_stream_error(e, endpoint, model, context=context) + except Exception: + pass + finally: + # Ensure counter is decremented even if an exception occurs + await decrement_usage(endpoint, tracking_model) + + @router.post("/api/generate") async def proxy(request: Request): """ @@ -165,87 +189,79 @@ async def proxy(request: Request): else: client = ollama.AsyncClient(host=endpoint) - # 4. Async generator that streams data and decrements the counter + # 4. Async generator body (error handling + cleanup handled by _guarded_stream) async def stream_generate_response(): - try: - if use_openai: - start_ts = time.perf_counter() - async_gen = await oclient.completions.create(**params) - else: - async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=_format, images=images, options=options, keep_alive=keep_alive) - if stream == True: - content_parts: list[str] = [] - async for chunk in async_gen: - if use_openai: - chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts) - prompt_tok = chunk.prompt_eval_count or 0 - comp_tok = chunk.eval_count or 0 - if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) - if hasattr(chunk, "model_dump_json"): - json_line = chunk.model_dump_json() - else: - json_line = orjson.dumps(chunk) - # Accumulate and store cache on done chunk — before yield so it always runs - if _cache is not None and _cache_enabled: - if getattr(chunk, "response", None): - content_parts.append(chunk.response) - if getattr(chunk, "done", False): - assembled = orjson.dumps({ - k: v for k, v in { - "model": getattr(chunk, "model", model), - "response": "".join(content_parts), - "done": True, - "done_reason": getattr(chunk, "done_reason", "stop") or "stop", - "prompt_eval_count": getattr(chunk, "prompt_eval_count", None), - "eval_count": getattr(chunk, "eval_count", None), - "total_duration": getattr(chunk, "total_duration", None), - "eval_duration": getattr(chunk, "eval_duration", None), - }.items() if v is not None - }) + b"\n" - try: - await _cache.set_generate(model, prompt, system or "", assembled) - except Exception as _ce: - print(f"[cache] set_generate (streaming) failed: {_ce}") - yield json_line.encode("utf-8") + b"\n" - else: + if use_openai: + start_ts = time.perf_counter() + async_gen = await oclient.completions.create(**params) + else: + async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=_format, images=images, options=options, keep_alive=keep_alive) + if stream == True: + content_parts: list[str] = [] + async for chunk in async_gen: if use_openai: - response = rechunk.openai_completion2ollama(async_gen, stream, start_ts) - response = response.model_dump_json() + chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts) + prompt_tok = chunk.prompt_eval_count or 0 + comp_tok = chunk.eval_count or 0 + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) + if hasattr(chunk, "model_dump_json"): + json_line = chunk.model_dump_json() else: - response = async_gen.model_dump_json() - prompt_tok = async_gen.prompt_eval_count or 0 - comp_tok = async_gen.eval_count or 0 - if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) - json_line = ( - response - if hasattr(async_gen, "model_dump_json") - else orjson.dumps(async_gen) - ) - cache_bytes = json_line.encode("utf-8") + b"\n" - yield cache_bytes - # Cache non-streaming response + json_line = orjson.dumps(chunk) + # Accumulate and store cache on done chunk — before yield so it always runs if _cache is not None and _cache_enabled: - try: - await _cache.set_generate(model, prompt, system or "", cache_bytes) - except Exception as _ce: - print(f"[cache] set_generate (non-streaming) failed: {_ce}") + if getattr(chunk, "response", None): + content_parts.append(chunk.response) + if getattr(chunk, "done", False): + assembled = orjson.dumps({ + k: v for k, v in { + "model": getattr(chunk, "model", model), + "response": "".join(content_parts), + "done": True, + "done_reason": getattr(chunk, "done_reason", "stop") or "stop", + "prompt_eval_count": getattr(chunk, "prompt_eval_count", None), + "eval_count": getattr(chunk, "eval_count", None), + "total_duration": getattr(chunk, "total_duration", None), + "eval_duration": getattr(chunk, "eval_duration", None), + }.items() if v is not None + }) + b"\n" + try: + await _cache.set_generate(model, prompt, system or "", assembled) + except Exception as _ce: + print(f"[cache] set_generate (streaming) failed: {_ce}") + yield json_line.encode("utf-8") + b"\n" + else: + if use_openai: + response = rechunk.openai_completion2ollama(async_gen, stream, start_ts) + response = response.model_dump_json() + else: + response = async_gen.model_dump_json() + prompt_tok = async_gen.prompt_eval_count or 0 + comp_tok = async_gen.eval_count or 0 + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) + json_line = ( + response + if hasattr(async_gen, "model_dump_json") + else orjson.dumps(async_gen) + ) + cache_bytes = json_line.encode("utf-8") + b"\n" + yield cache_bytes + # Cache non-streaming response + if _cache is not None and _cache_enabled: + try: + await _cache.set_generate(model, prompt, system or "", cache_bytes) + except Exception as _ce: + print(f"[cache] set_generate (non-streaming) failed: {_ce}") - except asyncio.CancelledError: - raise - except Exception as e: - try: - yield await _handle_stream_error(e, endpoint, model, context="generate_proxy") - except Exception: - pass - finally: - # Ensure counter is decremented even if an exception occurs - await decrement_usage(endpoint, tracking_model) - - # 5. Return a StreamingResponse backed by the generator + # 5. Return a StreamingResponse backed by the guarded generator return StreamingResponse( - stream_generate_response(), + _guarded_stream( + stream_generate_response(), + endpoint=endpoint, model=model, + tracking_model=tracking_model, context="generate_proxy", + ), media_type="application/json", ) @@ -425,128 +441,130 @@ async def chat_proxy(request: Request): await decrement_usage(endpoint, tracking_model) raise - # 3. Async generator that streams chat data and decrements the counter + # 3. Async generator body (error handling + cleanup handled by _guarded_stream) async def stream_chat_response(): - try: - # The chat method returns a generator of dicts (or GenerateResponse) - if use_openai: - _async_gen = async_gen # established in handler scope above - else: - if opt == True: - # Use the dedicated MOE helper function - _async_gen = await _make_moe_requests(model, messages, tools, think, _format, options, keep_alive) - else: - _async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive, logprobs=logprobs, top_logprobs=top_logprobs) - if stream == True: - tc_acc = {} # accumulate OpenAI tool-call deltas across chunks - content_parts: list[str] = [] - async for chunk in _async_gen: - if use_openai: - _accumulate_openai_tc_delta(chunk, tc_acc) - chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts) - # Inject fully-accumulated tool calls only into the final chunk - if chunk.done and tc_acc and chunk.message: - chunk.message.tool_calls = _build_ollama_tool_calls(tc_acc) - # `chunk` can be a dict or a pydantic model – dump to JSON safely - prompt_tok = chunk.prompt_eval_count or 0 - comp_tok = chunk.eval_count or 0 - if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) - if hasattr(chunk, "model_dump_json"): - json_line = chunk.model_dump_json() - else: - json_line = orjson.dumps(chunk) - # Accumulate and store cache on done chunk — before yield so it always runs - # Works for both Ollama-native and OpenAI-compatible backends; chunks are - # already converted to Ollama format by rechunk before this point. - if getattr(chunk, "done", False): - # Detect context exhaustion mid-generation for small-ctx models - _dr = getattr(chunk, "done_reason", None) - # Only cache when no max_tokens limit was set — otherwise - # finish_reason=length might just mean max_tokens was hit, - # not that the context window was exhausted. - _req_max_tok = ( - params.get("max_tokens") or params.get("max_completion_tokens") or params.get("num_predict") - if use_openai else - (options.get("num_predict") if options else None) - ) - if _dr == "length" and not _req_max_tok: - _pt = getattr(chunk, "prompt_eval_count", 0) or 0 - _ct = getattr(chunk, "eval_count", 0) or 0 - _inferred_nctx = _pt + _ct - if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT: - _endpoint_nctx[(endpoint, model)] = _inferred_nctx - print(f"[ctx-cache] done_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True) - if _cache is not None and not _is_moe and _cache_enabled: - if chunk.message and getattr(chunk.message, "content", None): - content_parts.append(chunk.message.content) - if getattr(chunk, "done", False): - assembled = orjson.dumps({ - k: v for k, v in { - "model": getattr(chunk, "model", model), - "created_at": (lambda ca: ca.isoformat() if hasattr(ca, "isoformat") else ca)(getattr(chunk, "created_at", None)), - "message": {"role": "assistant", "content": "".join(content_parts)}, - "done": True, - "done_reason": getattr(chunk, "done_reason", "stop") or "stop", - "prompt_eval_count": getattr(chunk, "prompt_eval_count", None), - "eval_count": getattr(chunk, "eval_count", None), - "total_duration": getattr(chunk, "total_duration", None), - "eval_duration": getattr(chunk, "eval_duration", None), - }.items() if v is not None - }) + b"\n" - try: - await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, assembled) - except Exception as _ce: - print(f"[cache] set_chat (ollama_chat streaming) failed: {_ce}") - yield json_line.encode("utf-8") + b"\n" + # The chat method returns a generator of dicts (or GenerateResponse) + if use_openai: + _async_gen = async_gen # established in handler scope above + else: + if opt == True: + # Use the dedicated MOE helper function + _async_gen = await _make_moe_requests(model, messages, tools, think, _format, options, keep_alive) else: + _async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive, logprobs=logprobs, top_logprobs=top_logprobs) + if stream == True: + tc_acc = {} # accumulate OpenAI tool-call deltas across chunks + content_parts: list[str] = [] + async for chunk in _async_gen: if use_openai: - response = rechunk.openai_chat_completion2ollama(_async_gen, stream, start_ts) - response = response.model_dump_json() + _accumulate_openai_tc_delta(chunk, tc_acc) + chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts) + # Inject fully-accumulated tool calls only into the final chunk + if chunk.done and tc_acc and chunk.message: + chunk.message.tool_calls = _build_ollama_tool_calls(tc_acc) + # `chunk` can be a dict or a pydantic model – dump to JSON safely + prompt_tok = chunk.prompt_eval_count or 0 + comp_tok = chunk.eval_count or 0 + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) + if hasattr(chunk, "model_dump_json"): + json_line = chunk.model_dump_json() else: - response = _async_gen.model_dump_json() - prompt_tok = _async_gen.prompt_eval_count or 0 - comp_tok = _async_gen.eval_count or 0 - if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) - json_line = ( - response - if hasattr(_async_gen, "model_dump_json") - else orjson.dumps(_async_gen) - ) - cache_bytes = json_line.encode("utf-8") + b"\n" - yield cache_bytes - # Cache non-streaming response (non-MOE; works for both Ollama and OpenAI backends) + json_line = orjson.dumps(chunk) + # Accumulate and store cache on done chunk — before yield so it always runs + # Works for both Ollama-native and OpenAI-compatible backends; chunks are + # already converted to Ollama format by rechunk before this point. + if getattr(chunk, "done", False): + # Detect context exhaustion mid-generation for small-ctx models + _dr = getattr(chunk, "done_reason", None) + # Only cache when no max_tokens limit was set — otherwise + # finish_reason=length might just mean max_tokens was hit, + # not that the context window was exhausted. + _req_max_tok = ( + params.get("max_tokens") or params.get("max_completion_tokens") or params.get("num_predict") + if use_openai else + (options.get("num_predict") if options else None) + ) + if _dr == "length" and not _req_max_tok: + _pt = getattr(chunk, "prompt_eval_count", 0) or 0 + _ct = getattr(chunk, "eval_count", 0) or 0 + _inferred_nctx = _pt + _ct + if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT: + _endpoint_nctx[(endpoint, model)] = _inferred_nctx + print(f"[ctx-cache] done_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True) if _cache is not None and not _is_moe and _cache_enabled: - try: - await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, cache_bytes) - except Exception as _ce: - print(f"[cache] set_chat (ollama_chat non-streaming) failed: {_ce}") + if chunk.message and getattr(chunk.message, "content", None): + content_parts.append(chunk.message.content) + if getattr(chunk, "done", False): + assembled = orjson.dumps({ + k: v for k, v in { + "model": getattr(chunk, "model", model), + "created_at": (lambda ca: ca.isoformat() if hasattr(ca, "isoformat") else ca)(getattr(chunk, "created_at", None)), + "message": {"role": "assistant", "content": "".join(content_parts)}, + "done": True, + "done_reason": getattr(chunk, "done_reason", "stop") or "stop", + "prompt_eval_count": getattr(chunk, "prompt_eval_count", None), + "eval_count": getattr(chunk, "eval_count", None), + "total_duration": getattr(chunk, "total_duration", None), + "eval_duration": getattr(chunk, "eval_duration", None), + }.items() if v is not None + }) + b"\n" + try: + await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, assembled) + except Exception as _ce: + print(f"[cache] set_chat (ollama_chat streaming) failed: {_ce}") + yield json_line.encode("utf-8") + b"\n" + else: + if use_openai: + response = rechunk.openai_chat_completion2ollama(_async_gen, stream, start_ts) + response = response.model_dump_json() + else: + response = _async_gen.model_dump_json() + prompt_tok = _async_gen.prompt_eval_count or 0 + comp_tok = _async_gen.eval_count or 0 + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) + json_line = ( + response + if hasattr(_async_gen, "model_dump_json") + else orjson.dumps(_async_gen) + ) + cache_bytes = json_line.encode("utf-8") + b"\n" + yield cache_bytes + # Cache non-streaming response (non-MOE; works for both Ollama and OpenAI backends) + if _cache is not None and not _is_moe and _cache_enabled: + try: + await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, cache_bytes) + except Exception as _ce: + print(f"[cache] set_chat (ollama_chat non-streaming) failed: {_ce}") - except asyncio.CancelledError: - raise - except Exception as e: - try: - yield await _handle_stream_error(e, endpoint, model, context="chat_proxy") - except Exception: - pass - finally: - # Ensure counter is decremented even if an exception occurs - await decrement_usage(endpoint, tracking_model) - - # 4. Return a StreamingResponse backed by the generator + # 4. Return a StreamingResponse backed by the guarded generator media_type = "application/x-ndjson" if stream else "application/json" return StreamingResponse( - stream_chat_response(), + _guarded_stream( + stream_chat_response(), + endpoint=endpoint, model=model, + tracking_model=tracking_model, context="chat_proxy", + ), media_type=media_type, ) -@router.post("/api/embeddings") -async def embedding_proxy(request: Request): - """ - Proxy an embedding request to Ollama and reply with embeddings. +async def _handle_embedding_request( + request: Request, + *, + input_field: str, + context: str, + make_native, + make_openai, +): + """Shared implementation for ``/api/embeddings`` and ``/api/embed``. + The two routes differ only in the request field they read (``prompt`` vs + ``input``), the ollama SDK method they call, and the OpenAI rechunk helper. + Those are passed in via ``input_field`` and the ``make_native`` / + ``make_openai`` callables; everything else — parsing, endpoint selection, + serialization, and the streaming error/cleanup contract — is shared. """ config = get_config() # 1. Parse and validate request @@ -555,77 +573,7 @@ async def embedding_proxy(request: Request): payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") - prompt = payload.get("prompt") - options = payload.get("options") - keep_alive = payload.get("keep_alive") - - if not model: - raise HTTPException( - status_code=400, detail="Missing required field 'model'" - ) - if not prompt: - raise HTTPException( - status_code=400, detail="Missing required field 'prompt'" - ) - except orjson.JSONDecodeError as e: - raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e - - # 2. Endpoint logic - endpoint, tracking_model = await choose_endpoint(model) - use_openai = is_openai_compatible(endpoint) - if use_openai: - if ":latest" in model: - model = model.split(":latest") - model = model[0] - client = _make_openai_client(endpoint, api_key=config.api_keys.get(endpoint, "no-key")) - else: - client = ollama.AsyncClient(host=endpoint) - # 3. Async generator that streams embedding data and decrements the counter - async def stream_embedding_response(): - try: - # The chat method returns a generator of dicts (or GenerateResponse) - if use_openai: - async_gen = await client.embeddings.create(input=prompt, model=model) - async_gen = rechunk.openai_embeddings2ollama(async_gen) - else: - async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive) - if hasattr(async_gen, "model_dump_json"): - json_line = async_gen.model_dump_json() - else: - json_line = orjson.dumps(async_gen) - yield json_line.encode("utf-8") + b"\n" - except asyncio.CancelledError: - raise - except Exception as e: - try: - yield await _handle_stream_error(e, endpoint, model, context="embeddings_proxy") - except Exception: - pass - finally: - # Ensure counter is decremented even if an exception occurs - await decrement_usage(endpoint, tracking_model) - - # 5. Return a StreamingResponse backed by the generator - return StreamingResponse( - stream_embedding_response(), - media_type="application/json", - ) - - -@router.post("/api/embed") -async def embed_proxy(request: Request): - """ - Proxy an embed request to Ollama and reply with embeddings. - - """ - config = get_config() - # 1. Parse and validate request - try: - body_bytes = await request.body() - payload = orjson.loads(body_bytes.decode("utf-8")) - - model = payload.get("model") - _input = payload.get("input") + value = payload.get(input_field) truncate = payload.get("truncate") options = payload.get("options") keep_alive = payload.get("keep_alive") @@ -634,9 +582,9 @@ async def embed_proxy(request: Request): raise HTTPException( status_code=400, detail="Missing required field 'model'" ) - if not _input: + if not value: raise HTTPException( - status_code=400, detail="Missing required field 'input'" + status_code=400, detail=f"Missing required field '{input_field}'" ) except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e @@ -651,38 +599,60 @@ async def embed_proxy(request: Request): client = _make_openai_client(endpoint, api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) - # 3. Async generator that streams embed data and decrements the counter - async def stream_embedding_response(): - try: - # The chat method returns a generator of dicts (or GenerateResponse) - if use_openai: - async_gen = await client.embeddings.create(input=_input, model=model) - async_gen = rechunk.openai_embed2ollama(async_gen, model) - else: - async_gen = await client.embed(model=model, input=_input, truncate=truncate, options=options, keep_alive=keep_alive) - if hasattr(async_gen, "model_dump_json"): - json_line = async_gen.model_dump_json() - else: - json_line = orjson.dumps(async_gen) - yield json_line.encode("utf-8") + b"\n" - except asyncio.CancelledError: - raise - except Exception as e: - try: - yield await _handle_stream_error(e, endpoint, model, context="embed_proxy") - except Exception: - pass - finally: - # Ensure counter is decremented even if an exception occurs - await decrement_usage(endpoint, tracking_model) - # 4. Return a StreamingResponse backed by the generator + # 3. Async generator body (error handling + cleanup handled by _guarded_stream) + async def stream_embedding_response(): + if use_openai: + response = await make_openai(client, model, value) + else: + response = await make_native(client, model, value, options, keep_alive, truncate) + if hasattr(response, "model_dump_json"): + json_line = response.model_dump_json() + else: + json_line = orjson.dumps(response) + yield json_line.encode("utf-8") + b"\n" + + # 4. Return a StreamingResponse backed by the guarded generator return StreamingResponse( - stream_embedding_response(), + _guarded_stream( + stream_embedding_response(), + endpoint=endpoint, model=model, + tracking_model=tracking_model, context=context, + ), media_type="application/json", ) +@router.post("/api/embeddings") +async def embedding_proxy(request: Request): + """Proxy an embedding request to Ollama and reply with embeddings.""" + async def _native(client, model, value, options, keep_alive, truncate): + return await client.embeddings(model=model, prompt=value, options=options, keep_alive=keep_alive) + + async def _openai(client, model, value): + return rechunk.openai_embeddings2ollama(await client.embeddings.create(input=value, model=model)) + + return await _handle_embedding_request( + request, input_field="prompt", context="embeddings_proxy", + make_native=_native, make_openai=_openai, + ) + + +@router.post("/api/embed") +async def embed_proxy(request: Request): + """Proxy an embed request to Ollama and reply with embeddings.""" + async def _native(client, model, value, options, keep_alive, truncate): + return await client.embed(model=model, input=value, truncate=truncate, options=options, keep_alive=keep_alive) + + async def _openai(client, model, value): + return rechunk.openai_embed2ollama(await client.embeddings.create(input=value, model=model), model) + + return await _handle_embedding_request( + request, input_field="input", context="embed_proxy", + make_native=_native, make_openai=_openai, + ) + + @router.post("/api/create") async def create_proxy(request: Request): """