fix(enhance.py): correct typo in function name from 'moe_select_candiadate' to 'moe_select_candidate'

feat(router.py): add helper function _make_chat_request for handling enhancing chat requests to endpoints
This commit is contained in:
Alpha Nerd 2025-12-15 10:35:56 +01:00
parent 5eb5490d16
commit 19a13cc613
2 changed files with 152 additions and 23 deletions

View file

@ -17,7 +17,7 @@ def moe(query: str, query_id: int, response: str) -> str:
"""
return moe_prompt
def moe_select_candiadate(query: str, candidates_with_feedback: list[str]) -> str:
def moe_select_candidate(query: str, candidates_with_feedback: list[str]) -> str:
select_prompt = f"""
From the following responses for the user query: {query}
select the best fitting candidate and formulate a final anser for the user.

173
router.py
View file

@ -441,6 +441,155 @@ async def decrement_usage(endpoint: str, model: str) -> None:
# usage_counts.pop(endpoint, None)
await publish_snapshot()
async def _make_chat_request(endpoint: str, model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
"""
Helper function to make a chat request to a specific endpoint.
Handles endpoint selection, client creation, usage tracking, and request execution.
"""
is_openai_endpoint = "/v1" in endpoint
if is_openai_endpoint:
if ":latest" in model:
model = model.split(":latest")[0]
if messages:
messages = transform_images_to_data_urls(messages)
params = {
"messages": messages,
"model": model,
}
optional_params = {
"tools": tools,
"stream": stream,
"stream_options": {"include_usage": True} if stream else None,
"max_tokens": options.get("num_predict") if options and "num_predict" in options else None,
"frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None,
"presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None,
"seed": options.get("seed") if options and "seed" in options else None,
"stop": options.get("stop") if options and "stop" in options else None,
"top_p": options.get("top_p") if options and "top_p" in options else None,
"temperature": options.get("temperature") if options and "temperature" in options else None,
"response_format": {"type": "json_schema", "json_schema": format} if format is not None else None
}
params.update({k: v for k, v in optional_params.items() if v is not None})
oclient = openai.AsyncOpenAI(base_url=endpoint, default_headers=default_headers, api_key=config.api_keys[endpoint])
else:
client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, model)
try:
if is_openai_endpoint:
start_ts = time.perf_counter()
response = await oclient.chat.completions.create(**params)
if stream:
# For streaming, we need to collect all chunks
chunks = []
async for chunk in response:
chunks.append(chunk)
if chunk.usage is not None:
prompt_tok = chunk.usage.prompt_tokens or 0
comp_tok = chunk.usage.completion_tokens or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
# Convert to Ollama format
if chunks:
response = rechunk.openai_chat_completion2ollama(chunks[-1], stream, start_ts)
else:
prompt_tok = response.usage.prompt_tokens or 0
comp_tok = response.usage.completion_tokens or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
response = rechunk.openai_chat_completion2ollama(response, stream, start_ts)
else:
response = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive)
if stream:
# For streaming, collect all chunks
chunks = []
async for chunk in response:
chunks.append(chunk)
prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
if chunks:
response = chunks[-1]
else:
prompt_tok = response.prompt_eval_count or 0
comp_tok = response.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
return response
finally:
await decrement_usage(endpoint, model)
def get_last_user_content(messages):
"""
Given a list of dicts (e.g., messages from an API),
return the 'content' of the last dict whose 'role' is 'user'.
If no such dict exists, return None.
"""
# Reverse iterate so we stop at the first match
for msg in reversed(messages):
if msg.get("role") == "user":
return msg.get("content")
return None
async def _make_moe_requests(model: str, messages: list, tools=None, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
"""
Helper function to make MOE (Multiple Opinions Ensemble) requests.
Generates 3 responses, 3 critiques, and returns the final selected response.
"""
query = get_last_user_content(messages)
if not query:
raise ValueError("No user query found in messages")
if options is None:
options = {}
options["temperature"] = 1
moe_reqs = []
# Generate 3 responses
response1_endpoint = await choose_endpoint(model)
response1_task = asyncio.create_task(_make_chat_request(response1_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
await asyncio.sleep(0.01) # Small delay to allow usage count to update
response2_endpoint = await choose_endpoint(model)
response2_task = asyncio.create_task(_make_chat_request(response2_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
await asyncio.sleep(0.01) # Small delay to allow usage count to update
response3_endpoint = await choose_endpoint(model)
response3_task = asyncio.create_task(_make_chat_request(response3_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
await asyncio.sleep(0.01) # Small delay to allow usage count to update
responses = await asyncio.gather(response1_task, response2_task, response3_task)
for n, r in enumerate(responses):
moe_req = enhance.moe(query, n, r.message.content)
moe_reqs.append(moe_req)
# Generate 3 critiques
critique1_endpoint = await choose_endpoint(model)
critique1_task = asyncio.create_task(_make_chat_request(critique1_endpoint, model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
await asyncio.sleep(0.01) # Small delay to allow usage count to update
critique2_endpoint = await choose_endpoint(model)
critique2_task = asyncio.create_task(_make_chat_request(critique2_endpoint, model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
await asyncio.sleep(0.01) # Small delay to allow usage count to update
critique3_endpoint = await choose_endpoint(model)
critique3_task = asyncio.create_task(_make_chat_request(critique3_endpoint, model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
await asyncio.sleep(0.01) # Small delay to allow usage count to update
critiques = await asyncio.gather(critique1_task, critique2_task, critique3_task)
# Select final response
m = enhance.moe_select_candidate(query, critiques)
# Generate final response
final_endpoint = await choose_endpoint(model)
return await _make_chat_request(final_endpoint, model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)
def iso8601_ns():
ns = time.time_ns()
sec, ns_rem = divmod(ns, 1_000_000_000)
@ -864,17 +1013,6 @@ async def chat_proxy(request: Request):
"""
Proxy a chat request to Ollama and stream the endpoint reply.
"""
def last_user_content(messages):
"""
Given a list of dicts (e.g., messages from an API),
return the 'content' of the last dict whose 'role' is 'user'.
If no such dict exists, return None.
"""
# Reverse iterate so we stop at the first match
for msg in reversed(messages):
if msg.get("role") == "user":
return msg.get("content")
return None
# 1. Parse and validate request
try:
body_bytes = await request.body()
@ -949,17 +1087,8 @@ async def chat_proxy(request: Request):
async_gen = await oclient.chat.completions.create(**params)
else:
if opt == True:
query = last_user_content(messages)
if query:
options["temperature"] = 1
moe_reqs = []
responses = await asyncio.gather(*[client.chat(model=model, messages=messages, tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive) for _ in range(0,3)])
for n,r in enumerate(responses):
moe_req = enhance.moe(query, n, r.message.content)
moe_reqs.append(moe_req)
critiques = await asyncio.gather(*[client.chat(model=model, messages=[{"role": "user", "content": moe_req}], tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive) for moe_req in moe_reqs])
m = enhance.moe_select_candiadate(query, critiques)
async_gen = await client.chat(model=model, messages=[{"role": "user", "content": m}], tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive)
# Use the dedicated MOE helper function
async_gen = await _make_moe_requests(model, messages, tools, think, _format, options, keep_alive)
else:
async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive)
if stream == True: