diff --git a/enhance.py b/enhance.py
index bb1551e..9be1fef 100644
--- a/enhance.py
+++ b/enhance.py
@@ -9,30 +9,41 @@ def moe(query: str, query_id: int, response: str) -> str:
User query: {query}
query_id: {query_id}
- The following is an assistant response to the original user query. Analyse the response, then critizise the response by discussing both strength and weakness of the response.
+ The following is an assistant response to the original user query. Analyse the response, then criticize the it by discussing both strengths and weaknesses. Do not add additional commentary.
{response}
+
+ Respond in the format:
+ original_response
+ ---
+ Response Analysis:
+ your analysis
"""
return moe_prompt
-def moe_select_candiadate(query: str, candidates_with_feedback: list[str]) -> str:
+def moe_select_candidate(query: str, candidates: list[str]) -> str:
+ if not candidates:
+ raise ValueError("No candidates supplied")
+
+ candidate_sections = ""
+ for i, cand in enumerate(candidates[:3], start=0):
+ candidate_sections += f"""
+
+ {cand.message.content}
+
+ """
+
+ # Strict instruction: "Respond **only** with the final answer."
select_prompt = f"""
From the following responses for the user query: {query}
- select the best fitting candidate and formulate a final anser for the user.
-
- {candidates_with_feedback[0].message.content}
-
+ {candidate_sections}
-
- {candidates_with_feedback[1].message.content}
-
-
-
- {candidates_with_feedback[2].message.content}
-
+ Choose the best candidate and output the final answer in the language of the query.
+ **Do NOT** mention candidate numbers, strengths, weaknesses, or any other commentary.
+ Just give the final answer—nothing else.
"""
- return select_prompt
+ return select_prompt.strip()
diff --git a/router.py b/router.py
index fa70991..fb3a0fb 100644
--- a/router.py
+++ b/router.py
@@ -441,6 +441,155 @@ async def decrement_usage(endpoint: str, model: str) -> None:
# usage_counts.pop(endpoint, None)
await publish_snapshot()
+async def _make_chat_request(endpoint: str, model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
+ """
+ Helper function to make a chat request to a specific endpoint.
+ Handles endpoint selection, client creation, usage tracking, and request execution.
+ """
+ is_openai_endpoint = "/v1" in endpoint
+ if is_openai_endpoint:
+ if ":latest" in model:
+ model = model.split(":latest")[0]
+ if messages:
+ messages = transform_images_to_data_urls(messages)
+ params = {
+ "messages": messages,
+ "model": model,
+ }
+ optional_params = {
+ "tools": tools,
+ "stream": stream,
+ "stream_options": {"include_usage": True} if stream else None,
+ "max_tokens": options.get("num_predict") if options and "num_predict" in options else None,
+ "frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None,
+ "presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None,
+ "seed": options.get("seed") if options and "seed" in options else None,
+ "stop": options.get("stop") if options and "stop" in options else None,
+ "top_p": options.get("top_p") if options and "top_p" in options else None,
+ "temperature": options.get("temperature") if options and "temperature" in options else None,
+ "response_format": {"type": "json_schema", "json_schema": format} if format is not None else None
+ }
+ params.update({k: v for k, v in optional_params.items() if v is not None})
+ oclient = openai.AsyncOpenAI(base_url=endpoint, default_headers=default_headers, api_key=config.api_keys[endpoint])
+ else:
+ client = ollama.AsyncClient(host=endpoint)
+
+ await increment_usage(endpoint, model)
+
+ try:
+ if is_openai_endpoint:
+ start_ts = time.perf_counter()
+ response = await oclient.chat.completions.create(**params)
+ if stream:
+ # For streaming, we need to collect all chunks
+ chunks = []
+ async for chunk in response:
+ chunks.append(chunk)
+ if chunk.usage is not None:
+ prompt_tok = chunk.usage.prompt_tokens or 0
+ comp_tok = chunk.usage.completion_tokens or 0
+ if prompt_tok != 0 or comp_tok != 0:
+ await token_queue.put((endpoint, model, prompt_tok, comp_tok))
+ # Convert to Ollama format
+ if chunks:
+ response = rechunk.openai_chat_completion2ollama(chunks[-1], stream, start_ts)
+ else:
+ prompt_tok = response.usage.prompt_tokens or 0
+ comp_tok = response.usage.completion_tokens or 0
+ if prompt_tok != 0 or comp_tok != 0:
+ await token_queue.put((endpoint, model, prompt_tok, comp_tok))
+ response = rechunk.openai_chat_completion2ollama(response, stream, start_ts)
+ else:
+ response = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive)
+ if stream:
+ # For streaming, collect all chunks
+ chunks = []
+ async for chunk in response:
+ chunks.append(chunk)
+ prompt_tok = chunk.prompt_eval_count or 0
+ comp_tok = chunk.eval_count or 0
+ if prompt_tok != 0 or comp_tok != 0:
+ await token_queue.put((endpoint, model, prompt_tok, comp_tok))
+ if chunks:
+ response = chunks[-1]
+ else:
+ prompt_tok = response.prompt_eval_count or 0
+ comp_tok = response.eval_count or 0
+ if prompt_tok != 0 or comp_tok != 0:
+ await token_queue.put((endpoint, model, prompt_tok, comp_tok))
+
+ return response
+ finally:
+ await decrement_usage(endpoint, model)
+
+def get_last_user_content(messages):
+ """
+ Given a list of dicts (e.g., messages from an API),
+ return the 'content' of the last dict whose 'role' is 'user'.
+ If no such dict exists, return None.
+ """
+ # Reverse iterate so we stop at the first match
+ for msg in reversed(messages):
+ if msg.get("role") == "user":
+ return msg.get("content")
+ return None
+
+async def _make_moe_requests(model: str, messages: list, tools=None, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
+ """
+ Helper function to make MOE (Multiple Opinions Ensemble) requests.
+ Generates 3 responses, 3 critiques, and returns the final selected response.
+ """
+ query = get_last_user_content(messages)
+ if not query:
+ raise ValueError("No user query found in messages")
+
+ if options is None:
+ options = {}
+ options["temperature"] = 1
+
+ moe_reqs = []
+
+ # Generate 3 responses
+ response1_endpoint = await choose_endpoint(model)
+ response1_task = asyncio.create_task(_make_chat_request(response1_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
+ await asyncio.sleep(0.01) # Small delay to allow usage count to update
+
+ response2_endpoint = await choose_endpoint(model)
+ response2_task = asyncio.create_task(_make_chat_request(response2_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
+ await asyncio.sleep(0.01) # Small delay to allow usage count to update
+
+ response3_endpoint = await choose_endpoint(model)
+ response3_task = asyncio.create_task(_make_chat_request(response3_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
+ await asyncio.sleep(0.01) # Small delay to allow usage count to update
+
+ responses = await asyncio.gather(response1_task, response2_task, response3_task)
+
+ for n, r in enumerate(responses):
+ moe_req = enhance.moe(query, n, r.message.content)
+ moe_reqs.append(moe_req)
+
+ # Generate 3 critiques
+ critique1_endpoint = await choose_endpoint(model)
+ critique1_task = asyncio.create_task(_make_chat_request(critique1_endpoint, model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
+ await asyncio.sleep(0.01) # Small delay to allow usage count to update
+
+ critique2_endpoint = await choose_endpoint(model)
+ critique2_task = asyncio.create_task(_make_chat_request(critique2_endpoint, model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
+ await asyncio.sleep(0.01) # Small delay to allow usage count to update
+
+ critique3_endpoint = await choose_endpoint(model)
+ critique3_task = asyncio.create_task(_make_chat_request(critique3_endpoint, model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
+ await asyncio.sleep(0.01) # Small delay to allow usage count to update
+
+ critiques = await asyncio.gather(critique1_task, critique2_task, critique3_task)
+
+ # Select final response
+ m = enhance.moe_select_candidate(query, critiques)
+
+ # Generate final response
+ final_endpoint = await choose_endpoint(model)
+ return await _make_chat_request(final_endpoint, model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)
+
def iso8601_ns():
ns = time.time_ns()
sec, ns_rem = divmod(ns, 1_000_000_000)
@@ -687,7 +836,7 @@ async def choose_endpoint(model: str) -> str:
if model in models
]
- # 6️⃣
+ # 6️⃣
if not candidate_endpoints:
if ":latest" in model: #ollama naming convention not applicable to openai
model_without_latest = model.split(":latest")[0]
@@ -696,7 +845,9 @@ async def choose_endpoint(model: str) -> str:
if model_without_latest in models and is_ext_openai_endpoint(ep)
]
if not candidate_endpoints:
- model = model + ":latest"
+ # Only add :latest suffix if model doesn't already have a version suffix
+ if ":" not in model:
+ model = model + ":latest"
candidate_endpoints = [
ep for ep, models in zip(config.endpoints, advertised_sets)
if model in models
@@ -862,17 +1013,6 @@ async def chat_proxy(request: Request):
"""
Proxy a chat request to Ollama and stream the endpoint reply.
"""
- def last_user_content(messages):
- """
- Given a list of dicts (e.g., messages from an API),
- return the 'content' of the last dict whose 'role' is 'user'.
- If no such dict exists, return None.
- """
- # Reverse iterate so we stop at the first match
- for msg in reversed(messages):
- if msg.get("role") == "user":
- return msg.get("content")
- return None
# 1. Parse and validate request
try:
body_bytes = await request.body()
@@ -947,17 +1087,8 @@ async def chat_proxy(request: Request):
async_gen = await oclient.chat.completions.create(**params)
else:
if opt == True:
- query = last_user_content(messages)
- if query:
- options["temperature"] = 1
- moe_reqs = []
- responses = await asyncio.gather(*[client.chat(model=model, messages=messages, tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive) for _ in range(0,3)])
- for n,r in enumerate(responses):
- moe_req = enhance.moe(query, n, r.message.content)
- moe_reqs.append(moe_req)
- critiques = await asyncio.gather(*[client.chat(model=model, messages=[{"role": "user", "content": moe_req}], tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive) for moe_req in moe_reqs])
- m = enhance.moe_select_candiadate(query, critiques)
- async_gen = await client.chat(model=model, messages=[{"role": "user", "content": m}], tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive)
+ # Use the dedicated MOE helper function
+ async_gen = await _make_moe_requests(model, messages, tools, think, _format, options, keep_alive)
else:
async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive)
if stream == True: