nomyo-router/requests/chat.py

218 lines
12 KiB
Python

"""High-level chat request orchestrator.
``_make_chat_request`` is the shared core that:
* picks an endpoint via ``choose_endpoint`` (which atomically reserves a slot),
* dispatches to either the native Ollama client or an OpenAI-compatible
client based on endpoint type,
* applies reactive context trimming when the backend rejects with
``exceed_context_size_error``,
* counts tokens for billing/SSE,
* always releases the reservation via ``decrement_usage`` in ``finally``.
``_make_moe_requests`` builds on it to implement the
"3 responses + 3 critiques + 1 final" mixture-of-experts dance.
"""
import asyncio
import re
import time
import ollama
import enhance
from config import get_config
from state import default_headers, token_queue
from context_window import _trim_messages_for_context, _calibrated_trim_target
from backends.normalize import is_openai_compatible
from backends.sessions import _make_openai_client
from routing import choose_endpoint, decrement_usage
from requests.messages import (
get_last_user_content,
transform_tool_calls_to_openai,
transform_images_to_data_urls,
_strip_assistant_prefill,
_strip_images_from_messages,
_accumulate_openai_tc_delta,
_build_ollama_tool_calls,
)
from requests.rechunk import rechunk
async def _make_chat_request(model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
"""
Helper function to make a chat request to a specific endpoint.
Handles endpoint selection, client creation, usage tracking, and request execution.
"""
config = get_config()
endpoint, tracking_model = await choose_endpoint(model) # selects and atomically reserves
use_openai = is_openai_compatible(endpoint)
if use_openai:
if ":latest" in model:
model = model.split(":latest")[0]
if messages:
if any("images" in m for m in messages):
messages = await asyncio.to_thread(transform_images_to_data_urls, messages)
messages = transform_tool_calls_to_openai(messages)
messages = _strip_assistant_prefill(messages)
params = {
"messages": messages,
"model": model,
}
optional_params = {
"tools": tools,
"stream": stream,
"stream_options": {"include_usage": True} if stream else None,
"max_tokens": options.get("num_predict") if options and "num_predict" in options else None,
"frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None,
"presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None,
"seed": options.get("seed") if options and "seed" in options else None,
"stop": options.get("stop") if options and "stop" in options else None,
"top_p": options.get("top_p") if options and "top_p" in options else None,
"temperature": options.get("temperature") if options and "temperature" in options else None,
"response_format": {"type": "json_schema", "json_schema": format} if format is not None else None
}
params.update({k: v for k, v in optional_params.items() if v is not None})
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
try:
if use_openai:
start_ts = time.perf_counter()
try:
response = await oclient.chat.completions.create(**params)
except Exception as e:
_e_str = str(e)
print(f"[_make_chat_request] caught {type(e).__name__}: {_e_str[:200]}")
if "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str:
err_body = getattr(e, "body", {}) or {}
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
n_ctx_limit = err_detail.get("n_ctx", 0)
actual_tokens = err_detail.get("n_prompt_tokens", 0)
if not n_ctx_limit:
_m = re.search(r"'n_ctx':\s*(\d+)", _e_str)
if _m:
n_ctx_limit = int(_m.group(1))
_m = re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str)
if _m:
actual_tokens = int(_m.group(1))
if not n_ctx_limit:
raise
msgs_to_trim = params.get("messages", [])
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
trimmed = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
print(f"[_make_chat_request] Context exceeded ({actual_tokens}/{n_ctx_limit} tokens, tiktoken_target={cal_target}), dropped {len(msgs_to_trim) - len(trimmed)} oldest message(s) and retrying")
try:
response = await oclient.chat.completions.create(**{**params, "messages": trimmed})
except Exception as e2:
if "exceed_context_size_error" in str(e2) or "exceeds the available context size" in str(e2):
print(f"[_make_chat_request] Context still exceeded after trimming, also stripping tools")
params_no_tools = {k: v for k, v in params.items() if k not in ("tools", "tool_choice")}
response = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed})
else:
raise
elif "image input is not supported" in _e_str:
print(f"[_make_chat_request] Model {model} doesn't support images, retrying with text-only messages")
params = {**params, "messages": _strip_images_from_messages(params.get("messages", []))}
response = await oclient.chat.completions.create(**params)
else:
raise
if stream:
# For streaming, we need to collect all chunks
chunks = []
tc_acc = {} # accumulate tool-call deltas
async for chunk in response:
chunks.append(chunk)
_accumulate_openai_tc_delta(chunk, tc_acc)
prompt_tok = 0
comp_tok = 0
if chunk.usage is not None:
prompt_tok = chunk.usage.prompt_tokens or 0
comp_tok = chunk.usage.completion_tokens or 0
else:
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
# Convert to Ollama format
if chunks:
response = rechunk.openai_chat_completion2ollama(chunks[-1], stream, start_ts)
# Inject fully-accumulated tool calls into the final response
if tc_acc and response.message:
response.message.tool_calls = _build_ollama_tool_calls(tc_acc)
else:
prompt_tok = 0
comp_tok = 0
if response.usage is not None:
prompt_tok = response.usage.prompt_tokens or 0
comp_tok = response.usage.completion_tokens or 0
else:
llama_usage = rechunk.extract_usage_from_llama_timings(response)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
response = rechunk.openai_chat_completion2ollama(response, stream, start_ts)
else:
response = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive)
if stream:
# For streaming, collect all chunks
chunks = []
async for chunk in response:
chunks.append(chunk)
prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
if chunks:
response = chunks[-1]
else:
prompt_tok = response.prompt_eval_count or 0
comp_tok = response.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
return response
finally:
await decrement_usage(endpoint, tracking_model)
async def _make_moe_requests(model: str, messages: list, tools=None, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
"""
Helper function to make MOE (Multiple Opinions Ensemble) requests.
Generates 3 responses, 3 critiques, and returns the final selected response.
"""
query = get_last_user_content(messages)
if not query:
raise ValueError("No user query found in messages")
if options is None:
options = {}
options["temperature"] = 1
moe_reqs = []
# Generate 3 responses — choose_endpoint is called inside _make_chat_request and
# atomically reserves a slot, so all 3 tasks see each other's load immediately.
response1_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
response2_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
response3_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
responses = await asyncio.gather(response1_task, response2_task, response3_task)
for n, r in enumerate(responses):
moe_req = enhance.moe(query, n, r.message.content)
moe_reqs.append(moe_req)
# Generate 3 critiques
critique1_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
critique2_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
critique3_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
critiques = await asyncio.gather(critique1_task, critique2_task, critique3_task)
# Select final response
m = enhance.moe_select_candidate(query, critiques)
# Generate final response
return await _make_chat_request(model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)