refactor: optimize token aggregation query and enhance chat proxy

- Refactored token aggregation query in db.py to use a single SQL query with SUM() instead of iterating through rows, improving performance
- Combined import statements in db.py and router.py to reduce lines of code
- Enhanced chat proxy in router.py to handle "moe-" prefixed models with multiple query execution and critique generation
- Added last_user_content() helper function to extract user content from messages
- Improved code readability and maintainability through these structural changes
This commit is contained in:
Alpha Nerd 2025-12-13 11:58:49 +01:00
parent 59a8ef3abb
commit 34d6abd28b
3 changed files with 93 additions and 21 deletions

View file

@ -6,7 +6,7 @@ version: 0.5
license: AGPL
"""
# -------------------------------------------------------------
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, Set, List, Optional
@ -862,6 +862,17 @@ async def chat_proxy(request: Request):
"""
Proxy a chat request to Ollama and stream the endpoint reply.
"""
def last_user_content(messages):
"""
Given a list of dicts (e.g., messages from an API),
return the 'content' of the last dict whose 'role' is 'user'.
If no such dict exists, return None.
"""
# Reverse iterate so we stop at the first match
for msg in reversed(messages):
if msg.get("role") == "user":
return msg.get("content")
return None
# 1. Parse and validate request
try:
body_bytes = await request.body()
@ -892,6 +903,11 @@ async def chat_proxy(request: Request):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
if model.startswith("moe-"):
model = model.split("moe-")[1]
opt = True
else:
opt = False
endpoint = await choose_endpoint(model)
is_openai_endpoint = "/v1" in endpoint
if is_openai_endpoint:
@ -930,7 +946,20 @@ async def chat_proxy(request: Request):
start_ts = time.perf_counter()
async_gen = await oclient.chat.completions.create(**params)
else:
async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive)
if opt == True:
query = last_user_content(messages)
if query:
options["temperature"] = 1
moe_reqs = []
responses = await asyncio.gather(*[client.chat(model=model, messages=messages, tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive) for _ in range(0,3)])
for n,r in enumerate(responses):
moe_req = enhance.moe(query, n, r.message.content)
moe_reqs.append(moe_req)
critiques = await asyncio.gather(*[client.chat(model=model, messages=[{"role": "user", "content": moe_req}], tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive) for moe_req in moe_reqs])
m = enhance.moe_select_candiadate(query, critiques)
async_gen = await client.chat(model=model, messages=[{"role": "user", "content": m}], tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive)
else:
async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive)
if stream == True:
async for chunk in async_gen:
if is_openai_endpoint:
@ -1705,7 +1734,9 @@ async def openai_chat_completions_proxy(request: Request):
if prompt_tok != 0 or comp_tok != 0:
if not is_ext_openai_endpoint(endpoint):
if not ":" in model:
local_model = model+":latest"
local_model = model if ":" in model else model + ":latest"
else:
local_model = model
await token_queue.put((endpoint, local_model, prompt_tok, comp_tok))
yield b"data: [DONE]\n\n"
else:
@ -1821,7 +1852,9 @@ async def openai_completions_proxy(request: Request):
if prompt_tok != 0 or comp_tok != 0:
if not is_ext_openai_endpoint(endpoint):
if not ":" in model:
local_model = model+":latest"
local_model = model if ":" in model else model + ":latest"
else:
local_model = model
await token_queue.put((endpoint, local_model, prompt_tok, comp_tok))
# Final DONE event
yield b"data: [DONE]\n\n"