feat: Add tool call normalization and streaming delta accumulation
Adds support for correctly handling tool calls in chat requests. Normalizes tool call data (ensuring IDs, types, and JSON arguments) in non-streaming mode and accumulates OpenAI-style deltas during streaming to build the final Ollama response.
This commit is contained in:
parent
4892998abc
commit
9875eb977a
2 changed files with 108 additions and 20 deletions
126
router.py
126
router.py
|
|
@ -868,6 +868,7 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
|
|||
model = model.split(":latest")[0]
|
||||
if messages:
|
||||
messages = transform_images_to_data_urls(messages)
|
||||
messages = transform_tool_calls_to_openai(messages)
|
||||
params = {
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
|
|
@ -899,8 +900,10 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
|
|||
if stream:
|
||||
# For streaming, we need to collect all chunks
|
||||
chunks = []
|
||||
tc_acc = {} # accumulate tool-call deltas
|
||||
async for chunk in response:
|
||||
chunks.append(chunk)
|
||||
_accumulate_openai_tc_delta(chunk, tc_acc)
|
||||
if chunk.usage is not None:
|
||||
prompt_tok = chunk.usage.prompt_tokens or 0
|
||||
comp_tok = chunk.usage.completion_tokens or 0
|
||||
|
|
@ -909,6 +912,9 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
|
|||
# Convert to Ollama format
|
||||
if chunks:
|
||||
response = rechunk.openai_chat_completion2ollama(chunks[-1], stream, start_ts)
|
||||
# Inject fully-accumulated tool calls into the final response
|
||||
if tc_acc and response.message:
|
||||
response.message.tool_calls = _build_ollama_tool_calls(tc_acc)
|
||||
else:
|
||||
prompt_tok = response.usage.prompt_tokens or 0
|
||||
comp_tok = response.usage.completion_tokens or 0
|
||||
|
|
@ -1062,6 +1068,39 @@ def resize_image_if_needed(image_data):
|
|||
print(f"Error processing image: {e}")
|
||||
return None
|
||||
|
||||
def transform_tool_calls_to_openai(message_list):
|
||||
"""
|
||||
Ensure tool_calls in assistant messages conform to the OpenAI format:
|
||||
- Each tool call must have "type": "function"
|
||||
- Each tool call must have an "id"
|
||||
- arguments must be a JSON string, not a dict
|
||||
Also ensure tool-role messages have a tool_call_id.
|
||||
"""
|
||||
# Track generated IDs so tool-role messages can reference them
|
||||
last_tool_call_ids = {}
|
||||
for msg in message_list:
|
||||
role = msg.get("role")
|
||||
if role == "assistant" and "tool_calls" in msg:
|
||||
for tc in msg["tool_calls"]:
|
||||
if "type" not in tc:
|
||||
tc["type"] = "function"
|
||||
if "id" not in tc:
|
||||
tc["id"] = f"call_{secrets.token_hex(16)}"
|
||||
func = tc.get("function", {})
|
||||
if isinstance(func.get("arguments"), dict):
|
||||
func["arguments"] = orjson.dumps(func["arguments"]).decode("utf-8")
|
||||
# Remember the id for the following tool-role message
|
||||
name = func.get("name")
|
||||
if name:
|
||||
last_tool_call_ids[name] = tc["id"]
|
||||
elif role == "tool":
|
||||
if "tool_call_id" not in msg:
|
||||
# Try to match by name from a preceding assistant tool_call
|
||||
name = msg.get("name") or msg.get("tool_name")
|
||||
if name and name in last_tool_call_ids:
|
||||
msg["tool_call_id"] = last_tool_call_ids.pop(name)
|
||||
return message_list
|
||||
|
||||
def transform_images_to_data_urls(message_list):
|
||||
for message in message_list:
|
||||
if "images" in message:
|
||||
|
|
@ -1089,6 +1128,51 @@ def transform_images_to_data_urls(message_list):
|
|||
|
||||
return message_list
|
||||
|
||||
def _accumulate_openai_tc_delta(chunk, accumulator: dict) -> None:
|
||||
"""Accumulate tool_call deltas from a single OpenAI streaming chunk.
|
||||
|
||||
``accumulator`` is a dict mapping tool-call *index* to
|
||||
``{"id": str, "name": str, "arguments": str}`` where ``arguments``
|
||||
is the concatenation of all JSON fragments seen so far.
|
||||
"""
|
||||
if not chunk.choices:
|
||||
return
|
||||
delta = chunk.choices[0].delta
|
||||
tc_deltas = getattr(delta, "tool_calls", None)
|
||||
if not tc_deltas:
|
||||
return
|
||||
for tc in tc_deltas:
|
||||
idx = tc.index
|
||||
if idx not in accumulator:
|
||||
accumulator[idx] = {
|
||||
"id": getattr(tc, "id", None) or f"call_{secrets.token_hex(16)}",
|
||||
"name": tc.function.name if tc.function else None,
|
||||
"arguments": "",
|
||||
}
|
||||
else:
|
||||
if getattr(tc, "id", None):
|
||||
accumulator[idx]["id"] = tc.id
|
||||
if tc.function and tc.function.name:
|
||||
accumulator[idx]["name"] = tc.function.name
|
||||
if tc.function and tc.function.arguments:
|
||||
accumulator[idx]["arguments"] += tc.function.arguments
|
||||
|
||||
def _build_ollama_tool_calls(accumulator: dict) -> list | None:
|
||||
"""Convert accumulated tool-call data into Ollama-format tool_calls list."""
|
||||
if not accumulator:
|
||||
return None
|
||||
result = []
|
||||
for idx in sorted(accumulator.keys()):
|
||||
tc = accumulator[idx]
|
||||
try:
|
||||
args = orjson.loads(tc["arguments"]) if tc["arguments"] else {}
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
args = {}
|
||||
result.append(ollama.Message.ToolCall(
|
||||
function=ollama.Message.ToolCall.Function(name=tc["name"], arguments=args)
|
||||
))
|
||||
return result
|
||||
|
||||
class rechunk:
|
||||
def openai_chat_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.ChatResponse:
|
||||
now = time.perf_counter()
|
||||
|
|
@ -1099,12 +1183,12 @@ class rechunk:
|
|||
done=True,
|
||||
done_reason='stop',
|
||||
total_duration=int((now - start_ts) * 1_000_000_000),
|
||||
load_duration=100000,
|
||||
load_duration=100000,
|
||||
prompt_eval_count=int(chunk.usage.prompt_tokens),
|
||||
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)),
|
||||
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)),
|
||||
eval_count=int(chunk.usage.completion_tokens),
|
||||
eval_duration=int((now - start_ts) * 1_000_000_000),
|
||||
message={"role": "assistant"}
|
||||
message=ollama.Message(role="assistant", content=""),
|
||||
)
|
||||
with_thinking = chunk.choices[0] if chunk.choices[0] else None
|
||||
if stream == True:
|
||||
|
|
@ -1116,24 +1200,22 @@ class rechunk:
|
|||
role = chunk.choices[0].message.role or "assistant"
|
||||
content = chunk.choices[0].message.content or ''
|
||||
# Convert OpenAI tool_calls to Ollama format
|
||||
# In streaming mode, tool_calls arrive as partial deltas across multiple chunks
|
||||
# (name only in first delta, arguments as incremental JSON fragments).
|
||||
# Callers must accumulate deltas and inject the final result; skip here.
|
||||
ollama_tool_calls = None
|
||||
if stream:
|
||||
raw_tool_calls = getattr(with_thinking.delta, "tool_calls", None) if with_thinking else None
|
||||
else:
|
||||
if not stream:
|
||||
raw_tool_calls = getattr(with_thinking.message, "tool_calls", None) if with_thinking else None
|
||||
if raw_tool_calls:
|
||||
ollama_tool_calls = []
|
||||
for tc in raw_tool_calls:
|
||||
try:
|
||||
args = orjson.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else (tc.function.arguments or {})
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
args = {}
|
||||
ollama_tool_calls.append({
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": args,
|
||||
}
|
||||
})
|
||||
if raw_tool_calls:
|
||||
ollama_tool_calls = []
|
||||
for tc in raw_tool_calls:
|
||||
try:
|
||||
args = orjson.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else (tc.function.arguments or {})
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
args = {}
|
||||
ollama_tool_calls.append(ollama.Message.ToolCall(
|
||||
function=ollama.Message.ToolCall.Function(name=tc.function.name, arguments=args)
|
||||
))
|
||||
assistant_msg = ollama.Message(
|
||||
role=role,
|
||||
content=content,
|
||||
|
|
@ -1528,6 +1610,7 @@ async def chat_proxy(request: Request):
|
|||
model = model[0]
|
||||
if messages:
|
||||
messages = transform_images_to_data_urls(messages)
|
||||
messages = transform_tool_calls_to_openai(messages)
|
||||
params = {
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
|
|
@ -1564,9 +1647,14 @@ async def chat_proxy(request: Request):
|
|||
else:
|
||||
async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive)
|
||||
if stream == True:
|
||||
tc_acc = {} # accumulate OpenAI tool-call deltas across chunks
|
||||
async for chunk in async_gen:
|
||||
if use_openai:
|
||||
_accumulate_openai_tc_delta(chunk, tc_acc)
|
||||
chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts)
|
||||
# Inject fully-accumulated tool calls only into the final chunk
|
||||
if chunk.done and tc_acc and chunk.message:
|
||||
chunk.message.tool_calls = _build_ollama_tool_calls(tc_acc)
|
||||
# `chunk` can be a dict or a pydantic model – dump to JSON safely
|
||||
prompt_tok = chunk.prompt_eval_count or 0
|
||||
comp_tok = chunk.eval_count or 0
|
||||
|
|
|
|||
|
|
@ -863,7 +863,7 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) {
|
|||
|
||||
const formatUntil = (value) => {
|
||||
if (value === null || value === undefined || value === "") {
|
||||
return "Forever";
|
||||
return "∞";
|
||||
}
|
||||
|
||||
let targetTime;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue