fix(enhance.py): correct typo in function name from 'moe_select_candiadate' to 'moe_select_candidate'
feat(router.py): add helper function _make_chat_request for handling enhancing chat requests to endpoints
This commit is contained in:
parent
5eb5490d16
commit
19a13cc613
2 changed files with 152 additions and 23 deletions
|
|
@ -17,7 +17,7 @@ def moe(query: str, query_id: int, response: str) -> str:
|
|||
"""
|
||||
return moe_prompt
|
||||
|
||||
def moe_select_candiadate(query: str, candidates_with_feedback: list[str]) -> str:
|
||||
def moe_select_candidate(query: str, candidates_with_feedback: list[str]) -> str:
|
||||
select_prompt = f"""
|
||||
From the following responses for the user query: {query}
|
||||
select the best fitting candidate and formulate a final anser for the user.
|
||||
|
|
|
|||
173
router.py
173
router.py
|
|
@ -441,6 +441,155 @@ async def decrement_usage(endpoint: str, model: str) -> None:
|
|||
# usage_counts.pop(endpoint, None)
|
||||
await publish_snapshot()
|
||||
|
||||
async def _make_chat_request(endpoint: str, model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
|
||||
"""
|
||||
Helper function to make a chat request to a specific endpoint.
|
||||
Handles endpoint selection, client creation, usage tracking, and request execution.
|
||||
"""
|
||||
is_openai_endpoint = "/v1" in endpoint
|
||||
if is_openai_endpoint:
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")[0]
|
||||
if messages:
|
||||
messages = transform_images_to_data_urls(messages)
|
||||
params = {
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
}
|
||||
optional_params = {
|
||||
"tools": tools,
|
||||
"stream": stream,
|
||||
"stream_options": {"include_usage": True} if stream else None,
|
||||
"max_tokens": options.get("num_predict") if options and "num_predict" in options else None,
|
||||
"frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None,
|
||||
"presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None,
|
||||
"seed": options.get("seed") if options and "seed" in options else None,
|
||||
"stop": options.get("stop") if options and "stop" in options else None,
|
||||
"top_p": options.get("top_p") if options and "top_p" in options else None,
|
||||
"temperature": options.get("temperature") if options and "temperature" in options else None,
|
||||
"response_format": {"type": "json_schema", "json_schema": format} if format is not None else None
|
||||
}
|
||||
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||||
oclient = openai.AsyncOpenAI(base_url=endpoint, default_headers=default_headers, api_key=config.api_keys[endpoint])
|
||||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
|
||||
await increment_usage(endpoint, model)
|
||||
|
||||
try:
|
||||
if is_openai_endpoint:
|
||||
start_ts = time.perf_counter()
|
||||
response = await oclient.chat.completions.create(**params)
|
||||
if stream:
|
||||
# For streaming, we need to collect all chunks
|
||||
chunks = []
|
||||
async for chunk in response:
|
||||
chunks.append(chunk)
|
||||
if chunk.usage is not None:
|
||||
prompt_tok = chunk.usage.prompt_tokens or 0
|
||||
comp_tok = chunk.usage.completion_tokens or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
# Convert to Ollama format
|
||||
if chunks:
|
||||
response = rechunk.openai_chat_completion2ollama(chunks[-1], stream, start_ts)
|
||||
else:
|
||||
prompt_tok = response.usage.prompt_tokens or 0
|
||||
comp_tok = response.usage.completion_tokens or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
response = rechunk.openai_chat_completion2ollama(response, stream, start_ts)
|
||||
else:
|
||||
response = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive)
|
||||
if stream:
|
||||
# For streaming, collect all chunks
|
||||
chunks = []
|
||||
async for chunk in response:
|
||||
chunks.append(chunk)
|
||||
prompt_tok = chunk.prompt_eval_count or 0
|
||||
comp_tok = chunk.eval_count or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
if chunks:
|
||||
response = chunks[-1]
|
||||
else:
|
||||
prompt_tok = response.prompt_eval_count or 0
|
||||
comp_tok = response.eval_count or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
|
||||
return response
|
||||
finally:
|
||||
await decrement_usage(endpoint, model)
|
||||
|
||||
def get_last_user_content(messages):
|
||||
"""
|
||||
Given a list of dicts (e.g., messages from an API),
|
||||
return the 'content' of the last dict whose 'role' is 'user'.
|
||||
If no such dict exists, return None.
|
||||
"""
|
||||
# Reverse iterate so we stop at the first match
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "user":
|
||||
return msg.get("content")
|
||||
return None
|
||||
|
||||
async def _make_moe_requests(model: str, messages: list, tools=None, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
|
||||
"""
|
||||
Helper function to make MOE (Multiple Opinions Ensemble) requests.
|
||||
Generates 3 responses, 3 critiques, and returns the final selected response.
|
||||
"""
|
||||
query = get_last_user_content(messages)
|
||||
if not query:
|
||||
raise ValueError("No user query found in messages")
|
||||
|
||||
if options is None:
|
||||
options = {}
|
||||
options["temperature"] = 1
|
||||
|
||||
moe_reqs = []
|
||||
|
||||
# Generate 3 responses
|
||||
response1_endpoint = await choose_endpoint(model)
|
||||
response1_task = asyncio.create_task(_make_chat_request(response1_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
await asyncio.sleep(0.01) # Small delay to allow usage count to update
|
||||
|
||||
response2_endpoint = await choose_endpoint(model)
|
||||
response2_task = asyncio.create_task(_make_chat_request(response2_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
await asyncio.sleep(0.01) # Small delay to allow usage count to update
|
||||
|
||||
response3_endpoint = await choose_endpoint(model)
|
||||
response3_task = asyncio.create_task(_make_chat_request(response3_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
await asyncio.sleep(0.01) # Small delay to allow usage count to update
|
||||
|
||||
responses = await asyncio.gather(response1_task, response2_task, response3_task)
|
||||
|
||||
for n, r in enumerate(responses):
|
||||
moe_req = enhance.moe(query, n, r.message.content)
|
||||
moe_reqs.append(moe_req)
|
||||
|
||||
# Generate 3 critiques
|
||||
critique1_endpoint = await choose_endpoint(model)
|
||||
critique1_task = asyncio.create_task(_make_chat_request(critique1_endpoint, model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
await asyncio.sleep(0.01) # Small delay to allow usage count to update
|
||||
|
||||
critique2_endpoint = await choose_endpoint(model)
|
||||
critique2_task = asyncio.create_task(_make_chat_request(critique2_endpoint, model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
await asyncio.sleep(0.01) # Small delay to allow usage count to update
|
||||
|
||||
critique3_endpoint = await choose_endpoint(model)
|
||||
critique3_task = asyncio.create_task(_make_chat_request(critique3_endpoint, model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
await asyncio.sleep(0.01) # Small delay to allow usage count to update
|
||||
|
||||
critiques = await asyncio.gather(critique1_task, critique2_task, critique3_task)
|
||||
|
||||
# Select final response
|
||||
m = enhance.moe_select_candidate(query, critiques)
|
||||
|
||||
# Generate final response
|
||||
final_endpoint = await choose_endpoint(model)
|
||||
return await _make_chat_request(final_endpoint, model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)
|
||||
|
||||
def iso8601_ns():
|
||||
ns = time.time_ns()
|
||||
sec, ns_rem = divmod(ns, 1_000_000_000)
|
||||
|
|
@ -864,17 +1013,6 @@ async def chat_proxy(request: Request):
|
|||
"""
|
||||
Proxy a chat request to Ollama and stream the endpoint reply.
|
||||
"""
|
||||
def last_user_content(messages):
|
||||
"""
|
||||
Given a list of dicts (e.g., messages from an API),
|
||||
return the 'content' of the last dict whose 'role' is 'user'.
|
||||
If no such dict exists, return None.
|
||||
"""
|
||||
# Reverse iterate so we stop at the first match
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "user":
|
||||
return msg.get("content")
|
||||
return None
|
||||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
|
|
@ -949,17 +1087,8 @@ async def chat_proxy(request: Request):
|
|||
async_gen = await oclient.chat.completions.create(**params)
|
||||
else:
|
||||
if opt == True:
|
||||
query = last_user_content(messages)
|
||||
if query:
|
||||
options["temperature"] = 1
|
||||
moe_reqs = []
|
||||
responses = await asyncio.gather(*[client.chat(model=model, messages=messages, tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive) for _ in range(0,3)])
|
||||
for n,r in enumerate(responses):
|
||||
moe_req = enhance.moe(query, n, r.message.content)
|
||||
moe_reqs.append(moe_req)
|
||||
critiques = await asyncio.gather(*[client.chat(model=model, messages=[{"role": "user", "content": moe_req}], tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive) for moe_req in moe_reqs])
|
||||
m = enhance.moe_select_candiadate(query, critiques)
|
||||
async_gen = await client.chat(model=model, messages=[{"role": "user", "content": m}], tools=tools, stream=False, think=think, format=_format, options=options, keep_alive=keep_alive)
|
||||
# Use the dedicated MOE helper function
|
||||
async_gen = await _make_moe_requests(model, messages, tools, think, _format, options, keep_alive)
|
||||
else:
|
||||
async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive)
|
||||
if stream == True:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue