refac: code deduplication for error handling and call sites
All checks were successful
Build and Publish Docker Image (Semantic Cache) / build (amd64, linux/amd64, docker-amd64) (push) Successful in 4m2s
Build and Publish Docker Image / build (amd64, linux/amd64, docker-amd64) (push) Successful in 1m37s
Build and Publish Docker Image (Semantic Cache) / build (arm64, linux/arm64, docker-arm64) (push) Successful in 17m43s
Build and Publish Docker Image (Semantic Cache) / merge (push) Successful in 34s
Build and Publish Docker Image / build (arm64, linux/arm64, docker-arm64) (push) Successful in 12m47s
Build and Publish Docker Image / merge (push) Successful in 33s
All checks were successful
Build and Publish Docker Image (Semantic Cache) / build (amd64, linux/amd64, docker-amd64) (push) Successful in 4m2s
Build and Publish Docker Image / build (amd64, linux/amd64, docker-amd64) (push) Successful in 1m37s
Build and Publish Docker Image (Semantic Cache) / build (arm64, linux/arm64, docker-arm64) (push) Successful in 17m43s
Build and Publish Docker Image (Semantic Cache) / merge (push) Successful in 34s
Build and Publish Docker Image / build (arm64, linux/arm64, docker-arm64) (push) Successful in 12m47s
Build and Publish Docker Image / merge (push) Successful in 33s
This commit is contained in:
parent
2dceece0d6
commit
497c87b02e
1 changed files with 253 additions and 283 deletions
536
api/ollama.py
536
api/ollama.py
|
|
@ -91,6 +91,30 @@ async def _handle_stream_error(
|
|||
return orjson.dumps(err_payload) + b"\n"
|
||||
|
||||
|
||||
async def _guarded_stream(inner, *, endpoint: str, model: str, tracking_model: str, context: str):
|
||||
"""Wrap a per-route body generator with the shared streaming contract.
|
||||
|
||||
Every ``/api/*`` streaming handler needs the same three guarantees around its
|
||||
body: surface backend errors transitively (via :func:`_handle_stream_error`),
|
||||
let client-disconnect cancellation propagate untouched, and always decrement
|
||||
the usage counter. Centralising them here keeps the four bodies free of
|
||||
duplicated ``try/except/finally`` scaffolding.
|
||||
"""
|
||||
try:
|
||||
async for item in inner:
|
||||
yield item
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
try:
|
||||
yield await _handle_stream_error(e, endpoint, model, context=context)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
# Ensure counter is decremented even if an exception occurs
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
|
||||
@router.post("/api/generate")
|
||||
async def proxy(request: Request):
|
||||
"""
|
||||
|
|
@ -165,87 +189,79 @@ async def proxy(request: Request):
|
|||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
|
||||
# 4. Async generator that streams data and decrements the counter
|
||||
# 4. Async generator body (error handling + cleanup handled by _guarded_stream)
|
||||
async def stream_generate_response():
|
||||
try:
|
||||
if use_openai:
|
||||
start_ts = time.perf_counter()
|
||||
async_gen = await oclient.completions.create(**params)
|
||||
else:
|
||||
async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=_format, images=images, options=options, keep_alive=keep_alive)
|
||||
if stream == True:
|
||||
content_parts: list[str] = []
|
||||
async for chunk in async_gen:
|
||||
if use_openai:
|
||||
chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts)
|
||||
prompt_tok = chunk.prompt_eval_count or 0
|
||||
comp_tok = chunk.eval_count or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
if hasattr(chunk, "model_dump_json"):
|
||||
json_line = chunk.model_dump_json()
|
||||
else:
|
||||
json_line = orjson.dumps(chunk)
|
||||
# Accumulate and store cache on done chunk — before yield so it always runs
|
||||
if _cache is not None and _cache_enabled:
|
||||
if getattr(chunk, "response", None):
|
||||
content_parts.append(chunk.response)
|
||||
if getattr(chunk, "done", False):
|
||||
assembled = orjson.dumps({
|
||||
k: v for k, v in {
|
||||
"model": getattr(chunk, "model", model),
|
||||
"response": "".join(content_parts),
|
||||
"done": True,
|
||||
"done_reason": getattr(chunk, "done_reason", "stop") or "stop",
|
||||
"prompt_eval_count": getattr(chunk, "prompt_eval_count", None),
|
||||
"eval_count": getattr(chunk, "eval_count", None),
|
||||
"total_duration": getattr(chunk, "total_duration", None),
|
||||
"eval_duration": getattr(chunk, "eval_duration", None),
|
||||
}.items() if v is not None
|
||||
}) + b"\n"
|
||||
try:
|
||||
await _cache.set_generate(model, prompt, system or "", assembled)
|
||||
except Exception as _ce:
|
||||
print(f"[cache] set_generate (streaming) failed: {_ce}")
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
else:
|
||||
if use_openai:
|
||||
start_ts = time.perf_counter()
|
||||
async_gen = await oclient.completions.create(**params)
|
||||
else:
|
||||
async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=_format, images=images, options=options, keep_alive=keep_alive)
|
||||
if stream == True:
|
||||
content_parts: list[str] = []
|
||||
async for chunk in async_gen:
|
||||
if use_openai:
|
||||
response = rechunk.openai_completion2ollama(async_gen, stream, start_ts)
|
||||
response = response.model_dump_json()
|
||||
chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts)
|
||||
prompt_tok = chunk.prompt_eval_count or 0
|
||||
comp_tok = chunk.eval_count or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
if hasattr(chunk, "model_dump_json"):
|
||||
json_line = chunk.model_dump_json()
|
||||
else:
|
||||
response = async_gen.model_dump_json()
|
||||
prompt_tok = async_gen.prompt_eval_count or 0
|
||||
comp_tok = async_gen.eval_count or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
json_line = (
|
||||
response
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
else orjson.dumps(async_gen)
|
||||
)
|
||||
cache_bytes = json_line.encode("utf-8") + b"\n"
|
||||
yield cache_bytes
|
||||
# Cache non-streaming response
|
||||
json_line = orjson.dumps(chunk)
|
||||
# Accumulate and store cache on done chunk — before yield so it always runs
|
||||
if _cache is not None and _cache_enabled:
|
||||
try:
|
||||
await _cache.set_generate(model, prompt, system or "", cache_bytes)
|
||||
except Exception as _ce:
|
||||
print(f"[cache] set_generate (non-streaming) failed: {_ce}")
|
||||
if getattr(chunk, "response", None):
|
||||
content_parts.append(chunk.response)
|
||||
if getattr(chunk, "done", False):
|
||||
assembled = orjson.dumps({
|
||||
k: v for k, v in {
|
||||
"model": getattr(chunk, "model", model),
|
||||
"response": "".join(content_parts),
|
||||
"done": True,
|
||||
"done_reason": getattr(chunk, "done_reason", "stop") or "stop",
|
||||
"prompt_eval_count": getattr(chunk, "prompt_eval_count", None),
|
||||
"eval_count": getattr(chunk, "eval_count", None),
|
||||
"total_duration": getattr(chunk, "total_duration", None),
|
||||
"eval_duration": getattr(chunk, "eval_duration", None),
|
||||
}.items() if v is not None
|
||||
}) + b"\n"
|
||||
try:
|
||||
await _cache.set_generate(model, prompt, system or "", assembled)
|
||||
except Exception as _ce:
|
||||
print(f"[cache] set_generate (streaming) failed: {_ce}")
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
else:
|
||||
if use_openai:
|
||||
response = rechunk.openai_completion2ollama(async_gen, stream, start_ts)
|
||||
response = response.model_dump_json()
|
||||
else:
|
||||
response = async_gen.model_dump_json()
|
||||
prompt_tok = async_gen.prompt_eval_count or 0
|
||||
comp_tok = async_gen.eval_count or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
json_line = (
|
||||
response
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
else orjson.dumps(async_gen)
|
||||
)
|
||||
cache_bytes = json_line.encode("utf-8") + b"\n"
|
||||
yield cache_bytes
|
||||
# Cache non-streaming response
|
||||
if _cache is not None and _cache_enabled:
|
||||
try:
|
||||
await _cache.set_generate(model, prompt, system or "", cache_bytes)
|
||||
except Exception as _ce:
|
||||
print(f"[cache] set_generate (non-streaming) failed: {_ce}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
try:
|
||||
yield await _handle_stream_error(e, endpoint, model, context="generate_proxy")
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
# Ensure counter is decremented even if an exception occurs
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
# 5. Return a StreamingResponse backed by the generator
|
||||
# 5. Return a StreamingResponse backed by the guarded generator
|
||||
return StreamingResponse(
|
||||
stream_generate_response(),
|
||||
_guarded_stream(
|
||||
stream_generate_response(),
|
||||
endpoint=endpoint, model=model,
|
||||
tracking_model=tracking_model, context="generate_proxy",
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
|
|
@ -425,128 +441,130 @@ async def chat_proxy(request: Request):
|
|||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
|
||||
# 3. Async generator that streams chat data and decrements the counter
|
||||
# 3. Async generator body (error handling + cleanup handled by _guarded_stream)
|
||||
async def stream_chat_response():
|
||||
try:
|
||||
# The chat method returns a generator of dicts (or GenerateResponse)
|
||||
if use_openai:
|
||||
_async_gen = async_gen # established in handler scope above
|
||||
else:
|
||||
if opt == True:
|
||||
# Use the dedicated MOE helper function
|
||||
_async_gen = await _make_moe_requests(model, messages, tools, think, _format, options, keep_alive)
|
||||
else:
|
||||
_async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive, logprobs=logprobs, top_logprobs=top_logprobs)
|
||||
if stream == True:
|
||||
tc_acc = {} # accumulate OpenAI tool-call deltas across chunks
|
||||
content_parts: list[str] = []
|
||||
async for chunk in _async_gen:
|
||||
if use_openai:
|
||||
_accumulate_openai_tc_delta(chunk, tc_acc)
|
||||
chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts)
|
||||
# Inject fully-accumulated tool calls only into the final chunk
|
||||
if chunk.done and tc_acc and chunk.message:
|
||||
chunk.message.tool_calls = _build_ollama_tool_calls(tc_acc)
|
||||
# `chunk` can be a dict or a pydantic model – dump to JSON safely
|
||||
prompt_tok = chunk.prompt_eval_count or 0
|
||||
comp_tok = chunk.eval_count or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
if hasattr(chunk, "model_dump_json"):
|
||||
json_line = chunk.model_dump_json()
|
||||
else:
|
||||
json_line = orjson.dumps(chunk)
|
||||
# Accumulate and store cache on done chunk — before yield so it always runs
|
||||
# Works for both Ollama-native and OpenAI-compatible backends; chunks are
|
||||
# already converted to Ollama format by rechunk before this point.
|
||||
if getattr(chunk, "done", False):
|
||||
# Detect context exhaustion mid-generation for small-ctx models
|
||||
_dr = getattr(chunk, "done_reason", None)
|
||||
# Only cache when no max_tokens limit was set — otherwise
|
||||
# finish_reason=length might just mean max_tokens was hit,
|
||||
# not that the context window was exhausted.
|
||||
_req_max_tok = (
|
||||
params.get("max_tokens") or params.get("max_completion_tokens") or params.get("num_predict")
|
||||
if use_openai else
|
||||
(options.get("num_predict") if options else None)
|
||||
)
|
||||
if _dr == "length" and not _req_max_tok:
|
||||
_pt = getattr(chunk, "prompt_eval_count", 0) or 0
|
||||
_ct = getattr(chunk, "eval_count", 0) or 0
|
||||
_inferred_nctx = _pt + _ct
|
||||
if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
||||
_endpoint_nctx[(endpoint, model)] = _inferred_nctx
|
||||
print(f"[ctx-cache] done_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True)
|
||||
if _cache is not None and not _is_moe and _cache_enabled:
|
||||
if chunk.message and getattr(chunk.message, "content", None):
|
||||
content_parts.append(chunk.message.content)
|
||||
if getattr(chunk, "done", False):
|
||||
assembled = orjson.dumps({
|
||||
k: v for k, v in {
|
||||
"model": getattr(chunk, "model", model),
|
||||
"created_at": (lambda ca: ca.isoformat() if hasattr(ca, "isoformat") else ca)(getattr(chunk, "created_at", None)),
|
||||
"message": {"role": "assistant", "content": "".join(content_parts)},
|
||||
"done": True,
|
||||
"done_reason": getattr(chunk, "done_reason", "stop") or "stop",
|
||||
"prompt_eval_count": getattr(chunk, "prompt_eval_count", None),
|
||||
"eval_count": getattr(chunk, "eval_count", None),
|
||||
"total_duration": getattr(chunk, "total_duration", None),
|
||||
"eval_duration": getattr(chunk, "eval_duration", None),
|
||||
}.items() if v is not None
|
||||
}) + b"\n"
|
||||
try:
|
||||
await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, assembled)
|
||||
except Exception as _ce:
|
||||
print(f"[cache] set_chat (ollama_chat streaming) failed: {_ce}")
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
# The chat method returns a generator of dicts (or GenerateResponse)
|
||||
if use_openai:
|
||||
_async_gen = async_gen # established in handler scope above
|
||||
else:
|
||||
if opt == True:
|
||||
# Use the dedicated MOE helper function
|
||||
_async_gen = await _make_moe_requests(model, messages, tools, think, _format, options, keep_alive)
|
||||
else:
|
||||
_async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive, logprobs=logprobs, top_logprobs=top_logprobs)
|
||||
if stream == True:
|
||||
tc_acc = {} # accumulate OpenAI tool-call deltas across chunks
|
||||
content_parts: list[str] = []
|
||||
async for chunk in _async_gen:
|
||||
if use_openai:
|
||||
response = rechunk.openai_chat_completion2ollama(_async_gen, stream, start_ts)
|
||||
response = response.model_dump_json()
|
||||
_accumulate_openai_tc_delta(chunk, tc_acc)
|
||||
chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts)
|
||||
# Inject fully-accumulated tool calls only into the final chunk
|
||||
if chunk.done and tc_acc and chunk.message:
|
||||
chunk.message.tool_calls = _build_ollama_tool_calls(tc_acc)
|
||||
# `chunk` can be a dict or a pydantic model – dump to JSON safely
|
||||
prompt_tok = chunk.prompt_eval_count or 0
|
||||
comp_tok = chunk.eval_count or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
if hasattr(chunk, "model_dump_json"):
|
||||
json_line = chunk.model_dump_json()
|
||||
else:
|
||||
response = _async_gen.model_dump_json()
|
||||
prompt_tok = _async_gen.prompt_eval_count or 0
|
||||
comp_tok = _async_gen.eval_count or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
json_line = (
|
||||
response
|
||||
if hasattr(_async_gen, "model_dump_json")
|
||||
else orjson.dumps(_async_gen)
|
||||
)
|
||||
cache_bytes = json_line.encode("utf-8") + b"\n"
|
||||
yield cache_bytes
|
||||
# Cache non-streaming response (non-MOE; works for both Ollama and OpenAI backends)
|
||||
json_line = orjson.dumps(chunk)
|
||||
# Accumulate and store cache on done chunk — before yield so it always runs
|
||||
# Works for both Ollama-native and OpenAI-compatible backends; chunks are
|
||||
# already converted to Ollama format by rechunk before this point.
|
||||
if getattr(chunk, "done", False):
|
||||
# Detect context exhaustion mid-generation for small-ctx models
|
||||
_dr = getattr(chunk, "done_reason", None)
|
||||
# Only cache when no max_tokens limit was set — otherwise
|
||||
# finish_reason=length might just mean max_tokens was hit,
|
||||
# not that the context window was exhausted.
|
||||
_req_max_tok = (
|
||||
params.get("max_tokens") or params.get("max_completion_tokens") or params.get("num_predict")
|
||||
if use_openai else
|
||||
(options.get("num_predict") if options else None)
|
||||
)
|
||||
if _dr == "length" and not _req_max_tok:
|
||||
_pt = getattr(chunk, "prompt_eval_count", 0) or 0
|
||||
_ct = getattr(chunk, "eval_count", 0) or 0
|
||||
_inferred_nctx = _pt + _ct
|
||||
if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
||||
_endpoint_nctx[(endpoint, model)] = _inferred_nctx
|
||||
print(f"[ctx-cache] done_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True)
|
||||
if _cache is not None and not _is_moe and _cache_enabled:
|
||||
try:
|
||||
await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, cache_bytes)
|
||||
except Exception as _ce:
|
||||
print(f"[cache] set_chat (ollama_chat non-streaming) failed: {_ce}")
|
||||
if chunk.message and getattr(chunk.message, "content", None):
|
||||
content_parts.append(chunk.message.content)
|
||||
if getattr(chunk, "done", False):
|
||||
assembled = orjson.dumps({
|
||||
k: v for k, v in {
|
||||
"model": getattr(chunk, "model", model),
|
||||
"created_at": (lambda ca: ca.isoformat() if hasattr(ca, "isoformat") else ca)(getattr(chunk, "created_at", None)),
|
||||
"message": {"role": "assistant", "content": "".join(content_parts)},
|
||||
"done": True,
|
||||
"done_reason": getattr(chunk, "done_reason", "stop") or "stop",
|
||||
"prompt_eval_count": getattr(chunk, "prompt_eval_count", None),
|
||||
"eval_count": getattr(chunk, "eval_count", None),
|
||||
"total_duration": getattr(chunk, "total_duration", None),
|
||||
"eval_duration": getattr(chunk, "eval_duration", None),
|
||||
}.items() if v is not None
|
||||
}) + b"\n"
|
||||
try:
|
||||
await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, assembled)
|
||||
except Exception as _ce:
|
||||
print(f"[cache] set_chat (ollama_chat streaming) failed: {_ce}")
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
else:
|
||||
if use_openai:
|
||||
response = rechunk.openai_chat_completion2ollama(_async_gen, stream, start_ts)
|
||||
response = response.model_dump_json()
|
||||
else:
|
||||
response = _async_gen.model_dump_json()
|
||||
prompt_tok = _async_gen.prompt_eval_count or 0
|
||||
comp_tok = _async_gen.eval_count or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
json_line = (
|
||||
response
|
||||
if hasattr(_async_gen, "model_dump_json")
|
||||
else orjson.dumps(_async_gen)
|
||||
)
|
||||
cache_bytes = json_line.encode("utf-8") + b"\n"
|
||||
yield cache_bytes
|
||||
# Cache non-streaming response (non-MOE; works for both Ollama and OpenAI backends)
|
||||
if _cache is not None and not _is_moe and _cache_enabled:
|
||||
try:
|
||||
await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, cache_bytes)
|
||||
except Exception as _ce:
|
||||
print(f"[cache] set_chat (ollama_chat non-streaming) failed: {_ce}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
try:
|
||||
yield await _handle_stream_error(e, endpoint, model, context="chat_proxy")
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
# Ensure counter is decremented even if an exception occurs
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
# 4. Return a StreamingResponse backed by the generator
|
||||
# 4. Return a StreamingResponse backed by the guarded generator
|
||||
media_type = "application/x-ndjson" if stream else "application/json"
|
||||
return StreamingResponse(
|
||||
stream_chat_response(),
|
||||
_guarded_stream(
|
||||
stream_chat_response(),
|
||||
endpoint=endpoint, model=model,
|
||||
tracking_model=tracking_model, context="chat_proxy",
|
||||
),
|
||||
media_type=media_type,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/embeddings")
|
||||
async def embedding_proxy(request: Request):
|
||||
"""
|
||||
Proxy an embedding request to Ollama and reply with embeddings.
|
||||
async def _handle_embedding_request(
|
||||
request: Request,
|
||||
*,
|
||||
input_field: str,
|
||||
context: str,
|
||||
make_native,
|
||||
make_openai,
|
||||
):
|
||||
"""Shared implementation for ``/api/embeddings`` and ``/api/embed``.
|
||||
|
||||
The two routes differ only in the request field they read (``prompt`` vs
|
||||
``input``), the ollama SDK method they call, and the OpenAI rechunk helper.
|
||||
Those are passed in via ``input_field`` and the ``make_native`` /
|
||||
``make_openai`` callables; everything else — parsing, endpoint selection,
|
||||
serialization, and the streaming error/cleanup contract — is shared.
|
||||
"""
|
||||
config = get_config()
|
||||
# 1. Parse and validate request
|
||||
|
|
@ -555,77 +573,7 @@ async def embedding_proxy(request: Request):
|
|||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
prompt = payload.get("prompt")
|
||||
options = payload.get("options")
|
||||
keep_alive = payload.get("keep_alive")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
if not prompt:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'prompt'"
|
||||
)
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint, tracking_model = await choose_endpoint(model)
|
||||
use_openai = is_openai_compatible(endpoint)
|
||||
if use_openai:
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
model = model[0]
|
||||
client = _make_openai_client(endpoint, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
# 3. Async generator that streams embedding data and decrements the counter
|
||||
async def stream_embedding_response():
|
||||
try:
|
||||
# The chat method returns a generator of dicts (or GenerateResponse)
|
||||
if use_openai:
|
||||
async_gen = await client.embeddings.create(input=prompt, model=model)
|
||||
async_gen = rechunk.openai_embeddings2ollama(async_gen)
|
||||
else:
|
||||
async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive)
|
||||
if hasattr(async_gen, "model_dump_json"):
|
||||
json_line = async_gen.model_dump_json()
|
||||
else:
|
||||
json_line = orjson.dumps(async_gen)
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
try:
|
||||
yield await _handle_stream_error(e, endpoint, model, context="embeddings_proxy")
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
# Ensure counter is decremented even if an exception occurs
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
# 5. Return a StreamingResponse backed by the generator
|
||||
return StreamingResponse(
|
||||
stream_embedding_response(),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/embed")
|
||||
async def embed_proxy(request: Request):
|
||||
"""
|
||||
Proxy an embed request to Ollama and reply with embeddings.
|
||||
|
||||
"""
|
||||
config = get_config()
|
||||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
_input = payload.get("input")
|
||||
value = payload.get(input_field)
|
||||
truncate = payload.get("truncate")
|
||||
options = payload.get("options")
|
||||
keep_alive = payload.get("keep_alive")
|
||||
|
|
@ -634,9 +582,9 @@ async def embed_proxy(request: Request):
|
|||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
if not _input:
|
||||
if not value:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'input'"
|
||||
status_code=400, detail=f"Missing required field '{input_field}'"
|
||||
)
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
|
@ -651,38 +599,60 @@ async def embed_proxy(request: Request):
|
|||
client = _make_openai_client(endpoint, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
# 3. Async generator that streams embed data and decrements the counter
|
||||
async def stream_embedding_response():
|
||||
try:
|
||||
# The chat method returns a generator of dicts (or GenerateResponse)
|
||||
if use_openai:
|
||||
async_gen = await client.embeddings.create(input=_input, model=model)
|
||||
async_gen = rechunk.openai_embed2ollama(async_gen, model)
|
||||
else:
|
||||
async_gen = await client.embed(model=model, input=_input, truncate=truncate, options=options, keep_alive=keep_alive)
|
||||
if hasattr(async_gen, "model_dump_json"):
|
||||
json_line = async_gen.model_dump_json()
|
||||
else:
|
||||
json_line = orjson.dumps(async_gen)
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
try:
|
||||
yield await _handle_stream_error(e, endpoint, model, context="embed_proxy")
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
# Ensure counter is decremented even if an exception occurs
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
# 4. Return a StreamingResponse backed by the generator
|
||||
# 3. Async generator body (error handling + cleanup handled by _guarded_stream)
|
||||
async def stream_embedding_response():
|
||||
if use_openai:
|
||||
response = await make_openai(client, model, value)
|
||||
else:
|
||||
response = await make_native(client, model, value, options, keep_alive, truncate)
|
||||
if hasattr(response, "model_dump_json"):
|
||||
json_line = response.model_dump_json()
|
||||
else:
|
||||
json_line = orjson.dumps(response)
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
|
||||
# 4. Return a StreamingResponse backed by the guarded generator
|
||||
return StreamingResponse(
|
||||
stream_embedding_response(),
|
||||
_guarded_stream(
|
||||
stream_embedding_response(),
|
||||
endpoint=endpoint, model=model,
|
||||
tracking_model=tracking_model, context=context,
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/embeddings")
|
||||
async def embedding_proxy(request: Request):
|
||||
"""Proxy an embedding request to Ollama and reply with embeddings."""
|
||||
async def _native(client, model, value, options, keep_alive, truncate):
|
||||
return await client.embeddings(model=model, prompt=value, options=options, keep_alive=keep_alive)
|
||||
|
||||
async def _openai(client, model, value):
|
||||
return rechunk.openai_embeddings2ollama(await client.embeddings.create(input=value, model=model))
|
||||
|
||||
return await _handle_embedding_request(
|
||||
request, input_field="prompt", context="embeddings_proxy",
|
||||
make_native=_native, make_openai=_openai,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/embed")
|
||||
async def embed_proxy(request: Request):
|
||||
"""Proxy an embed request to Ollama and reply with embeddings."""
|
||||
async def _native(client, model, value, options, keep_alive, truncate):
|
||||
return await client.embed(model=model, input=value, truncate=truncate, options=options, keep_alive=keep_alive)
|
||||
|
||||
async def _openai(client, model, value):
|
||||
return rechunk.openai_embed2ollama(await client.embeddings.create(input=value, model=model), model)
|
||||
|
||||
return await _handle_embedding_request(
|
||||
request, input_field="input", context="embed_proxy",
|
||||
make_native=_native, make_openai=_openai,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/create")
|
||||
async def create_proxy(request: Request):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue