Merge remote-tracking branch 'upstream/dev' into feat/split-auto-free-premium

This commit is contained in:
Anish Sarkar 2026-04-30 16:23:05 +05:30
commit 872065f90d
15 changed files with 1857 additions and 545 deletions

View file

@ -599,8 +599,17 @@ class VercelStreamingService:
Format the start of tool input streaming.
Args:
tool_call_id: The unique tool call identifier (synthetic, derived
from LangGraph ``run_id`` so the frontend has a stable card id).
tool_call_id: The unique tool call identifier. May be EITHER the
synthetic ``call_<run_id>`` id derived from LangGraph
``run_id`` (legacy / ``SURFSENSE_ENABLE_STREAM_PARITY_V2``
OFF, or the unmatched-fallback path under parity_v2) OR
the authoritative LangChain ``tool_call.id`` (parity_v2
path: when the provider streams ``tool_call_chunks`` we
register the ``index`` and reuse the lc-id as the card
id so live ``tool-input-delta`` events can be routed
without a downstream join). Either way, the same id is
preserved across ``tool-input-start`` / ``-delta`` /
``-available`` / ``tool-output-available`` for one call.
tool_name: The name of the tool being called.
langchain_tool_call_id: Optional authoritative LangChain
``tool_call.id``. When set, surfaces as

View file

@ -473,6 +473,42 @@ def _emit_stream_terminal_error(
return streaming_service.format_error(message, error_code=error_code)
def _legacy_match_lc_id(
pending_tool_call_chunks: list[dict[str, Any]],
tool_name: str,
run_id: str,
lc_tool_call_id_by_run: dict[str, str],
) -> str | None:
"""Best-effort match a buffered ``tool_call_chunk`` to a tool name.
Pure extract of the legacy in-line match used at ``on_tool_start`` for
parity_v2-OFF and unmatched (chunk path didn't register an index for
this call) tools. Pops the next id-bearing chunk whose ``name``
matches ``tool_name`` (or any id-bearing chunk as a fallback) and
returns its id. Mutates ``pending_tool_call_chunks`` and
``lc_tool_call_id_by_run`` in place.
"""
matched_idx: int | None = None
for idx, tcc in enumerate(pending_tool_call_chunks):
if tcc.get("name") == tool_name and tcc.get("id"):
matched_idx = idx
break
if matched_idx is None:
for idx, tcc in enumerate(pending_tool_call_chunks):
if tcc.get("id"):
matched_idx = idx
break
if matched_idx is None:
return None
matched = pending_tool_call_chunks.pop(matched_idx)
candidate = matched.get("id")
if isinstance(candidate, str) and candidate:
if run_id:
lc_tool_call_id_by_run[run_id] = candidate
return candidate
return None
async def _stream_agent_events(
agent: Any,
config: dict[str, Any],
@ -538,10 +574,28 @@ async def _stream_agent_events(
# ``tool_call_chunks`` from ``on_chat_model_stream``, key them by
# name, and pop the next unconsumed entry at ``on_tool_start``. The
# authoritative id is later filled in at ``on_tool_end`` from
# ``ToolMessage.tool_call_id``.
# ``ToolMessage.tool_call_id``. Under parity_v2 we ALSO short-circuit
# this list for chunks that already registered into ``index_to_meta``
# below — so this list is reserved for the parity_v2-OFF / unmatched
# fallback path only and never re-pops a chunk we already streamed.
pending_tool_call_chunks: list[dict[str, Any]] = []
lc_tool_call_id_by_run: dict[str, str] = {}
# parity_v2 only: live tool-call argument streaming. ``index_to_meta``
# is keyed by the chunk's ``index`` field — LangChain
# ``ToolCallChunk``s for the same call share an index but only the
# first chunk carries id+name (subsequent ones are id=None,
# name=None, args="<delta>"). We register an index when both id and
# name are observed on a chunk (per ToolCallChunk semantics they
# arrive together on the first chunk), then route every later chunk
# at that index to the same ``ui_id`` as a ``tool-input-delta``.
# ``ui_tool_call_id_by_run`` maps LangGraph ``run_id`` to the
# ``ui_id`` used for that call's ``tool-input-start`` so the matching
# ``tool-output-available`` (emitted from ``on_tool_end``) lands on
# the same card.
index_to_meta: dict[int, dict[str, str]] = {}
ui_tool_call_id_by_run: dict[str, str] = {}
# Per-tool-end mutable cache for the LangChain tool_call_id resolved
# at ``on_tool_end``. ``_emit_tool_output`` reads this so every
# ``format_tool_output_available`` call automatically carries the
@ -587,13 +641,6 @@ async def _stream_agent_events(
continue
parts = _extract_chunk_parts(chunk)
# Accumulate any tool_call_chunks for best-effort
# correlation with ``on_tool_start`` below. We don't emit
# anything here; the matching is done at tool-start time.
if parity_v2 and parts["tool_call_chunks"]:
for tcc in parts["tool_call_chunks"]:
pending_tool_call_chunks.append(tcc)
reasoning_delta = parts["reasoning"]
text_delta = parts["text"]
@ -639,6 +686,71 @@ async def _stream_agent_events(
yield streaming_service.format_text_delta(current_text_id, text_delta)
accumulated_text += text_delta
# Live tool-call argument streaming. Runs AFTER text/reasoning
# processing so chunks containing both stay in their natural
# wire order (text → text-end → tool-input-start). Active
# text/reasoning are closed inside the registration branch
# before ``tool-input-start`` so the frontend sees a clean
# part boundary even when providers interleave.
if parity_v2 and parts["tool_call_chunks"]:
for tcc in parts["tool_call_chunks"]:
idx = tcc.get("index")
# Register this index when we first see id+name
# TOGETHER. Per LangChain ToolCallChunk semantics the
# first chunk for a tool call carries both fields
# together; later chunks have id=None, name=None and
# only ``args``. Requiring BOTH keeps wire
# ``tool-input-start`` always carrying a real
# toolName (assistant-ui's typed tool-part dispatch
# keys off it).
if idx is not None and idx not in index_to_meta:
lc_id = tcc.get("id")
name = tcc.get("name")
if lc_id and name:
ui_id = lc_id
# Close active text/reasoning so wire
# ordering stays clean even on providers
# that interleave text and tool-call chunks
# within the same stream window.
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
current_text_id = None
if current_reasoning_id is not None:
yield streaming_service.format_reasoning_end(
current_reasoning_id
)
current_reasoning_id = None
index_to_meta[idx] = {
"ui_id": ui_id,
"lc_id": lc_id,
"name": name,
}
yield streaming_service.format_tool_input_start(
ui_id,
name,
langchain_tool_call_id=lc_id,
)
# Emit args delta for any chunk at a registered
# index (including idless continuations). Once an
# index is owned by ``index_to_meta`` we DO NOT
# append to ``pending_tool_call_chunks`` — that list
# is reserved for the parity_v2-OFF / unmatched
# fallback path so it never re-pops chunks already
# consumed here (skip-append).
meta = index_to_meta.get(idx) if idx is not None else None
if meta:
args_chunk = tcc.get("args") or ""
if args_chunk:
yield streaming_service.format_tool_input_delta(
meta["ui_id"], args_chunk
)
else:
pending_tool_call_chunks.append(tcc)
elif event_type == "on_tool_start":
active_tool_depth += 1
tool_name = event.get("name", "unknown_tool")
@ -969,44 +1081,65 @@ async def _stream_agent_events(
status="in_progress",
)
tool_call_id = (
f"call_{run_id[:32]}"
if run_id
else streaming_service.generate_tool_call_id()
)
# Best-effort attach the LangChain ``tool_call_id``. We
# pop the first chunk in ``pending_tool_call_chunks`` whose
# name matches; if none match (the chunked args may not yet
# carry a ``name`` field, or the model skipped the chunked
# form) we leave ``langchainToolCallId`` unset for now and
# fill it in authoritatively at ``on_tool_end`` from
# ``ToolMessage.tool_call_id``.
langchain_tool_call_id: str | None = None
if parity_v2 and pending_tool_call_chunks:
matched_idx: int | None = None
for idx, tcc in enumerate(pending_tool_call_chunks):
if tcc.get("name") == tool_name and tcc.get("id"):
matched_idx = idx
# Resolve the card identity. If the chunk-emission loop
# already registered an ``index`` for this tool call (parity_v2
# path), reuse the same ui_id so the card sees:
# tool-input-start → deltas… → tool-input-available →
# tool-output-available all keyed by lc_id. Otherwise fall
# back to the synthetic ``call_<run_id>`` id and the legacy
# best-effort match against ``pending_tool_call_chunks``.
matched_meta: dict[str, str] | None = None
if parity_v2:
# FIFO over indices 0,1,2…; first unassigned same-name
# match wins. Handles parallel same-name calls (e.g. two
# write_file calls) deterministically as long as the
# model interleaves on_tool_start in the same order it
# streamed the args.
taken_ui_ids = set(ui_tool_call_id_by_run.values())
for meta in index_to_meta.values():
if meta["name"] == tool_name and meta["ui_id"] not in taken_ui_ids:
matched_meta = meta
break
if matched_idx is None:
for idx, tcc in enumerate(pending_tool_call_chunks):
if tcc.get("id"):
matched_idx = idx
break
if matched_idx is not None:
matched = pending_tool_call_chunks.pop(matched_idx)
candidate = matched.get("id")
if isinstance(candidate, str) and candidate:
langchain_tool_call_id = candidate
if run_id:
lc_tool_call_id_by_run[run_id] = candidate
yield streaming_service.format_tool_input_start(
tool_call_id,
tool_name,
langchain_tool_call_id=langchain_tool_call_id,
)
tool_call_id: str
langchain_tool_call_id: str | None = None
if matched_meta is not None:
tool_call_id = matched_meta["ui_id"]
langchain_tool_call_id = matched_meta["lc_id"]
# ``tool-input-start`` already fired during chunk
# emission — skip the duplicate. No pruning is needed
# because the chunk-emission loop intentionally never
# appends registered-index chunks to
# ``pending_tool_call_chunks`` (skip-append).
if run_id:
lc_tool_call_id_by_run[run_id] = matched_meta["lc_id"]
else:
tool_call_id = (
f"call_{run_id[:32]}"
if run_id
else streaming_service.generate_tool_call_id()
)
# Legacy fallback: parity_v2 OFF, or parity_v2 ON but the
# provider didn't stream tool_call_chunks for this call
# (no index registered). Run the existing best-effort
# match BEFORE emitting start so we still attach an
# authoritative ``langchainToolCallId`` when possible.
if parity_v2:
langchain_tool_call_id = _legacy_match_lc_id(
pending_tool_call_chunks,
tool_name,
run_id,
lc_tool_call_id_by_run,
)
yield streaming_service.format_tool_input_start(
tool_call_id,
tool_name,
langchain_tool_call_id=langchain_tool_call_id,
)
if run_id:
ui_tool_call_id_by_run[run_id] = tool_call_id
# Sanitize tool_input: strip runtime-injected non-serializable
# values (e.g. LangChain ToolRuntime) before sending over SSE.
if isinstance(tool_input, dict):
@ -1059,7 +1192,15 @@ async def _stream_agent_events(
result.write_succeeded = True
result.verification_succeeded = True
tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown"
# Look up the SAME card id used at on_tool_start (either the
# parity_v2 lc-id-derived ui_id or the legacy synthetic
# ``call_<run_id>``) so the output event always lands on the
# same card as start/delta/available. Fallback preserves the
# legacy synthetic shape for parity_v2-OFF / unknown-run paths.
tool_call_id = ui_tool_call_id_by_run.get(
run_id,
f"call_{run_id[:32]}" if run_id else "call_unknown",
)
original_step_id = tool_step_ids.get(
run_id, f"{step_prefix}-unknown-{run_id[:8]}"
)
@ -1070,17 +1211,22 @@ async def _stream_agent_events(
# at ``on_tool_start`` time (kept in ``lc_tool_call_id_by_run``)
# if the output isn't a ToolMessage. The value is stored in
# ``current_lc_tool_call_id`` so ``_emit_tool_output``
# picks it up for every output emit below. Stays None when
# parity_v2 is off so legacy emit paths are untouched.
# picks it up for every output emit below.
#
# Emitted in BOTH parity_v2 and legacy modes: the chat tool
# card needs the LangChain id to match against the
# ``data-action-log`` SSE event (keyed by ``lc_tool_call_id``)
# so the inline Revert button can light up. Reading
# ``raw_output.tool_call_id`` is a cheap, non-mutating attribute
# access that is safe regardless of feature-flag state.
current_lc_tool_call_id["value"] = None
if parity_v2:
authoritative = getattr(raw_output, "tool_call_id", None)
if isinstance(authoritative, str) and authoritative:
current_lc_tool_call_id["value"] = authoritative
if run_id:
lc_tool_call_id_by_run[run_id] = authoritative
elif run_id and run_id in lc_tool_call_id_by_run:
current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id]
authoritative = getattr(raw_output, "tool_call_id", None)
if isinstance(authoritative, str) and authoritative:
current_lc_tool_call_id["value"] = authoritative
if run_id:
lc_tool_call_id_by_run[run_id] = authoritative
elif run_id and run_id in lc_tool_call_id_by_run:
current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id]
if tool_name == "read_file":
yield streaming_service.format_thinking_step(

View file

@ -183,3 +183,46 @@ class TestDefensive:
assert out["text"] == ""
assert out["reasoning"] == ""
assert out["tool_call_chunks"] == []
class TestIdlessContinuationChunks:
"""Per LangChain ``ToolCallChunk`` semantics, the FIRST chunk for a
tool call carries id+name; later chunks for the same call have
``id=None, name=None`` and only ``args`` + ``index``. Live tool-call
argument streaming relies on those idless continuation chunks
flowing through ``_extract_chunk_parts`` UNTOUCHED so the upstream
chunk-emission loop can still route them by ``index``.
"""
def test_idless_continuation_chunk_preserved_verbatim(self) -> None:
chunk = _FakeChunk(
tool_call_chunks=[
{"id": None, "name": None, "args": '_path":"/x"}', "index": 0}
]
)
out = _extract_chunk_parts(chunk)
assert len(out["tool_call_chunks"]) == 1
tcc = out["tool_call_chunks"][0]
assert tcc.get("id") is None
assert tcc.get("name") is None
assert tcc.get("args") == '_path":"/x"}'
assert tcc.get("index") == 0
def test_first_then_idless_sequence_preserves_index(self) -> None:
"""Both chunks for the same call share an ``index`` key — the
index-routing loop in ``stream_new_chat`` depends on it."""
first = _FakeChunk(
tool_call_chunks=[
{"id": "lc-1", "name": "write_file", "args": '{"file', "index": 0}
]
)
cont = _FakeChunk(
tool_call_chunks=[
{"id": None, "name": None, "args": '_path":"/x"}', "index": 0}
]
)
out_first = _extract_chunk_parts(first)
out_cont = _extract_chunk_parts(cont)
assert out_first["tool_call_chunks"][0]["index"] == 0
assert out_cont["tool_call_chunks"][0]["index"] == 0
assert out_cont["tool_call_chunks"][0].get("id") is None

View file

@ -0,0 +1,527 @@
"""Unit tests for live tool-call argument streaming.
Pins the wire format that ``_stream_agent_events`` emits when
``SURFSENSE_ENABLE_STREAM_PARITY_V2=true``: ``tool-input-start``
``tool-input-delta``... ``tool-input-available`` ``tool-output-available``
all keyed by the same LangChain ``tool_call.id``.
Identity is tracked in ``index_to_meta`` (per-chunk ``index``) and
``ui_tool_call_id_by_run`` (LangGraph ``run_id``); both are private to
``_stream_agent_events`` so we exercise them via the public wire output.
These tests also lock in the legacy / parity_v2-OFF behaviour so the
synthetic ``call_<run_id>`` shape stays stable for older clients.
"""
from __future__ import annotations
import json
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from typing import Any
import pytest
import app.tasks.chat.stream_new_chat as stream_module
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.stream_new_chat import (
StreamResult,
_legacy_match_lc_id,
_stream_agent_events,
)
pytestmark = pytest.mark.unit
@dataclass
class _FakeChunk:
"""Minimal stand-in for ``AIMessageChunk``."""
content: Any = ""
additional_kwargs: dict[str, Any] = field(default_factory=dict)
tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
@dataclass
class _FakeToolMessage:
"""Stand-in for ``ToolMessage`` returned by ``on_tool_end``."""
content: Any
tool_call_id: str | None = None
class _FakeAgentState:
"""Stand-in for ``StateSnapshot`` returned by ``aget_state``."""
def __init__(self) -> None:
# Empty values keeps the cloud-fallback safety-net branch a no-op,
# and an empty ``tasks`` list keeps the post-stream interrupt
# check a no-op too.
self.values: dict[str, Any] = {}
self.tasks: list[Any] = []
class _FakeAgent:
"""Replays a list of ``astream_events`` events."""
def __init__(self, events: list[dict[str, Any]]) -> None:
self._events = events
async def astream_events( # type: ignore[no-untyped-def]
self, _input_data: Any, *, config: dict[str, Any], version: str
) -> AsyncGenerator[dict[str, Any], None]:
del config, version # unused, contract-compatible
for ev in self._events:
yield ev
async def aget_state(self, _config: dict[str, Any]) -> _FakeAgentState:
# Called once after astream_events drains so the cloud-fallback
# safety net can inspect staged filesystem work. The fake stays
# empty so the safety net is a no-op.
return _FakeAgentState()
def _model_stream(
*,
text: str = "",
reasoning: str = "",
tool_call_chunks: list[dict[str, Any]] | None = None,
tags: list[str] | None = None,
) -> dict[str, Any]:
return (
{
"event": "on_chat_model_stream",
"tags": tags or [],
"data": {
"chunk": _FakeChunk(
content=text,
tool_call_chunks=list(tool_call_chunks or []),
)
},
# reasoning piggybacks via additional_kwargs path; if needed,
# override content to a typed-block list. Most tests just check
# tool_call_chunks routing so this is fine.
}
if not reasoning
else {
"event": "on_chat_model_stream",
"tags": tags or [],
"data": {
"chunk": _FakeChunk(
content=text,
additional_kwargs={"reasoning_content": reasoning},
tool_call_chunks=list(tool_call_chunks or []),
)
},
}
)
def _tool_start(
*,
name: str,
run_id: str,
input_payload: dict[str, Any] | None = None,
) -> dict[str, Any]:
return {
"event": "on_tool_start",
"name": name,
"run_id": run_id,
"data": {"input": input_payload or {}},
}
def _tool_end(
*,
name: str,
run_id: str,
tool_call_id: str | None = None,
output: Any = "ok",
) -> dict[str, Any]:
return {
"event": "on_tool_end",
"name": name,
"run_id": run_id,
"data": {
"output": _FakeToolMessage(
content=json.dumps(output) if not isinstance(output, str) else output,
tool_call_id=tool_call_id,
)
},
}
@pytest.fixture
def parity_v2_on(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
stream_module,
"get_flags",
lambda: AgentFeatureFlags(enable_stream_parity_v2=True),
)
@pytest.fixture
def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
stream_module,
"get_flags",
lambda: AgentFeatureFlags(enable_stream_parity_v2=False),
)
async def _drain(events: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Run ``_stream_agent_events`` against a fake agent and return the
SSE payloads (parsed JSON) it yielded.
"""
agent = _FakeAgent(events)
service = VercelStreamingService()
result = StreamResult()
config = {"configurable": {"thread_id": "test-thread"}}
sse_lines: list[str] = []
async for sse in _stream_agent_events(
agent, config, {}, service, result, step_prefix="thinking"
):
sse_lines.append(sse)
parsed: list[dict[str, Any]] = []
for line in sse_lines:
if not line.startswith("data: "):
continue
body = line[len("data: ") :].rstrip("\n")
if not body or body == "[DONE]":
continue
try:
parsed.append(json.loads(body))
except json.JSONDecodeError:
continue
return parsed
def _types(payloads: list[dict[str, Any]]) -> list[str]:
return [p.get("type", "?") for p in payloads]
def _of_type(payloads: list[dict[str, Any]], type_name: str) -> list[dict[str, Any]]:
return [p for p in payloads if p.get("type") == type_name]
# ---------------------------------------------------------------------------
# Helper: ``_legacy_match_lc_id`` is a pure refactor; assert behaviour.
# ---------------------------------------------------------------------------
class TestLegacyMatch:
def test_pops_first_id_bearing_chunk_with_matching_name(self) -> None:
chunks: list[dict[str, Any]] = [
{"id": "x1", "name": "ls"},
{"id": "y1", "name": "write_file"},
]
runs: dict[str, str] = {}
result = _legacy_match_lc_id(chunks, "write_file", "run-1", runs)
assert result == "y1"
assert chunks == [{"id": "x1", "name": "ls"}]
assert runs == {"run-1": "y1"}
def test_falls_back_to_any_id_bearing_when_name_mismatches(self) -> None:
chunks: list[dict[str, Any]] = [{"id": "anon", "name": None}]
runs: dict[str, str] = {}
out = _legacy_match_lc_id(chunks, "ls", "run-2", runs)
assert out == "anon"
assert chunks == []
def test_returns_none_when_no_id_bearing_chunk(self) -> None:
chunks: list[dict[str, Any]] = [{"id": None, "name": None}]
runs: dict[str, str] = {}
assert _legacy_match_lc_id(chunks, "ls", "run-3", runs) is None
assert chunks == [{"id": None, "name": None}]
assert runs == {}
# ---------------------------------------------------------------------------
# parity_v2 wire format tests.
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_idless_chunk_merging_by_index(parity_v2_on: None) -> None:
"""First chunk carries id+name; later idless chunks at the same
``index`` merge into the SAME ``tool-input-start`` ui id and emit
one ``tool-input-delta`` per chunk."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-1", "name": "write_file", "args": '{"file', "index": 0}
],
),
_model_stream(
tool_call_chunks=[
{"id": None, "name": None, "args": '_path":"/x"}', "index": 0}
],
),
_tool_start(
name="write_file", run_id="run-A", input_payload={"file_path": "/x"}
),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
deltas = _of_type(payloads, "tool-input-delta")
available = _of_type(payloads, "tool-input-available")
output = _of_type(payloads, "tool-output-available")
assert len(starts) == 1
assert starts[0]["toolCallId"] == "lc-1"
assert starts[0]["toolName"] == "write_file"
assert starts[0]["langchainToolCallId"] == "lc-1"
assert [d["inputTextDelta"] for d in deltas] == ['{"file', '_path":"/x"}']
assert all(d["toolCallId"] == "lc-1" for d in deltas)
assert len(available) == 1
assert available[0]["toolCallId"] == "lc-1"
assert len(output) == 1
assert output[0]["toolCallId"] == "lc-1"
@pytest.mark.asyncio
async def test_two_interleaved_tool_calls_route_by_index(
parity_v2_on: None,
) -> None:
"""Two same-name calls with distinct indices keep their deltas
routed to the right card."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-A", "name": "write_file", "args": '{"a":1', "index": 0},
{"id": "lc-B", "name": "write_file", "args": '{"b":2', "index": 1},
]
),
_model_stream(
tool_call_chunks=[
{"id": None, "name": None, "args": "}", "index": 0},
{"id": None, "name": None, "args": "}", "index": 1},
]
),
_tool_start(name="write_file", run_id="run-A", input_payload={"a": 1}),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-A"),
_tool_start(name="write_file", run_id="run-B", input_payload={"b": 2}),
_tool_end(name="write_file", run_id="run-B", tool_call_id="lc-B"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
deltas = _of_type(payloads, "tool-input-delta")
output = _of_type(payloads, "tool-output-available")
assert {s["toolCallId"] for s in starts} == {"lc-A", "lc-B"}
by_id: dict[str, list[str]] = {"lc-A": [], "lc-B": []}
for d in deltas:
by_id[d["toolCallId"]].append(d["inputTextDelta"])
assert by_id["lc-A"] == ['{"a":1', "}"]
assert by_id["lc-B"] == ['{"b":2', "}"]
assert {o["toolCallId"] for o in output} == {"lc-A", "lc-B"}
@pytest.mark.asyncio
async def test_identity_stable_across_lifecycle(parity_v2_on: None) -> None:
"""Whatever id ``tool-input-start`` chose must be the SAME id used
on ``tool-input-available`` AND ``tool-output-available``."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-9", "name": "ls", "args": '{"path":"/"}', "index": 0}
]
),
_tool_start(name="ls", run_id="run-X", input_payload={"path": "/"}),
_tool_end(name="ls", run_id="run-X", tool_call_id="lc-9"),
]
payloads = await _drain(events)
relevant = [
p
for p in payloads
if p.get("type")
in {"tool-input-start", "tool-input-available", "tool-output-available"}
]
assert {p["toolCallId"] for p in relevant} == {"lc-9"}
@pytest.mark.asyncio
async def test_no_duplicate_tool_input_start(parity_v2_on: None) -> None:
"""When the chunk-emission loop already fired ``tool-input-start``
for this run, ``on_tool_start`` MUST NOT emit a second one."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-1", "name": "write_file", "args": "{}", "index": 0}
]
),
_tool_start(name="write_file", run_id="run-A", input_payload={}),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
assert len(starts) == 1
assert starts[0]["toolCallId"] == "lc-1"
@pytest.mark.asyncio
async def test_active_text_closes_before_early_tool_input_start(
parity_v2_on: None,
) -> None:
"""Streaming a text-delta then a tool-call chunk in subsequent
chunks: the wire MUST contain ``text-end`` before the FIRST
``tool-input-start`` (clean part boundary on the frontend)."""
events = [
_model_stream(text="Working on it"),
_model_stream(
tool_call_chunks=[
{"id": "lc-1", "name": "write_file", "args": "{}", "index": 0}
]
),
_tool_start(name="write_file", run_id="run-A", input_payload={}),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
]
types = _types(await _drain(events))
text_end_idx = types.index("text-end")
start_idx = types.index("tool-input-start")
assert text_end_idx < start_idx
@pytest.mark.asyncio
async def test_mixed_text_and_tool_chunk_preserve_order(
parity_v2_on: None,
) -> None:
"""One AIMessageChunk that carries BOTH ``text`` content AND
``tool_call_chunks`` should emit the text delta FIRST, then close
text, then ``tool-input-start``+``tool-input-delta``."""
events = [
_model_stream(
text="I'll update it",
tool_call_chunks=[
{
"id": "lc-1",
"name": "write_file",
"args": '{"file_path":"/x"}',
"index": 0,
}
],
),
_tool_start(
name="write_file", run_id="run-A", input_payload={"file_path": "/x"}
),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
]
types = _types(await _drain(events))
# text-start … text-delta … text-end … tool-input-start … tool-input-delta
assert types.index("text-start") < types.index("text-delta")
assert types.index("text-delta") < types.index("text-end")
assert types.index("text-end") < types.index("tool-input-start")
assert types.index("tool-input-start") < types.index("tool-input-delta")
@pytest.mark.asyncio
async def test_parity_v2_off_preserves_legacy_shape(
parity_v2_off: None,
) -> None:
"""When the flag is OFF, no deltas are emitted and the ``toolCallId``
is ``call_<run_id>`` (NOT the lc id)."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-1", "name": "ls", "args": '{"path":"/"}', "index": 0}
]
),
_tool_start(name="ls", run_id="run-A", input_payload={"path": "/"}),
_tool_end(name="ls", run_id="run-A", tool_call_id="lc-1"),
]
payloads = await _drain(events)
assert _of_type(payloads, "tool-input-delta") == []
starts = _of_type(payloads, "tool-input-start")
assert len(starts) == 1
assert starts[0]["toolCallId"].startswith("call_run-A")
# No ``langchainToolCallId`` propagation on ``tool-input-start`` in
# legacy mode (the start event fires before the ToolMessage is
# available, so we can't extract the authoritative LangChain id yet).
assert "langchainToolCallId" not in starts[0]
output = _of_type(payloads, "tool-output-available")
assert output[0]["toolCallId"].startswith("call_run-A")
# ``tool-output-available`` MUST carry ``langchainToolCallId`` even
# in legacy mode: the chat tool card uses it to backfill the
# LangChain id and join against the ``data-action-log`` SSE event
# (keyed by ``lc_tool_call_id``) so the inline Revert button can
# light up. Sourced from the returned ``ToolMessage.tool_call_id``,
# which is populated regardless of feature-flag state.
assert output[0]["langchainToolCallId"] == "lc-1"
@pytest.mark.asyncio
async def test_skip_append_prevents_stale_id_reuse(
parity_v2_on: None,
) -> None:
"""Two same-name tools: the SECOND tool's ``langchainToolCallId``
must NOT come from the first tool's chunk (``pending_tool_call_chunks``
must stay empty for indexed-registered chunks)."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-A", "name": "write_file", "args": "{}", "index": 0},
{"id": "lc-B", "name": "write_file", "args": "{}", "index": 1},
]
),
_tool_start(name="write_file", run_id="run-1", input_payload={}),
_tool_end(name="write_file", run_id="run-1", tool_call_id="lc-A"),
_tool_start(name="write_file", run_id="run-2", input_payload={}),
_tool_end(name="write_file", run_id="run-2", tool_call_id="lc-B"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
# Two distinct lc ids, each its own card.
assert {s["toolCallId"] for s in starts} == {"lc-A", "lc-B"}
# Each tool-output-available landed on its respective card.
output = _of_type(payloads, "tool-output-available")
assert {o["toolCallId"] for o in output} == {"lc-A", "lc-B"}
@pytest.mark.asyncio
async def test_registration_waits_for_both_id_and_name(
parity_v2_on: None,
) -> None:
"""An id-only chunk (no name yet) must NOT emit ``tool-input-start``."""
events = [
_model_stream(
tool_call_chunks=[{"id": "lc-1", "name": None, "args": "", "index": 0}]
),
]
payloads = await _drain(events)
assert _of_type(payloads, "tool-input-start") == []
@pytest.mark.asyncio
async def test_unmatched_fallback_still_attaches_lc_id(
parity_v2_on: None,
) -> None:
"""parity_v2 ON, but the provider didn't include an ``index``: the
legacy fallback path must still emit ``tool-input-start`` with the
matching ``langchainToolCallId``."""
events = [
# No index on the chunk → not registered into index_to_meta;
# falls through to ``pending_tool_call_chunks`` so the legacy
# match path can pop it at on_tool_start.
_model_stream(tool_call_chunks=[{"id": "lc-orphan", "name": "ls", "args": ""}]),
_tool_start(name="ls", run_id="run-1", input_payload={"path": "/"}),
_tool_end(name="ls", run_id="run-1", tool_call_id="lc-orphan"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
assert len(starts) == 1
assert starts[0]["toolCallId"].startswith("call_run-1")
assert starts[0]["langchainToolCallId"] == "lc-orphan"