diff --git a/db.py b/db.py index aaea508..24c8480 100644 --- a/db.py +++ b/db.py @@ -1,5 +1,4 @@ -import aiosqlite -import asyncio +import aiosqlite, asyncio from typing import Optional from pathlib import Path from datetime import datetime, timezone @@ -182,22 +181,24 @@ class TokenDatabase: """Get token counts for a specific model, aggregated across all endpoints.""" db = await self._get_connection() async with self._operation_lock: - async with db.execute('SELECT endpoint, model, input_tokens, output_tokens, total_tokens FROM token_counts WHERE model = ?', (model,)) as cursor: - total_input = 0 - total_output = 0 - total_tokens = 0 - async for row in cursor: - total_input += row[2] - total_output += row[3] - total_tokens += row[4] - - if total_input > 0 or total_output > 0: + async with db.execute(''' + SELECT + 'aggregated' as endpoint, + ? as model, + SUM(input_tokens) as input_tokens, + SUM(output_tokens) as output_tokens, + SUM(total_tokens) as total_tokens + FROM token_counts + WHERE model = ? + ''', (model, model)) as cursor: + row = await cursor.fetchone() + if row is not None: return { - 'endpoint': 'aggregated', - 'model': model, - 'input_tokens': total_input, - 'output_tokens': total_output, - 'total_tokens': total_tokens + 'endpoint': row[0], + 'model': row[1], + 'input_tokens': row[2], + 'output_tokens': row[3], + 'total_tokens': row[4] } return None diff --git a/enhance.py b/enhance.py new file mode 100644 index 0000000..bb1551e --- /dev/null +++ b/enhance.py @@ -0,0 +1,38 @@ +from pydantic import BaseModel + +class feedback(BaseModel): + query_id: int + content: str + +def moe(query: str, query_id: int, response: str) -> str: + moe_prompt = f""" + User query: {query} + query_id: {query_id} + + The following is an assistant response to the original user query. Analyse the response, then critizise the response by discussing both strength and weakness of the response. + + + {response} + + """ + return moe_prompt + +def moe_select_candiadate(query: str, candidates_with_feedback: list[str]) -> str: + select_prompt = f""" + From the following responses for the user query: {query} + select the best fitting candidate and formulate a final anser for the user. + + + {candidates_with_feedback[0].message.content} + + + + {candidates_with_feedback[1].message.content} + + + + {candidates_with_feedback[2].message.content} + + """ + return select_prompt + diff --git a/router.py b/router.py index 3b60770..20377e7 100644 --- a/router.py +++ b/router.py @@ -6,7 +6,7 @@ version: 0.5 license: AGPL """ # ------------------------------------------------------------- -import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io +import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance from datetime import datetime, timezone from pathlib import Path from typing import Dict, Set, List, Optional @@ -862,6 +862,17 @@ async def chat_proxy(request: Request): """ Proxy a chat request to Ollama and stream the endpoint reply. """ + def last_user_content(messages): + """ + Given a list of dicts (e.g., messages from an API), + return the 'content' of the last dict whose 'role' is 'user'. + If no such dict exists, return None. + """ + # Reverse iterate so we stop at the first match + for msg in reversed(messages): + if msg.get("role") == "user": + return msg.get("content") + return None # 1. Parse and validate request try: body_bytes = await request.body() @@ -892,6 +903,11 @@ async def chat_proxy(request: Request): raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic + if model.startswith("moe-"): + model = model.split("moe-")[1] + opt = True + else: + opt = False endpoint = await choose_endpoint(model) is_openai_endpoint = "/v1" in endpoint if is_openai_endpoint: @@ -930,7 +946,20 @@ async def chat_proxy(request: Request): start_ts = time.perf_counter() async_gen = await oclient.chat.completions.create(**params) else: - async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive) + if opt == True: + query = last_user_content(messages) + if query: + options["temperature"] = 1 + moe_reqs = [] + responses = await asyncio.gather(*[client.chat(model=model, messages=messages, tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive) for _ in range(0,3)]) + for n,r in enumerate(responses): + moe_req = enhance.moe(query, n, r.message.content) + moe_reqs.append(moe_req) + critiques = await asyncio.gather(*[client.chat(model=model, messages=[{"role": "user", "content": moe_req}], tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive) for moe_req in moe_reqs]) + m = enhance.moe_select_candiadate(query, critiques) + async_gen = await client.chat(model=model, messages=[{"role": "user", "content": m}], tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive) + else: + async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive) if stream == True: async for chunk in async_gen: if is_openai_endpoint: @@ -1705,7 +1734,9 @@ async def openai_chat_completions_proxy(request: Request): if prompt_tok != 0 or comp_tok != 0: if not is_ext_openai_endpoint(endpoint): if not ":" in model: - local_model = model+":latest" + local_model = model if ":" in model else model + ":latest" + else: + local_model = model await token_queue.put((endpoint, local_model, prompt_tok, comp_tok)) yield b"data: [DONE]\n\n" else: @@ -1821,7 +1852,9 @@ async def openai_completions_proxy(request: Request): if prompt_tok != 0 or comp_tok != 0: if not is_ext_openai_endpoint(endpoint): if not ":" in model: - local_model = model+":latest" + local_model = model if ":" in model else model + ":latest" + else: + local_model = model await token_queue.put((endpoint, local_model, prompt_tok, comp_tok)) # Final DONE event yield b"data: [DONE]\n\n"