refactor: optimize token aggregation query and enhance chat proxy
- Refactored token aggregation query in db.py to use a single SQL query with SUM() instead of iterating through rows, improving performance - Combined import statements in db.py and router.py to reduce lines of code - Enhanced chat proxy in router.py to handle "moe-" prefixed models with multiple query execution and critique generation - Added last_user_content() helper function to extract user content from messages - Improved code readability and maintainability through these structural changes
This commit is contained in:
parent
59a8ef3abb
commit
34d6abd28b
3 changed files with 93 additions and 21 deletions
35
db.py
35
db.py
|
|
@ -1,5 +1,4 @@
|
|||
import aiosqlite
|
||||
import asyncio
|
||||
import aiosqlite, asyncio
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
|
@ -182,22 +181,24 @@ class TokenDatabase:
|
|||
"""Get token counts for a specific model, aggregated across all endpoints."""
|
||||
db = await self._get_connection()
|
||||
async with self._operation_lock:
|
||||
async with db.execute('SELECT endpoint, model, input_tokens, output_tokens, total_tokens FROM token_counts WHERE model = ?', (model,)) as cursor:
|
||||
total_input = 0
|
||||
total_output = 0
|
||||
total_tokens = 0
|
||||
async for row in cursor:
|
||||
total_input += row[2]
|
||||
total_output += row[3]
|
||||
total_tokens += row[4]
|
||||
|
||||
if total_input > 0 or total_output > 0:
|
||||
async with db.execute('''
|
||||
SELECT
|
||||
'aggregated' as endpoint,
|
||||
? as model,
|
||||
SUM(input_tokens) as input_tokens,
|
||||
SUM(output_tokens) as output_tokens,
|
||||
SUM(total_tokens) as total_tokens
|
||||
FROM token_counts
|
||||
WHERE model = ?
|
||||
''', (model, model)) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row is not None:
|
||||
return {
|
||||
'endpoint': 'aggregated',
|
||||
'model': model,
|
||||
'input_tokens': total_input,
|
||||
'output_tokens': total_output,
|
||||
'total_tokens': total_tokens
|
||||
'endpoint': row[0],
|
||||
'model': row[1],
|
||||
'input_tokens': row[2],
|
||||
'output_tokens': row[3],
|
||||
'total_tokens': row[4]
|
||||
}
|
||||
return None
|
||||
|
||||
|
|
|
|||
38
enhance.py
Normal file
38
enhance.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
class feedback(BaseModel):
|
||||
query_id: int
|
||||
content: str
|
||||
|
||||
def moe(query: str, query_id: int, response: str) -> str:
|
||||
moe_prompt = f"""
|
||||
User query: {query}
|
||||
query_id: {query_id}
|
||||
|
||||
The following is an assistant response to the original user query. Analyse the response, then critizise the response by discussing both strength and weakness of the response.
|
||||
|
||||
<assistant_response>
|
||||
{response}
|
||||
</assistant_response>
|
||||
"""
|
||||
return moe_prompt
|
||||
|
||||
def moe_select_candiadate(query: str, candidates_with_feedback: list[str]) -> str:
|
||||
select_prompt = f"""
|
||||
From the following responses for the user query: {query}
|
||||
select the best fitting candidate and formulate a final anser for the user.
|
||||
|
||||
<candidate_0>
|
||||
{candidates_with_feedback[0].message.content}
|
||||
</candidate_0>
|
||||
|
||||
<candidate_1>
|
||||
{candidates_with_feedback[1].message.content}
|
||||
</candidate_1>
|
||||
|
||||
<candidate_2>
|
||||
{candidates_with_feedback[2].message.content}
|
||||
</candidate_2>
|
||||
"""
|
||||
return select_prompt
|
||||
|
||||
41
router.py
41
router.py
|
|
@ -6,7 +6,7 @@ version: 0.5
|
|||
license: AGPL
|
||||
"""
|
||||
# -------------------------------------------------------------
|
||||
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io
|
||||
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Dict, Set, List, Optional
|
||||
|
|
@ -862,6 +862,17 @@ async def chat_proxy(request: Request):
|
|||
"""
|
||||
Proxy a chat request to Ollama and stream the endpoint reply.
|
||||
"""
|
||||
def last_user_content(messages):
|
||||
"""
|
||||
Given a list of dicts (e.g., messages from an API),
|
||||
return the 'content' of the last dict whose 'role' is 'user'.
|
||||
If no such dict exists, return None.
|
||||
"""
|
||||
# Reverse iterate so we stop at the first match
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "user":
|
||||
return msg.get("content")
|
||||
return None
|
||||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
|
|
@ -892,6 +903,11 @@ async def chat_proxy(request: Request):
|
|||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
if model.startswith("moe-"):
|
||||
model = model.split("moe-")[1]
|
||||
opt = True
|
||||
else:
|
||||
opt = False
|
||||
endpoint = await choose_endpoint(model)
|
||||
is_openai_endpoint = "/v1" in endpoint
|
||||
if is_openai_endpoint:
|
||||
|
|
@ -930,7 +946,20 @@ async def chat_proxy(request: Request):
|
|||
start_ts = time.perf_counter()
|
||||
async_gen = await oclient.chat.completions.create(**params)
|
||||
else:
|
||||
async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive)
|
||||
if opt == True:
|
||||
query = last_user_content(messages)
|
||||
if query:
|
||||
options["temperature"] = 1
|
||||
moe_reqs = []
|
||||
responses = await asyncio.gather(*[client.chat(model=model, messages=messages, tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive) for _ in range(0,3)])
|
||||
for n,r in enumerate(responses):
|
||||
moe_req = enhance.moe(query, n, r.message.content)
|
||||
moe_reqs.append(moe_req)
|
||||
critiques = await asyncio.gather(*[client.chat(model=model, messages=[{"role": "user", "content": moe_req}], tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive) for moe_req in moe_reqs])
|
||||
m = enhance.moe_select_candiadate(query, critiques)
|
||||
async_gen = await client.chat(model=model, messages=[{"role": "user", "content": m}], tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive)
|
||||
else:
|
||||
async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive)
|
||||
if stream == True:
|
||||
async for chunk in async_gen:
|
||||
if is_openai_endpoint:
|
||||
|
|
@ -1705,7 +1734,9 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
if prompt_tok != 0 or comp_tok != 0:
|
||||
if not is_ext_openai_endpoint(endpoint):
|
||||
if not ":" in model:
|
||||
local_model = model+":latest"
|
||||
local_model = model if ":" in model else model + ":latest"
|
||||
else:
|
||||
local_model = model
|
||||
await token_queue.put((endpoint, local_model, prompt_tok, comp_tok))
|
||||
yield b"data: [DONE]\n\n"
|
||||
else:
|
||||
|
|
@ -1821,7 +1852,9 @@ async def openai_completions_proxy(request: Request):
|
|||
if prompt_tok != 0 or comp_tok != 0:
|
||||
if not is_ext_openai_endpoint(endpoint):
|
||||
if not ":" in model:
|
||||
local_model = model+":latest"
|
||||
local_model = model if ":" in model else model + ":latest"
|
||||
else:
|
||||
local_model = model
|
||||
await token_queue.put((endpoint, local_model, prompt_tok, comp_tok))
|
||||
# Final DONE event
|
||||
yield b"data: [DONE]\n\n"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue