187 lines
6.9 KiB
Python
187 lines
6.9 KiB
Python
"""Message-shape transforms used across the chat/completions paths.
|
|
|
|
Covers the directions between Ollama's native message format and the
|
|
OpenAI Chat Completions format:
|
|
* tool-call normalization (Ollama → OpenAI),
|
|
* images encoded as base64 lists → OpenAI multimodal ``image_url`` parts,
|
|
* trailing-assistant prefill strip (rejected by Claude/OpenAI),
|
|
* streaming tool_calls accumulation across deltas,
|
|
* logprobs translation (OpenAI choice → Ollama ``Logprob``).
|
|
"""
|
|
import secrets
|
|
|
|
import ollama
|
|
import orjson
|
|
from ollama._types import TokenLogprob, Logprob
|
|
|
|
from images import is_base64, resize_image_if_needed
|
|
|
|
|
|
def get_last_user_content(messages):
|
|
"""
|
|
Given a list of dicts (e.g., messages from an API),
|
|
return the 'content' of the last dict whose 'role' is 'user'.
|
|
If no such dict exists, return None.
|
|
"""
|
|
# Reverse iterate so we stop at the first match
|
|
for msg in reversed(messages):
|
|
if msg.get("role") == "user":
|
|
return msg.get("content")
|
|
return None
|
|
|
|
|
|
def _strip_assistant_prefill(messages: list) -> list:
|
|
"""Remove a trailing assistant message used as prefill.
|
|
OpenAI-compatible endpoints (including Claude) do not support prefill and
|
|
will reject requests where the last message has role 'assistant'."""
|
|
if messages and messages[-1].get("role") == "assistant":
|
|
return messages[:-1]
|
|
return messages
|
|
|
|
|
|
def transform_tool_calls_to_openai(message_list):
|
|
"""
|
|
Ensure tool_calls in assistant messages conform to the OpenAI format:
|
|
- Each tool call must have "type": "function"
|
|
- Each tool call must have an "id"
|
|
- arguments must be a JSON string, not a dict
|
|
Also ensure tool-role messages have a tool_call_id.
|
|
"""
|
|
# Track generated IDs so tool-role messages can reference them
|
|
last_tool_call_ids = {}
|
|
for msg in message_list:
|
|
role = msg.get("role")
|
|
if role == "assistant" and "tool_calls" in msg:
|
|
for tc in msg["tool_calls"]:
|
|
if "type" not in tc:
|
|
tc["type"] = "function"
|
|
if "id" not in tc:
|
|
tc["id"] = f"call_{secrets.token_hex(16)}"
|
|
func = tc.get("function", {})
|
|
if isinstance(func.get("arguments"), dict):
|
|
func["arguments"] = orjson.dumps(func["arguments"]).decode("utf-8")
|
|
# Remember the id for the following tool-role message
|
|
name = func.get("name")
|
|
if name:
|
|
last_tool_call_ids[name] = tc["id"]
|
|
elif role == "tool":
|
|
if "tool_call_id" not in msg:
|
|
# Try to match by name from a preceding assistant tool_call
|
|
name = msg.get("name") or msg.get("tool_name")
|
|
if name and name in last_tool_call_ids:
|
|
msg["tool_call_id"] = last_tool_call_ids.pop(name)
|
|
return message_list
|
|
|
|
|
|
def transform_images_to_data_urls(message_list):
|
|
for message in message_list:
|
|
if "images" in message:
|
|
images = message.pop("images")
|
|
if not isinstance(images, list):
|
|
continue
|
|
new_content = []
|
|
for image in images: #TODO: quality downsize if images are too big to fit into model context window size
|
|
if not is_base64(image):
|
|
raise ValueError(f"Image string is not a valid base64 encoded string.")
|
|
resized_image = resize_image_if_needed(image)
|
|
if resized_image:
|
|
data_url = f"data:image/png;base64,{resized_image}"
|
|
#new_content.append({
|
|
# "type": "text",
|
|
# "text": ""
|
|
#})
|
|
new_content.append({
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": data_url
|
|
}
|
|
})
|
|
message["content"] = new_content
|
|
|
|
return message_list
|
|
|
|
|
|
def _strip_images_from_messages(messages: list) -> list:
|
|
"""Remove image_url parts from message content, keeping only text."""
|
|
result = []
|
|
for msg in messages:
|
|
content = msg.get("content")
|
|
if isinstance(content, list):
|
|
text_only = [p for p in content if p.get("type") != "image_url"]
|
|
if len(text_only) == 1 and text_only[0].get("type") == "text":
|
|
content = text_only[0]["text"]
|
|
else:
|
|
content = text_only
|
|
result.append({**msg, "content": content})
|
|
else:
|
|
result.append(msg)
|
|
return result
|
|
|
|
|
|
def _accumulate_openai_tc_delta(chunk, accumulator: dict) -> None:
|
|
"""Accumulate tool_call deltas from a single OpenAI streaming chunk.
|
|
|
|
``accumulator`` is a dict mapping tool-call *index* to
|
|
``{"id": str, "name": str, "arguments": str}`` where ``arguments``
|
|
is the concatenation of all JSON fragments seen so far.
|
|
"""
|
|
if not chunk.choices:
|
|
return
|
|
delta = chunk.choices[0].delta
|
|
tc_deltas = getattr(delta, "tool_calls", None)
|
|
if not tc_deltas:
|
|
return
|
|
for tc in tc_deltas:
|
|
idx = tc.index
|
|
if idx not in accumulator:
|
|
accumulator[idx] = {
|
|
"id": getattr(tc, "id", None) or f"call_{secrets.token_hex(16)}",
|
|
"name": tc.function.name if tc.function else None,
|
|
"arguments": "",
|
|
}
|
|
else:
|
|
if getattr(tc, "id", None):
|
|
accumulator[idx]["id"] = tc.id
|
|
if tc.function and tc.function.name:
|
|
accumulator[idx]["name"] = tc.function.name
|
|
if tc.function and tc.function.arguments:
|
|
accumulator[idx]["arguments"] += tc.function.arguments
|
|
|
|
|
|
def _build_ollama_tool_calls(accumulator: dict) -> list | None:
|
|
"""Convert accumulated tool-call data into Ollama-format tool_calls list."""
|
|
if not accumulator:
|
|
return None
|
|
result = []
|
|
for idx in sorted(accumulator.keys()):
|
|
tc = accumulator[idx]
|
|
try:
|
|
args = orjson.loads(tc["arguments"]) if tc["arguments"] else {}
|
|
except (orjson.JSONDecodeError, TypeError):
|
|
args = {}
|
|
result.append(ollama.Message.ToolCall(
|
|
function=ollama.Message.ToolCall.Function(name=tc["name"], arguments=args)
|
|
))
|
|
return result
|
|
|
|
|
|
def _convert_openai_logprobs(choice) -> list | None:
|
|
"""Convert OpenAI logprobs from a choice into Ollama Logprob objects."""
|
|
lp = getattr(choice, "logprobs", None)
|
|
if lp is None:
|
|
return None
|
|
content = getattr(lp, "content", None)
|
|
if not content:
|
|
return None
|
|
result = []
|
|
for entry in content:
|
|
top = [
|
|
TokenLogprob(token=alt.token, logprob=alt.logprob)
|
|
for alt in (entry.top_logprobs or [])
|
|
]
|
|
result.append(Logprob(
|
|
token=entry.token,
|
|
logprob=entry.logprob,
|
|
top_logprobs=top or None,
|
|
))
|
|
return result
|