refactor: optimize token aggregation query and enhance chat proxy
- Refactored token aggregation query in db.py to use a single SQL query with SUM() instead of iterating through rows, improving performance - Combined import statements in db.py and router.py to reduce lines of code - Enhanced chat proxy in router.py to handle "moe-" prefixed models with multiple query execution and critique generation - Added last_user_content() helper function to extract user content from messages - Improved code readability and maintainability through these structural changes
This commit is contained in:
parent
59a8ef3abb
commit
34d6abd28b
3 changed files with 93 additions and 21 deletions
41
router.py
41
router.py
|
|
@ -6,7 +6,7 @@ version: 0.5
|
|||
license: AGPL
|
||||
"""
|
||||
# -------------------------------------------------------------
|
||||
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io
|
||||
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Dict, Set, List, Optional
|
||||
|
|
@ -862,6 +862,17 @@ async def chat_proxy(request: Request):
|
|||
"""
|
||||
Proxy a chat request to Ollama and stream the endpoint reply.
|
||||
"""
|
||||
def last_user_content(messages):
|
||||
"""
|
||||
Given a list of dicts (e.g., messages from an API),
|
||||
return the 'content' of the last dict whose 'role' is 'user'.
|
||||
If no such dict exists, return None.
|
||||
"""
|
||||
# Reverse iterate so we stop at the first match
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "user":
|
||||
return msg.get("content")
|
||||
return None
|
||||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
|
|
@ -892,6 +903,11 @@ async def chat_proxy(request: Request):
|
|||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
if model.startswith("moe-"):
|
||||
model = model.split("moe-")[1]
|
||||
opt = True
|
||||
else:
|
||||
opt = False
|
||||
endpoint = await choose_endpoint(model)
|
||||
is_openai_endpoint = "/v1" in endpoint
|
||||
if is_openai_endpoint:
|
||||
|
|
@ -930,7 +946,20 @@ async def chat_proxy(request: Request):
|
|||
start_ts = time.perf_counter()
|
||||
async_gen = await oclient.chat.completions.create(**params)
|
||||
else:
|
||||
async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive)
|
||||
if opt == True:
|
||||
query = last_user_content(messages)
|
||||
if query:
|
||||
options["temperature"] = 1
|
||||
moe_reqs = []
|
||||
responses = await asyncio.gather(*[client.chat(model=model, messages=messages, tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive) for _ in range(0,3)])
|
||||
for n,r in enumerate(responses):
|
||||
moe_req = enhance.moe(query, n, r.message.content)
|
||||
moe_reqs.append(moe_req)
|
||||
critiques = await asyncio.gather(*[client.chat(model=model, messages=[{"role": "user", "content": moe_req}], tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive) for moe_req in moe_reqs])
|
||||
m = enhance.moe_select_candiadate(query, critiques)
|
||||
async_gen = await client.chat(model=model, messages=[{"role": "user", "content": m}], tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive)
|
||||
else:
|
||||
async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive)
|
||||
if stream == True:
|
||||
async for chunk in async_gen:
|
||||
if is_openai_endpoint:
|
||||
|
|
@ -1705,7 +1734,9 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
if prompt_tok != 0 or comp_tok != 0:
|
||||
if not is_ext_openai_endpoint(endpoint):
|
||||
if not ":" in model:
|
||||
local_model = model+":latest"
|
||||
local_model = model if ":" in model else model + ":latest"
|
||||
else:
|
||||
local_model = model
|
||||
await token_queue.put((endpoint, local_model, prompt_tok, comp_tok))
|
||||
yield b"data: [DONE]\n\n"
|
||||
else:
|
||||
|
|
@ -1821,7 +1852,9 @@ async def openai_completions_proxy(request: Request):
|
|||
if prompt_tok != 0 or comp_tok != 0:
|
||||
if not is_ext_openai_endpoint(endpoint):
|
||||
if not ":" in model:
|
||||
local_model = model+":latest"
|
||||
local_model = model if ":" in model else model + ":latest"
|
||||
else:
|
||||
local_model = model
|
||||
await token_queue.put((endpoint, local_model, prompt_tok, comp_tok))
|
||||
# Final DONE event
|
||||
yield b"data: [DONE]\n\n"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue