nomyo-router/requests/messages.py

187 lines
6.9 KiB
Python

"""Message-shape transforms used across the chat/completions paths.
Covers the directions between Ollama's native message format and the
OpenAI Chat Completions format:
* tool-call normalization (Ollama → OpenAI),
* images encoded as base64 lists → OpenAI multimodal ``image_url`` parts,
* trailing-assistant prefill strip (rejected by Claude/OpenAI),
* streaming tool_calls accumulation across deltas,
* logprobs translation (OpenAI choice → Ollama ``Logprob``).
"""
import secrets
import ollama
import orjson
from ollama._types import TokenLogprob, Logprob
from images import is_base64, resize_image_if_needed
def get_last_user_content(messages):
"""
Given a list of dicts (e.g., messages from an API),
return the 'content' of the last dict whose 'role' is 'user'.
If no such dict exists, return None.
"""
# Reverse iterate so we stop at the first match
for msg in reversed(messages):
if msg.get("role") == "user":
return msg.get("content")
return None
def _strip_assistant_prefill(messages: list) -> list:
"""Remove a trailing assistant message used as prefill.
OpenAI-compatible endpoints (including Claude) do not support prefill and
will reject requests where the last message has role 'assistant'."""
if messages and messages[-1].get("role") == "assistant":
return messages[:-1]
return messages
def transform_tool_calls_to_openai(message_list):
"""
Ensure tool_calls in assistant messages conform to the OpenAI format:
- Each tool call must have "type": "function"
- Each tool call must have an "id"
- arguments must be a JSON string, not a dict
Also ensure tool-role messages have a tool_call_id.
"""
# Track generated IDs so tool-role messages can reference them
last_tool_call_ids = {}
for msg in message_list:
role = msg.get("role")
if role == "assistant" and "tool_calls" in msg:
for tc in msg["tool_calls"]:
if "type" not in tc:
tc["type"] = "function"
if "id" not in tc:
tc["id"] = f"call_{secrets.token_hex(16)}"
func = tc.get("function", {})
if isinstance(func.get("arguments"), dict):
func["arguments"] = orjson.dumps(func["arguments"]).decode("utf-8")
# Remember the id for the following tool-role message
name = func.get("name")
if name:
last_tool_call_ids[name] = tc["id"]
elif role == "tool":
if "tool_call_id" not in msg:
# Try to match by name from a preceding assistant tool_call
name = msg.get("name") or msg.get("tool_name")
if name and name in last_tool_call_ids:
msg["tool_call_id"] = last_tool_call_ids.pop(name)
return message_list
def transform_images_to_data_urls(message_list):
for message in message_list:
if "images" in message:
images = message.pop("images")
if not isinstance(images, list):
continue
new_content = []
for image in images: #TODO: quality downsize if images are too big to fit into model context window size
if not is_base64(image):
raise ValueError(f"Image string is not a valid base64 encoded string.")
resized_image = resize_image_if_needed(image)
if resized_image:
data_url = f"data:image/png;base64,{resized_image}"
#new_content.append({
# "type": "text",
# "text": ""
#})
new_content.append({
"type": "image_url",
"image_url": {
"url": data_url
}
})
message["content"] = new_content
return message_list
def _strip_images_from_messages(messages: list) -> list:
"""Remove image_url parts from message content, keeping only text."""
result = []
for msg in messages:
content = msg.get("content")
if isinstance(content, list):
text_only = [p for p in content if p.get("type") != "image_url"]
if len(text_only) == 1 and text_only[0].get("type") == "text":
content = text_only[0]["text"]
else:
content = text_only
result.append({**msg, "content": content})
else:
result.append(msg)
return result
def _accumulate_openai_tc_delta(chunk, accumulator: dict) -> None:
"""Accumulate tool_call deltas from a single OpenAI streaming chunk.
``accumulator`` is a dict mapping tool-call *index* to
``{"id": str, "name": str, "arguments": str}`` where ``arguments``
is the concatenation of all JSON fragments seen so far.
"""
if not chunk.choices:
return
delta = chunk.choices[0].delta
tc_deltas = getattr(delta, "tool_calls", None)
if not tc_deltas:
return
for tc in tc_deltas:
idx = tc.index
if idx not in accumulator:
accumulator[idx] = {
"id": getattr(tc, "id", None) or f"call_{secrets.token_hex(16)}",
"name": tc.function.name if tc.function else None,
"arguments": "",
}
else:
if getattr(tc, "id", None):
accumulator[idx]["id"] = tc.id
if tc.function and tc.function.name:
accumulator[idx]["name"] = tc.function.name
if tc.function and tc.function.arguments:
accumulator[idx]["arguments"] += tc.function.arguments
def _build_ollama_tool_calls(accumulator: dict) -> list | None:
"""Convert accumulated tool-call data into Ollama-format tool_calls list."""
if not accumulator:
return None
result = []
for idx in sorted(accumulator.keys()):
tc = accumulator[idx]
try:
args = orjson.loads(tc["arguments"]) if tc["arguments"] else {}
except (orjson.JSONDecodeError, TypeError):
args = {}
result.append(ollama.Message.ToolCall(
function=ollama.Message.ToolCall.Function(name=tc["name"], arguments=args)
))
return result
def _convert_openai_logprobs(choice) -> list | None:
"""Convert OpenAI logprobs from a choice into Ollama Logprob objects."""
lp = getattr(choice, "logprobs", None)
if lp is None:
return None
content = getattr(lp, "content", None)
if not content:
return None
result = []
for entry in content:
top = [
TokenLogprob(token=alt.token, logprob=alt.logprob)
for alt in (entry.top_logprobs or [])
]
result.append(Logprob(
token=entry.token,
logprob=entry.logprob,
top_logprobs=top or None,
))
return result