diff --git a/enhance.py b/enhance.py index bb1551e..9be1fef 100644 --- a/enhance.py +++ b/enhance.py @@ -9,30 +9,41 @@ def moe(query: str, query_id: int, response: str) -> str: User query: {query} query_id: {query_id} - The following is an assistant response to the original user query. Analyse the response, then critizise the response by discussing both strength and weakness of the response. + The following is an assistant response to the original user query. Analyse the response, then criticize the it by discussing both strengths and weaknesses. Do not add additional commentary. {response} + + Respond in the format: + original_response + --- + Response Analysis: + your analysis """ return moe_prompt -def moe_select_candiadate(query: str, candidates_with_feedback: list[str]) -> str: +def moe_select_candidate(query: str, candidates: list[str]) -> str: + if not candidates: + raise ValueError("No candidates supplied") + + candidate_sections = "" + for i, cand in enumerate(candidates[:3], start=0): + candidate_sections += f""" + + {cand.message.content} + + """ + + # Strict instruction: "Respond **only** with the final answer." select_prompt = f""" From the following responses for the user query: {query} - select the best fitting candidate and formulate a final anser for the user. - - {candidates_with_feedback[0].message.content} - + {candidate_sections} - - {candidates_with_feedback[1].message.content} - - - - {candidates_with_feedback[2].message.content} - + Choose the best candidate and output the final answer in the language of the query. + **Do NOT** mention candidate numbers, strengths, weaknesses, or any other commentary. + Just give the final answer—nothing else. """ - return select_prompt + return select_prompt.strip() diff --git a/router.py b/router.py index fa70991..fb3a0fb 100644 --- a/router.py +++ b/router.py @@ -441,6 +441,155 @@ async def decrement_usage(endpoint: str, model: str) -> None: # usage_counts.pop(endpoint, None) await publish_snapshot() +async def _make_chat_request(endpoint: str, model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse: + """ + Helper function to make a chat request to a specific endpoint. + Handles endpoint selection, client creation, usage tracking, and request execution. + """ + is_openai_endpoint = "/v1" in endpoint + if is_openai_endpoint: + if ":latest" in model: + model = model.split(":latest")[0] + if messages: + messages = transform_images_to_data_urls(messages) + params = { + "messages": messages, + "model": model, + } + optional_params = { + "tools": tools, + "stream": stream, + "stream_options": {"include_usage": True} if stream else None, + "max_tokens": options.get("num_predict") if options and "num_predict" in options else None, + "frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None, + "presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None, + "seed": options.get("seed") if options and "seed" in options else None, + "stop": options.get("stop") if options and "stop" in options else None, + "top_p": options.get("top_p") if options and "top_p" in options else None, + "temperature": options.get("temperature") if options and "temperature" in options else None, + "response_format": {"type": "json_schema", "json_schema": format} if format is not None else None + } + params.update({k: v for k, v in optional_params.items() if v is not None}) + oclient = openai.AsyncOpenAI(base_url=endpoint, default_headers=default_headers, api_key=config.api_keys[endpoint]) + else: + client = ollama.AsyncClient(host=endpoint) + + await increment_usage(endpoint, model) + + try: + if is_openai_endpoint: + start_ts = time.perf_counter() + response = await oclient.chat.completions.create(**params) + if stream: + # For streaming, we need to collect all chunks + chunks = [] + async for chunk in response: + chunks.append(chunk) + if chunk.usage is not None: + prompt_tok = chunk.usage.prompt_tokens or 0 + comp_tok = chunk.usage.completion_tokens or 0 + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + # Convert to Ollama format + if chunks: + response = rechunk.openai_chat_completion2ollama(chunks[-1], stream, start_ts) + else: + prompt_tok = response.usage.prompt_tokens or 0 + comp_tok = response.usage.completion_tokens or 0 + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + response = rechunk.openai_chat_completion2ollama(response, stream, start_ts) + else: + response = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive) + if stream: + # For streaming, collect all chunks + chunks = [] + async for chunk in response: + chunks.append(chunk) + prompt_tok = chunk.prompt_eval_count or 0 + comp_tok = chunk.eval_count or 0 + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + if chunks: + response = chunks[-1] + else: + prompt_tok = response.prompt_eval_count or 0 + comp_tok = response.eval_count or 0 + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + + return response + finally: + await decrement_usage(endpoint, model) + +def get_last_user_content(messages): + """ + Given a list of dicts (e.g., messages from an API), + return the 'content' of the last dict whose 'role' is 'user'. + If no such dict exists, return None. + """ + # Reverse iterate so we stop at the first match + for msg in reversed(messages): + if msg.get("role") == "user": + return msg.get("content") + return None + +async def _make_moe_requests(model: str, messages: list, tools=None, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse: + """ + Helper function to make MOE (Multiple Opinions Ensemble) requests. + Generates 3 responses, 3 critiques, and returns the final selected response. + """ + query = get_last_user_content(messages) + if not query: + raise ValueError("No user query found in messages") + + if options is None: + options = {} + options["temperature"] = 1 + + moe_reqs = [] + + # Generate 3 responses + response1_endpoint = await choose_endpoint(model) + response1_task = asyncio.create_task(_make_chat_request(response1_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) + await asyncio.sleep(0.01) # Small delay to allow usage count to update + + response2_endpoint = await choose_endpoint(model) + response2_task = asyncio.create_task(_make_chat_request(response2_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) + await asyncio.sleep(0.01) # Small delay to allow usage count to update + + response3_endpoint = await choose_endpoint(model) + response3_task = asyncio.create_task(_make_chat_request(response3_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) + await asyncio.sleep(0.01) # Small delay to allow usage count to update + + responses = await asyncio.gather(response1_task, response2_task, response3_task) + + for n, r in enumerate(responses): + moe_req = enhance.moe(query, n, r.message.content) + moe_reqs.append(moe_req) + + # Generate 3 critiques + critique1_endpoint = await choose_endpoint(model) + critique1_task = asyncio.create_task(_make_chat_request(critique1_endpoint, model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) + await asyncio.sleep(0.01) # Small delay to allow usage count to update + + critique2_endpoint = await choose_endpoint(model) + critique2_task = asyncio.create_task(_make_chat_request(critique2_endpoint, model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) + await asyncio.sleep(0.01) # Small delay to allow usage count to update + + critique3_endpoint = await choose_endpoint(model) + critique3_task = asyncio.create_task(_make_chat_request(critique3_endpoint, model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) + await asyncio.sleep(0.01) # Small delay to allow usage count to update + + critiques = await asyncio.gather(critique1_task, critique2_task, critique3_task) + + # Select final response + m = enhance.moe_select_candidate(query, critiques) + + # Generate final response + final_endpoint = await choose_endpoint(model) + return await _make_chat_request(final_endpoint, model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive) + def iso8601_ns(): ns = time.time_ns() sec, ns_rem = divmod(ns, 1_000_000_000) @@ -687,7 +836,7 @@ async def choose_endpoint(model: str) -> str: if model in models ] - # 6️⃣ + # 6️⃣ if not candidate_endpoints: if ":latest" in model: #ollama naming convention not applicable to openai model_without_latest = model.split(":latest")[0] @@ -696,7 +845,9 @@ async def choose_endpoint(model: str) -> str: if model_without_latest in models and is_ext_openai_endpoint(ep) ] if not candidate_endpoints: - model = model + ":latest" + # Only add :latest suffix if model doesn't already have a version suffix + if ":" not in model: + model = model + ":latest" candidate_endpoints = [ ep for ep, models in zip(config.endpoints, advertised_sets) if model in models @@ -862,17 +1013,6 @@ async def chat_proxy(request: Request): """ Proxy a chat request to Ollama and stream the endpoint reply. """ - def last_user_content(messages): - """ - Given a list of dicts (e.g., messages from an API), - return the 'content' of the last dict whose 'role' is 'user'. - If no such dict exists, return None. - """ - # Reverse iterate so we stop at the first match - for msg in reversed(messages): - if msg.get("role") == "user": - return msg.get("content") - return None # 1. Parse and validate request try: body_bytes = await request.body() @@ -947,17 +1087,8 @@ async def chat_proxy(request: Request): async_gen = await oclient.chat.completions.create(**params) else: if opt == True: - query = last_user_content(messages) - if query: - options["temperature"] = 1 - moe_reqs = [] - responses = await asyncio.gather(*[client.chat(model=model, messages=messages, tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive) for _ in range(0,3)]) - for n,r in enumerate(responses): - moe_req = enhance.moe(query, n, r.message.content) - moe_reqs.append(moe_req) - critiques = await asyncio.gather(*[client.chat(model=model, messages=[{"role": "user", "content": moe_req}], tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive) for moe_req in moe_reqs]) - m = enhance.moe_select_candiadate(query, critiques) - async_gen = await client.chat(model=model, messages=[{"role": "user", "content": m}], tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive) + # Use the dedicated MOE helper function + async_gen = await _make_moe_requests(model, messages, tools, think, _format, options, keep_alive) else: async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive) if stream == True: