refactor: optimize token aggregation query and enhance chat proxy

- Refactored token aggregation query in db.py to use a single SQL query with SUM() instead of iterating through rows, improving performance
- Combined import statements in db.py and router.py to reduce lines of code
- Enhanced chat proxy in router.py to handle "moe-" prefixed models with multiple query execution and critique generation
- Added last_user_content() helper function to extract user content from messages
- Improved code readability and maintainability through these structural changes
This commit is contained in:
Alpha Nerd 2025-12-13 11:58:49 +01:00
parent 59a8ef3abb
commit 34d6abd28b
3 changed files with 93 additions and 21 deletions

35
db.py
View file

@ -1,5 +1,4 @@
import aiosqlite
import asyncio
import aiosqlite, asyncio
from typing import Optional
from pathlib import Path
from datetime import datetime, timezone
@ -182,22 +181,24 @@ class TokenDatabase:
"""Get token counts for a specific model, aggregated across all endpoints."""
db = await self._get_connection()
async with self._operation_lock:
async with db.execute('SELECT endpoint, model, input_tokens, output_tokens, total_tokens FROM token_counts WHERE model = ?', (model,)) as cursor:
total_input = 0
total_output = 0
total_tokens = 0
async for row in cursor:
total_input += row[2]
total_output += row[3]
total_tokens += row[4]
if total_input > 0 or total_output > 0:
async with db.execute('''
SELECT
'aggregated' as endpoint,
? as model,
SUM(input_tokens) as input_tokens,
SUM(output_tokens) as output_tokens,
SUM(total_tokens) as total_tokens
FROM token_counts
WHERE model = ?
''', (model, model)) as cursor:
row = await cursor.fetchone()
if row is not None:
return {
'endpoint': 'aggregated',
'model': model,
'input_tokens': total_input,
'output_tokens': total_output,
'total_tokens': total_tokens
'endpoint': row[0],
'model': row[1],
'input_tokens': row[2],
'output_tokens': row[3],
'total_tokens': row[4]
}
return None

38
enhance.py Normal file
View file

@ -0,0 +1,38 @@
from pydantic import BaseModel
class feedback(BaseModel):
query_id: int
content: str
def moe(query: str, query_id: int, response: str) -> str:
moe_prompt = f"""
User query: {query}
query_id: {query_id}
The following is an assistant response to the original user query. Analyse the response, then critizise the response by discussing both strength and weakness of the response.
<assistant_response>
{response}
</assistant_response>
"""
return moe_prompt
def moe_select_candiadate(query: str, candidates_with_feedback: list[str]) -> str:
select_prompt = f"""
From the following responses for the user query: {query}
select the best fitting candidate and formulate a final anser for the user.
<candidate_0>
{candidates_with_feedback[0].message.content}
</candidate_0>
<candidate_1>
{candidates_with_feedback[1].message.content}
</candidate_1>
<candidate_2>
{candidates_with_feedback[2].message.content}
</candidate_2>
"""
return select_prompt

View file

@ -6,7 +6,7 @@ version: 0.5
license: AGPL
"""
# -------------------------------------------------------------
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, Set, List, Optional
@ -862,6 +862,17 @@ async def chat_proxy(request: Request):
"""
Proxy a chat request to Ollama and stream the endpoint reply.
"""
def last_user_content(messages):
"""
Given a list of dicts (e.g., messages from an API),
return the 'content' of the last dict whose 'role' is 'user'.
If no such dict exists, return None.
"""
# Reverse iterate so we stop at the first match
for msg in reversed(messages):
if msg.get("role") == "user":
return msg.get("content")
return None
# 1. Parse and validate request
try:
body_bytes = await request.body()
@ -892,6 +903,11 @@ async def chat_proxy(request: Request):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
if model.startswith("moe-"):
model = model.split("moe-")[1]
opt = True
else:
opt = False
endpoint = await choose_endpoint(model)
is_openai_endpoint = "/v1" in endpoint
if is_openai_endpoint:
@ -930,7 +946,20 @@ async def chat_proxy(request: Request):
start_ts = time.perf_counter()
async_gen = await oclient.chat.completions.create(**params)
else:
async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive)
if opt == True:
query = last_user_content(messages)
if query:
options["temperature"] = 1
moe_reqs = []
responses = await asyncio.gather(*[client.chat(model=model, messages=messages, tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive) for _ in range(0,3)])
for n,r in enumerate(responses):
moe_req = enhance.moe(query, n, r.message.content)
moe_reqs.append(moe_req)
critiques = await asyncio.gather(*[client.chat(model=model, messages=[{"role": "user", "content": moe_req}], tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive) for moe_req in moe_reqs])
m = enhance.moe_select_candiadate(query, critiques)
async_gen = await client.chat(model=model, messages=[{"role": "user", "content": m}], tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive)
else:
async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive)
if stream == True:
async for chunk in async_gen:
if is_openai_endpoint:
@ -1705,7 +1734,9 @@ async def openai_chat_completions_proxy(request: Request):
if prompt_tok != 0 or comp_tok != 0:
if not is_ext_openai_endpoint(endpoint):
if not ":" in model:
local_model = model+":latest"
local_model = model if ":" in model else model + ":latest"
else:
local_model = model
await token_queue.put((endpoint, local_model, prompt_tok, comp_tok))
yield b"data: [DONE]\n\n"
else:
@ -1821,7 +1852,9 @@ async def openai_completions_proxy(request: Request):
if prompt_tok != 0 or comp_tok != 0:
if not is_ext_openai_endpoint(endpoint):
if not ":" in model:
local_model = model+":latest"
local_model = model if ":" in model else model + ":latest"
else:
local_model = model
await token_queue.put((endpoint, local_model, prompt_tok, comp_tok))
# Final DONE event
yield b"data: [DONE]\n\n"