Merge pull request #37 from nomyo-ai/dev-v0.7.x-semcache
Dev v0.7.x semcache addtl. feature
This commit is contained in:
commit
21d6835253
2 changed files with 369 additions and 36 deletions
|
|
@ -32,6 +32,7 @@ PyYAML==6.0.3
|
||||||
sniffio==1.3.1
|
sniffio==1.3.1
|
||||||
starlette==0.49.1
|
starlette==0.49.1
|
||||||
truststore==0.10.4
|
truststore==0.10.4
|
||||||
|
tiktoken==0.12.0
|
||||||
tqdm==4.67.1
|
tqdm==4.67.1
|
||||||
typing-inspection==0.4.1
|
typing-inspection==0.4.1
|
||||||
typing_extensions==4.14.1
|
typing_extensions==4.14.1
|
||||||
|
|
|
||||||
390
router.py
390
router.py
|
|
@ -78,6 +78,107 @@ def _mask_secrets(text: str) -> str:
|
||||||
text = re.sub(r"(?i)(api[-_ ]key\s*[:=]\s*)([^\s]+)", r"\1***redacted***", text)
|
text = re.sub(r"(?i)(api[-_ ]key\s*[:=]\s*)([^\s]+)", r"\1***redacted***", text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Context-window sliding-window helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
try:
|
||||||
|
import tiktoken as _tiktoken
|
||||||
|
_tiktoken_enc = _tiktoken.get_encoding("cl100k_base")
|
||||||
|
except Exception:
|
||||||
|
_tiktoken_enc = None
|
||||||
|
|
||||||
|
def _count_message_tokens(messages: list) -> int:
|
||||||
|
"""Approximate token count for a message list.
|
||||||
|
|
||||||
|
Uses tiktoken cl100k_base when available (within ~5-15% of llama tokenizers).
|
||||||
|
Falls back to char/4 heuristic if tiktoken is unavailable.
|
||||||
|
Formula follows OpenAI's per-message overhead: 4 tokens/message + content + 2 priming.
|
||||||
|
"""
|
||||||
|
if _tiktoken_enc is None:
|
||||||
|
return sum(len(str(m.get("content", ""))) for m in messages) // 4
|
||||||
|
|
||||||
|
total = 2 # priming tokens
|
||||||
|
for msg in messages:
|
||||||
|
total += 4 # per-message role/separator overhead
|
||||||
|
content = msg.get("content", "")
|
||||||
|
if isinstance(content, str):
|
||||||
|
total += len(_tiktoken_enc.encode(content))
|
||||||
|
elif isinstance(content, list):
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, dict) and part.get("type") == "text":
|
||||||
|
total += len(_tiktoken_enc.encode(part.get("text", "")))
|
||||||
|
return total
|
||||||
|
|
||||||
|
def _trim_messages_for_context(
|
||||||
|
messages: list,
|
||||||
|
n_ctx: int,
|
||||||
|
safety_margin: int = None,
|
||||||
|
target_tokens: int = None,
|
||||||
|
) -> list:
|
||||||
|
"""Sliding-window trim — mirrors what llama.cpp context-shift used to do.
|
||||||
|
|
||||||
|
Keeps all system messages and the most recent non-system messages that fit
|
||||||
|
within (n_ctx - safety_margin) tokens. Oldest non-system messages are dropped
|
||||||
|
first (FIFO). The last message is always preserved.
|
||||||
|
|
||||||
|
safety_margin defaults to 1/4 of n_ctx to leave headroom for the generated
|
||||||
|
response, including RAG tool results and tool call JSON synthesis.
|
||||||
|
|
||||||
|
target_tokens: if provided, overrides the (n_ctx - safety_margin) target.
|
||||||
|
Pass a calibrated value when actual n_prompt_tokens is known from the error
|
||||||
|
body so that tiktoken underestimation vs the backend tokenizer is corrected.
|
||||||
|
"""
|
||||||
|
if target_tokens is not None:
|
||||||
|
target = target_tokens
|
||||||
|
else:
|
||||||
|
if safety_margin is None:
|
||||||
|
safety_margin = n_ctx // 4
|
||||||
|
target = n_ctx - safety_margin
|
||||||
|
system_msgs = [m for m in messages if m.get("role") == "system"]
|
||||||
|
non_system = [m for m in messages if m.get("role") != "system"]
|
||||||
|
|
||||||
|
while len(non_system) > 1:
|
||||||
|
if _count_message_tokens(system_msgs + non_system) <= target:
|
||||||
|
break
|
||||||
|
non_system.pop(0) # drop oldest non-system message
|
||||||
|
|
||||||
|
# Ensure the first non-system message is a user message (chat templates require it).
|
||||||
|
# Drop any leading assistant/tool messages that were left after trimming.
|
||||||
|
while non_system and non_system[0].get("role") != "user":
|
||||||
|
non_system.pop(0)
|
||||||
|
|
||||||
|
return system_msgs + non_system
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _calibrated_trim_target(msgs: list, n_ctx: int, actual_tokens: int) -> int:
|
||||||
|
"""Return a tiktoken-scale trim target based on how much backend tokens must be shed.
|
||||||
|
|
||||||
|
actual_tokens includes messages + tool schemas + overhead as counted by the backend.
|
||||||
|
_count_message_tokens only counts message text, so we cannot derive an accurate
|
||||||
|
per-token scale from the ratio. Instead we compute the *delta* we need to remove
|
||||||
|
in backend space, then convert just that delta to tiktoken scale (×1.2 buffer).
|
||||||
|
|
||||||
|
Example: actual=17993, n_ctx=16384, headroom=4096 → need to shed 5705 backend
|
||||||
|
tokens → shed 6846 tiktoken tokens from messages.
|
||||||
|
"""
|
||||||
|
cur_tiktoken = _count_message_tokens(msgs)
|
||||||
|
headroom = n_ctx // 4 # reserve for generated output
|
||||||
|
max_prompt = n_ctx - headroom # desired max backend tokens in prompt
|
||||||
|
to_shed = max(0, actual_tokens - max_prompt) # backend tokens we must drop
|
||||||
|
# Convert to tiktoken scale with 20% buffer (tiktoken underestimates llama by ~15-20%)
|
||||||
|
tiktoken_to_shed = int(to_shed * 1.2)
|
||||||
|
return max(1, cur_tiktoken - tiktoken_to_shed)
|
||||||
|
|
||||||
|
# Per-(endpoint, model) n_ctx cache.
|
||||||
|
# Populated from two sources:
|
||||||
|
# 1. 400 exceed_context_size_error body → n_ctx field
|
||||||
|
# 2. finish_reason/done_reason == "length" in streaming → prompt_tokens + completion_tokens
|
||||||
|
# Only used for proactive pre-trimming when n_ctx <= _CTX_TRIM_SMALL_LIMIT,
|
||||||
|
# so large-context models (200k+ for coding) are never touched.
|
||||||
|
_endpoint_nctx: dict[tuple[str, str], int] = {}
|
||||||
|
_CTX_TRIM_SMALL_LIMIT = 32768 # only proactively trim models with n_ctx at or below this
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Globals
|
# Globals
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
@ -707,10 +808,15 @@ class fetch:
|
||||||
# Check error cache with lock protection
|
# Check error cache with lock protection
|
||||||
async with _available_error_cache_lock:
|
async with _available_error_cache_lock:
|
||||||
if endpoint in _available_error_cache:
|
if endpoint in _available_error_cache:
|
||||||
if _is_fresh(_available_error_cache[endpoint], 300):
|
err_age = time.time() - _available_error_cache[endpoint]
|
||||||
# Still within the short error TTL – pretend nothing is available
|
if err_age < 30:
|
||||||
|
# Very fresh error (<30s) – endpoint likely still down, bail fast
|
||||||
return set()
|
return set()
|
||||||
# Error expired – remove it
|
elif err_age < 300:
|
||||||
|
# Stale error (30-300s) – endpoint may have recovered, probe in background
|
||||||
|
asyncio.create_task(fetch._refresh_available_models(endpoint, api_key))
|
||||||
|
return set()
|
||||||
|
# Error expired (>300s) – remove and fall through to fresh fetch
|
||||||
del _available_error_cache[endpoint]
|
del _available_error_cache[endpoint]
|
||||||
|
|
||||||
# Request coalescing: check if another request is already fetching this endpoint
|
# Request coalescing: check if another request is already fetching this endpoint
|
||||||
|
|
@ -983,7 +1089,37 @@ async def _make_chat_request(model: str, messages: list, tools=None, stream: boo
|
||||||
try:
|
try:
|
||||||
if use_openai:
|
if use_openai:
|
||||||
start_ts = time.perf_counter()
|
start_ts = time.perf_counter()
|
||||||
|
try:
|
||||||
response = await oclient.chat.completions.create(**params)
|
response = await oclient.chat.completions.create(**params)
|
||||||
|
except Exception as e:
|
||||||
|
_e_str = str(e)
|
||||||
|
print(f"[_make_chat_request] caught {type(e).__name__}: {_e_str[:200]}")
|
||||||
|
if "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str:
|
||||||
|
err_body = getattr(e, "body", {}) or {}
|
||||||
|
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
|
||||||
|
n_ctx_limit = err_detail.get("n_ctx", 0)
|
||||||
|
actual_tokens = err_detail.get("n_prompt_tokens", 0)
|
||||||
|
if not n_ctx_limit:
|
||||||
|
raise
|
||||||
|
msgs_to_trim = params.get("messages", [])
|
||||||
|
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
|
||||||
|
trimmed = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
|
||||||
|
print(f"[_make_chat_request] Context exceeded ({actual_tokens}/{n_ctx_limit} tokens, tiktoken_target={cal_target}), dropped {len(msgs_to_trim) - len(trimmed)} oldest message(s) and retrying")
|
||||||
|
try:
|
||||||
|
response = await oclient.chat.completions.create(**{**params, "messages": trimmed})
|
||||||
|
except Exception as e2:
|
||||||
|
if "exceed_context_size_error" in str(e2) or "exceeds the available context size" in str(e2):
|
||||||
|
print(f"[_make_chat_request] Context still exceeded after trimming, also stripping tools")
|
||||||
|
params_no_tools = {k: v for k, v in params.items() if k not in ("tools", "tool_choice")}
|
||||||
|
response = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed})
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
elif "image input is not supported" in _e_str:
|
||||||
|
print(f"[_make_chat_request] Model {model} doesn't support images, retrying with text-only messages")
|
||||||
|
params = {**params, "messages": _strip_images_from_messages(params.get("messages", []))}
|
||||||
|
response = await oclient.chat.completions.create(**params)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
if stream:
|
if stream:
|
||||||
# For streaming, we need to collect all chunks
|
# For streaming, we need to collect all chunks
|
||||||
chunks = []
|
chunks = []
|
||||||
|
|
@ -1212,6 +1348,22 @@ def transform_images_to_data_urls(message_list):
|
||||||
|
|
||||||
return message_list
|
return message_list
|
||||||
|
|
||||||
|
def _strip_images_from_messages(messages: list) -> list:
|
||||||
|
"""Remove image_url parts from message content, keeping only text."""
|
||||||
|
result = []
|
||||||
|
for msg in messages:
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
text_only = [p for p in content if p.get("type") != "image_url"]
|
||||||
|
if len(text_only) == 1 and text_only[0].get("type") == "text":
|
||||||
|
content = text_only[0]["text"]
|
||||||
|
else:
|
||||||
|
content = text_only
|
||||||
|
result.append({**msg, "content": content})
|
||||||
|
else:
|
||||||
|
result.append(msg)
|
||||||
|
return result
|
||||||
|
|
||||||
def _accumulate_openai_tc_delta(chunk, accumulator: dict) -> None:
|
def _accumulate_openai_tc_delta(chunk, accumulator: dict) -> None:
|
||||||
"""Accumulate tool_call deltas from a single OpenAI streaming chunk.
|
"""Accumulate tool_call deltas from a single OpenAI streaming chunk.
|
||||||
|
|
||||||
|
|
@ -1825,23 +1977,86 @@ async def chat_proxy(request: Request):
|
||||||
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||||
else:
|
else:
|
||||||
client = ollama.AsyncClient(host=endpoint)
|
client = ollama.AsyncClient(host=endpoint)
|
||||||
|
# For OpenAI endpoints: make the API call in handler scope
|
||||||
|
# (try/except inside async generators is unreliable with Starlette's streaming)
|
||||||
|
start_ts = None
|
||||||
|
async_gen = None
|
||||||
|
if use_openai:
|
||||||
|
start_ts = time.perf_counter()
|
||||||
|
# Proactive trim: only for small-ctx models we've already seen run out of space
|
||||||
|
_lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model
|
||||||
|
_known_nctx = _endpoint_nctx.get((endpoint, _lookup_model))
|
||||||
|
if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
||||||
|
_pre_target = int((_known_nctx - _known_nctx // 4) / 1.2)
|
||||||
|
_pre_est = _count_message_tokens(params.get("messages", []))
|
||||||
|
if _pre_est > _pre_target:
|
||||||
|
_pre_msgs = params.get("messages", [])
|
||||||
|
_pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target)
|
||||||
|
_dropped = len(_pre_msgs) - len(_pre_trimmed)
|
||||||
|
print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True)
|
||||||
|
params = {**params, "messages": _pre_trimmed}
|
||||||
|
try:
|
||||||
|
async_gen = await oclient.chat.completions.create(**params)
|
||||||
|
except Exception as e:
|
||||||
|
_e_str = str(e)
|
||||||
|
print(f"[chat_proxy] caught {type(e).__name__}: {_e_str[:200]}")
|
||||||
|
if "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str:
|
||||||
|
err_body = getattr(e, "body", {}) or {}
|
||||||
|
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
|
||||||
|
n_ctx_limit = err_detail.get("n_ctx", 0)
|
||||||
|
actual_tokens = err_detail.get("n_prompt_tokens", 0)
|
||||||
|
if not n_ctx_limit:
|
||||||
|
await decrement_usage(endpoint, tracking_model)
|
||||||
|
raise
|
||||||
|
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
|
||||||
|
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
|
||||||
|
msgs_to_trim = params.get("messages", [])
|
||||||
|
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
|
||||||
|
trimmed = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
|
||||||
|
print(f"[chat_proxy] Context exceeded ({actual_tokens}/{n_ctx_limit} tokens, tiktoken_target={cal_target}), dropped {len(msgs_to_trim) - len(trimmed)} oldest message(s) and retrying")
|
||||||
|
try:
|
||||||
|
async_gen = await oclient.chat.completions.create(**{**params, "messages": trimmed})
|
||||||
|
except Exception as e2:
|
||||||
|
_e2_str = str(e2)
|
||||||
|
if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str:
|
||||||
|
print(f"[chat_proxy] Context still exceeded after trimming messages, also stripping tools")
|
||||||
|
params_no_tools = {k: v for k, v in params.items() if k not in ("tools", "tool_choice")}
|
||||||
|
try:
|
||||||
|
async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed})
|
||||||
|
except Exception:
|
||||||
|
await decrement_usage(endpoint, tracking_model)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
await decrement_usage(endpoint, tracking_model)
|
||||||
|
raise
|
||||||
|
elif "image input is not supported" in _e_str:
|
||||||
|
print(f"[chat_proxy] Model {model} doesn't support images, retrying with text-only messages")
|
||||||
|
try:
|
||||||
|
params = {**params, "messages": _strip_images_from_messages(params.get("messages", []))}
|
||||||
|
async_gen = await oclient.chat.completions.create(**params)
|
||||||
|
except Exception:
|
||||||
|
await decrement_usage(endpoint, tracking_model)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
await decrement_usage(endpoint, tracking_model)
|
||||||
|
raise
|
||||||
|
|
||||||
# 3. Async generator that streams chat data and decrements the counter
|
# 3. Async generator that streams chat data and decrements the counter
|
||||||
async def stream_chat_response():
|
async def stream_chat_response():
|
||||||
try:
|
try:
|
||||||
# The chat method returns a generator of dicts (or GenerateResponse)
|
# The chat method returns a generator of dicts (or GenerateResponse)
|
||||||
if use_openai:
|
if use_openai:
|
||||||
start_ts = time.perf_counter()
|
_async_gen = async_gen # established in handler scope above
|
||||||
async_gen = await oclient.chat.completions.create(**params)
|
|
||||||
else:
|
else:
|
||||||
if opt == True:
|
if opt == True:
|
||||||
# Use the dedicated MOE helper function
|
# Use the dedicated MOE helper function
|
||||||
async_gen = await _make_moe_requests(model, messages, tools, think, _format, options, keep_alive)
|
_async_gen = await _make_moe_requests(model, messages, tools, think, _format, options, keep_alive)
|
||||||
else:
|
else:
|
||||||
async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive, logprobs=logprobs, top_logprobs=top_logprobs)
|
_async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive, logprobs=logprobs, top_logprobs=top_logprobs)
|
||||||
if stream == True:
|
if stream == True:
|
||||||
tc_acc = {} # accumulate OpenAI tool-call deltas across chunks
|
tc_acc = {} # accumulate OpenAI tool-call deltas across chunks
|
||||||
content_parts: list[str] = []
|
content_parts: list[str] = []
|
||||||
async for chunk in async_gen:
|
async for chunk in _async_gen:
|
||||||
if use_openai:
|
if use_openai:
|
||||||
_accumulate_openai_tc_delta(chunk, tc_acc)
|
_accumulate_openai_tc_delta(chunk, tc_acc)
|
||||||
chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts)
|
chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts)
|
||||||
|
|
@ -1860,6 +2075,20 @@ async def chat_proxy(request: Request):
|
||||||
# Accumulate and store cache on done chunk — before yield so it always runs
|
# Accumulate and store cache on done chunk — before yield so it always runs
|
||||||
# Works for both Ollama-native and OpenAI-compatible backends; chunks are
|
# Works for both Ollama-native and OpenAI-compatible backends; chunks are
|
||||||
# already converted to Ollama format by rechunk before this point.
|
# already converted to Ollama format by rechunk before this point.
|
||||||
|
if getattr(chunk, "done", False):
|
||||||
|
# Detect context exhaustion mid-generation for small-ctx models
|
||||||
|
_dr = getattr(chunk, "done_reason", None)
|
||||||
|
# Only cache when no max_tokens limit was set — otherwise
|
||||||
|
# finish_reason=length might just mean max_tokens was hit,
|
||||||
|
# not that the context window was exhausted.
|
||||||
|
_req_max_tok = params.get("max_tokens") or params.get("max_completion_tokens") or params.get("num_predict")
|
||||||
|
if _dr == "length" and not _req_max_tok:
|
||||||
|
_pt = getattr(chunk, "prompt_eval_count", 0) or 0
|
||||||
|
_ct = getattr(chunk, "eval_count", 0) or 0
|
||||||
|
_inferred_nctx = _pt + _ct
|
||||||
|
if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
||||||
|
_endpoint_nctx[(endpoint, model)] = _inferred_nctx
|
||||||
|
print(f"[ctx-cache] done_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True)
|
||||||
if _cache is not None and not _is_moe and _cache_enabled:
|
if _cache is not None and not _is_moe and _cache_enabled:
|
||||||
if chunk.message and getattr(chunk.message, "content", None):
|
if chunk.message and getattr(chunk.message, "content", None):
|
||||||
content_parts.append(chunk.message.content)
|
content_parts.append(chunk.message.content)
|
||||||
|
|
@ -1884,18 +2113,18 @@ async def chat_proxy(request: Request):
|
||||||
yield json_line.encode("utf-8") + b"\n"
|
yield json_line.encode("utf-8") + b"\n"
|
||||||
else:
|
else:
|
||||||
if use_openai:
|
if use_openai:
|
||||||
response = rechunk.openai_chat_completion2ollama(async_gen, stream, start_ts)
|
response = rechunk.openai_chat_completion2ollama(_async_gen, stream, start_ts)
|
||||||
response = response.model_dump_json()
|
response = response.model_dump_json()
|
||||||
else:
|
else:
|
||||||
response = async_gen.model_dump_json()
|
response = _async_gen.model_dump_json()
|
||||||
prompt_tok = async_gen.prompt_eval_count or 0
|
prompt_tok = _async_gen.prompt_eval_count or 0
|
||||||
comp_tok = async_gen.eval_count or 0
|
comp_tok = _async_gen.eval_count or 0
|
||||||
if prompt_tok != 0 or comp_tok != 0:
|
if prompt_tok != 0 or comp_tok != 0:
|
||||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||||
json_line = (
|
json_line = (
|
||||||
response
|
response
|
||||||
if hasattr(async_gen, "model_dump_json")
|
if hasattr(_async_gen, "model_dump_json")
|
||||||
else orjson.dumps(async_gen)
|
else orjson.dumps(_async_gen)
|
||||||
)
|
)
|
||||||
cache_bytes = json_line.encode("utf-8") + b"\n"
|
cache_bytes = json_line.encode("utf-8") + b"\n"
|
||||||
yield cache_bytes
|
yield cache_bytes
|
||||||
|
|
@ -2604,9 +2833,13 @@ async def ps_details_proxy(request: Request):
|
||||||
*[_fetch_llama_props(ep, mid) for ep, mid in props_requests]
|
*[_fetch_llama_props(ep, mid) for ep, mid in props_requests]
|
||||||
)
|
)
|
||||||
|
|
||||||
for model_dict, (n_ctx, is_sleeping) in zip(llama_models_pending, props_results):
|
for (ep, raw_id), model_dict, (n_ctx, is_sleeping) in zip(props_requests, llama_models_pending, props_results):
|
||||||
if n_ctx is not None:
|
if n_ctx is not None:
|
||||||
model_dict["context_length"] = n_ctx
|
model_dict["context_length"] = n_ctx
|
||||||
|
if 0 < n_ctx <= _CTX_TRIM_SMALL_LIMIT:
|
||||||
|
normalized = _normalize_llama_model_name(raw_id)
|
||||||
|
_endpoint_nctx[(ep, normalized)] = n_ctx
|
||||||
|
print(f"[ctx-cache/ps] cached n_ctx={n_ctx} for ({ep},{normalized})", flush=True)
|
||||||
if not is_sleeping:
|
if not is_sleeping:
|
||||||
models.append(model_dict)
|
models.append(model_dict)
|
||||||
|
|
||||||
|
|
@ -2686,6 +2919,21 @@ async def openai_embedding_proxy(request: Request):
|
||||||
model = payload.get("model")
|
model = payload.get("model")
|
||||||
doc = payload.get("input")
|
doc = payload.get("input")
|
||||||
|
|
||||||
|
# Normalize multimodal input: extract only text parts for embedding models
|
||||||
|
if isinstance(doc, list):
|
||||||
|
normalized = []
|
||||||
|
for item in doc:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
# Multimodal content part - extract text only, skip images
|
||||||
|
if item.get("type") == "text":
|
||||||
|
normalized.append(item.get("text", ""))
|
||||||
|
# Skip image_url and other non-text types
|
||||||
|
else:
|
||||||
|
normalized.append(item)
|
||||||
|
doc = normalized if len(normalized) != 1 else normalized[0]
|
||||||
|
elif isinstance(doc, dict) and doc.get("type") == "text":
|
||||||
|
doc = doc.get("text", "")
|
||||||
|
|
||||||
if not model:
|
if not model:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Missing required field 'model'"
|
status_code=400, detail="Missing required field 'model'"
|
||||||
|
|
@ -2819,7 +3067,7 @@ async def openai_chat_completions_proxy(request: Request):
|
||||||
endpoint, tracking_model = await choose_endpoint(model)
|
endpoint, tracking_model = await choose_endpoint(model)
|
||||||
base_url = ep2base(endpoint)
|
base_url = ep2base(endpoint)
|
||||||
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||||
# 3. Async generator that streams completions data and decrements the counter
|
# 3. Helpers and API call — done in handler scope so try/except works reliably
|
||||||
async def _normalize_images_in_messages(msgs: list) -> list:
|
async def _normalize_images_in_messages(msgs: list) -> list:
|
||||||
"""Fetch remote image URLs and convert them to base64 data URLs so
|
"""Fetch remote image URLs and convert them to base64 data URLs so
|
||||||
Ollama/llama-server can handle them without making outbound HTTP requests."""
|
Ollama/llama-server can handle them without making outbound HTTP requests."""
|
||||||
|
|
@ -2854,25 +3102,95 @@ async def openai_chat_completions_proxy(request: Request):
|
||||||
resolved.append({**msg, "content": new_content})
|
resolved.append({**msg, "content": new_content})
|
||||||
return resolved
|
return resolved
|
||||||
|
|
||||||
async def stream_ochat_response():
|
# Make the API call in handler scope — try/except inside async generators is unreliable
|
||||||
try:
|
# with Starlette's streaming machinery, so we resolve errors here before the generator starts.
|
||||||
# The chat method returns a generator of dicts (or GenerateResponse)
|
|
||||||
try:
|
|
||||||
# For non-external endpoints (Ollama, llama-server), resolve remote
|
|
||||||
# image URLs to base64 data URLs so the server can handle them locally.
|
|
||||||
send_params = params
|
send_params = params
|
||||||
if not is_ext_openai_endpoint(endpoint):
|
if not is_ext_openai_endpoint(endpoint):
|
||||||
resolved_msgs = await _normalize_images_in_messages(params.get("messages", []))
|
resolved_msgs = await _normalize_images_in_messages(params.get("messages", []))
|
||||||
send_params = {**params, "messages": resolved_msgs}
|
send_params = {**params, "messages": resolved_msgs}
|
||||||
|
# Proactive trim: only for small-ctx models we've already seen run out of space
|
||||||
|
_lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model
|
||||||
|
_known_nctx = _endpoint_nctx.get((endpoint, _lookup_model))
|
||||||
|
if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
||||||
|
_pre_target = int(((_known_nctx - _known_nctx // 4)) / 1.2)
|
||||||
|
_pre_est = _count_message_tokens(send_params.get("messages", []))
|
||||||
|
if _pre_est > _pre_target:
|
||||||
|
_pre_msgs = send_params.get("messages", [])
|
||||||
|
_pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target)
|
||||||
|
_dropped = len(_pre_msgs) - len(_pre_trimmed)
|
||||||
|
print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True)
|
||||||
|
send_params = {**send_params, "messages": _pre_trimmed}
|
||||||
|
try:
|
||||||
async_gen = await oclient.chat.completions.create(**send_params)
|
async_gen = await oclient.chat.completions.create(**send_params)
|
||||||
except openai.BadRequestError as e:
|
except Exception as e:
|
||||||
# If tools are not supported by the model, retry without tools
|
_e_str = str(e)
|
||||||
if "does not support tools" in str(e):
|
_is_ctx_err = "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str
|
||||||
print(f"[openai_chat_completions_proxy] Model {model} doesn't support tools, retrying without tools")
|
print(f"[ochat] caught={type(e).__name__} ctx={_is_ctx_err} msg={_e_str[:120]}", flush=True)
|
||||||
|
if "does not support tools" in _e_str:
|
||||||
|
# Model doesn't support tools — retry without them
|
||||||
|
print(f"[ochat] retry: no tools", flush=True)
|
||||||
|
try:
|
||||||
params_without_tools = {k: v for k, v in send_params.items() if k != "tools"}
|
params_without_tools = {k: v for k, v in send_params.items() if k != "tools"}
|
||||||
async_gen = await oclient.chat.completions.create(**params_without_tools)
|
async_gen = await oclient.chat.completions.create(**params_without_tools)
|
||||||
else:
|
except Exception:
|
||||||
|
await decrement_usage(endpoint, tracking_model)
|
||||||
raise
|
raise
|
||||||
|
elif _is_ctx_err:
|
||||||
|
# Backend context limit hit — apply sliding-window trim (context-shift at message level)
|
||||||
|
err_body = getattr(e, "body", {}) or {}
|
||||||
|
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
|
||||||
|
n_ctx_limit = err_detail.get("n_ctx", 0)
|
||||||
|
actual_tokens = err_detail.get("n_prompt_tokens", 0)
|
||||||
|
print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True)
|
||||||
|
if not n_ctx_limit:
|
||||||
|
await decrement_usage(endpoint, tracking_model)
|
||||||
|
raise
|
||||||
|
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
|
||||||
|
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
|
||||||
|
|
||||||
|
msgs_to_trim = send_params.get("messages", [])
|
||||||
|
try:
|
||||||
|
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
|
||||||
|
trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
|
||||||
|
except Exception as _helper_exc:
|
||||||
|
print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True)
|
||||||
|
await decrement_usage(endpoint, tracking_model)
|
||||||
|
raise
|
||||||
|
dropped = len(msgs_to_trim) - len(trimmed_messages)
|
||||||
|
print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True)
|
||||||
|
try:
|
||||||
|
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages})
|
||||||
|
print(f"[ctx-trim] retry-1 ok", flush=True)
|
||||||
|
except Exception as e2:
|
||||||
|
_e2_str = str(e2)
|
||||||
|
if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str:
|
||||||
|
# Still too large — tool definitions likely consuming too many tokens, strip them too
|
||||||
|
print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True)
|
||||||
|
params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")}
|
||||||
|
try:
|
||||||
|
async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages})
|
||||||
|
print(f"[ctx-trim] retry-2 ok", flush=True)
|
||||||
|
except Exception:
|
||||||
|
await decrement_usage(endpoint, tracking_model)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
await decrement_usage(endpoint, tracking_model)
|
||||||
|
raise
|
||||||
|
elif "image input is not supported" in _e_str:
|
||||||
|
# Model doesn't support images — strip and retry
|
||||||
|
print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages")
|
||||||
|
try:
|
||||||
|
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))})
|
||||||
|
except Exception:
|
||||||
|
await decrement_usage(endpoint, tracking_model)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
await decrement_usage(endpoint, tracking_model)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# 4. Async generator — only streams the already-established async_gen
|
||||||
|
async def stream_ochat_response():
|
||||||
|
try:
|
||||||
if stream == True:
|
if stream == True:
|
||||||
content_parts: list[str] = []
|
content_parts: list[str] = []
|
||||||
usage_snapshot: dict = {}
|
usage_snapshot: dict = {}
|
||||||
|
|
@ -2909,6 +3227,15 @@ async def openai_chat_completions_proxy(request: Request):
|
||||||
prompt_tok, comp_tok = llama_usage
|
prompt_tok, comp_tok = llama_usage
|
||||||
if prompt_tok != 0 or comp_tok != 0:
|
if prompt_tok != 0 or comp_tok != 0:
|
||||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||||
|
# Detect context exhaustion mid-generation for small-ctx models.
|
||||||
|
# Guard: skip if max_tokens was set in the request — finish_reason=length
|
||||||
|
# could just mean the caller's token budget was exhausted, not the context window.
|
||||||
|
_req_max_tok = send_params.get("max_tokens") or send_params.get("max_completion_tokens")
|
||||||
|
if chunk.choices and chunk.choices[0].finish_reason == "length" and not _req_max_tok:
|
||||||
|
_inferred_nctx = (prompt_tok + comp_tok) or 0
|
||||||
|
if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
||||||
|
_endpoint_nctx[(endpoint, model)] = _inferred_nctx
|
||||||
|
print(f"[ctx-cache] finish_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True)
|
||||||
# Cache assembled streaming response — before [DONE] so it always runs
|
# Cache assembled streaming response — before [DONE] so it always runs
|
||||||
if _cache is not None and _cache_enabled and content_parts:
|
if _cache is not None and _cache_enabled and content_parts:
|
||||||
assembled = orjson.dumps({
|
assembled = orjson.dumps({
|
||||||
|
|
@ -3044,10 +3371,15 @@ async def openai_completions_proxy(request: Request):
|
||||||
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||||
|
|
||||||
# 3. Async generator that streams completions data and decrements the counter
|
# 3. Async generator that streams completions data and decrements the counter
|
||||||
|
# Make the API call in handler scope (try/except inside async generators is unreliable)
|
||||||
|
try:
|
||||||
|
async_gen = await oclient.completions.create(**params)
|
||||||
|
except Exception:
|
||||||
|
await decrement_usage(endpoint, tracking_model)
|
||||||
|
raise
|
||||||
|
|
||||||
async def stream_ocompletions_response(model=model):
|
async def stream_ocompletions_response(model=model):
|
||||||
try:
|
try:
|
||||||
# The chat method returns a generator of dicts (or GenerateResponse)
|
|
||||||
async_gen = await oclient.completions.create(**params)
|
|
||||||
if stream == True:
|
if stream == True:
|
||||||
text_parts: list[str] = []
|
text_parts: list[str] = []
|
||||||
usage_snapshot: dict = {}
|
usage_snapshot: dict = {}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue