Merge remote-tracking branch 'upstream/dev' into feat/split-auto-free-premium

This commit is contained in:
Anish Sarkar 2026-04-30 16:23:05 +05:30
commit 872065f90d
15 changed files with 1857 additions and 545 deletions

View file

@ -183,3 +183,46 @@ class TestDefensive:
assert out["text"] == ""
assert out["reasoning"] == ""
assert out["tool_call_chunks"] == []
class TestIdlessContinuationChunks:
"""Per LangChain ``ToolCallChunk`` semantics, the FIRST chunk for a
tool call carries id+name; later chunks for the same call have
``id=None, name=None`` and only ``args`` + ``index``. Live tool-call
argument streaming relies on those idless continuation chunks
flowing through ``_extract_chunk_parts`` UNTOUCHED so the upstream
chunk-emission loop can still route them by ``index``.
"""
def test_idless_continuation_chunk_preserved_verbatim(self) -> None:
chunk = _FakeChunk(
tool_call_chunks=[
{"id": None, "name": None, "args": '_path":"/x"}', "index": 0}
]
)
out = _extract_chunk_parts(chunk)
assert len(out["tool_call_chunks"]) == 1
tcc = out["tool_call_chunks"][0]
assert tcc.get("id") is None
assert tcc.get("name") is None
assert tcc.get("args") == '_path":"/x"}'
assert tcc.get("index") == 0
def test_first_then_idless_sequence_preserves_index(self) -> None:
"""Both chunks for the same call share an ``index`` key — the
index-routing loop in ``stream_new_chat`` depends on it."""
first = _FakeChunk(
tool_call_chunks=[
{"id": "lc-1", "name": "write_file", "args": '{"file', "index": 0}
]
)
cont = _FakeChunk(
tool_call_chunks=[
{"id": None, "name": None, "args": '_path":"/x"}', "index": 0}
]
)
out_first = _extract_chunk_parts(first)
out_cont = _extract_chunk_parts(cont)
assert out_first["tool_call_chunks"][0]["index"] == 0
assert out_cont["tool_call_chunks"][0]["index"] == 0
assert out_cont["tool_call_chunks"][0].get("id") is None

View file

@ -0,0 +1,527 @@
"""Unit tests for live tool-call argument streaming.
Pins the wire format that ``_stream_agent_events`` emits when
``SURFSENSE_ENABLE_STREAM_PARITY_V2=true``: ``tool-input-start``
``tool-input-delta``... ``tool-input-available`` ``tool-output-available``
all keyed by the same LangChain ``tool_call.id``.
Identity is tracked in ``index_to_meta`` (per-chunk ``index``) and
``ui_tool_call_id_by_run`` (LangGraph ``run_id``); both are private to
``_stream_agent_events`` so we exercise them via the public wire output.
These tests also lock in the legacy / parity_v2-OFF behaviour so the
synthetic ``call_<run_id>`` shape stays stable for older clients.
"""
from __future__ import annotations
import json
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from typing import Any
import pytest
import app.tasks.chat.stream_new_chat as stream_module
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.stream_new_chat import (
StreamResult,
_legacy_match_lc_id,
_stream_agent_events,
)
pytestmark = pytest.mark.unit
@dataclass
class _FakeChunk:
"""Minimal stand-in for ``AIMessageChunk``."""
content: Any = ""
additional_kwargs: dict[str, Any] = field(default_factory=dict)
tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
@dataclass
class _FakeToolMessage:
"""Stand-in for ``ToolMessage`` returned by ``on_tool_end``."""
content: Any
tool_call_id: str | None = None
class _FakeAgentState:
"""Stand-in for ``StateSnapshot`` returned by ``aget_state``."""
def __init__(self) -> None:
# Empty values keeps the cloud-fallback safety-net branch a no-op,
# and an empty ``tasks`` list keeps the post-stream interrupt
# check a no-op too.
self.values: dict[str, Any] = {}
self.tasks: list[Any] = []
class _FakeAgent:
"""Replays a list of ``astream_events`` events."""
def __init__(self, events: list[dict[str, Any]]) -> None:
self._events = events
async def astream_events( # type: ignore[no-untyped-def]
self, _input_data: Any, *, config: dict[str, Any], version: str
) -> AsyncGenerator[dict[str, Any], None]:
del config, version # unused, contract-compatible
for ev in self._events:
yield ev
async def aget_state(self, _config: dict[str, Any]) -> _FakeAgentState:
# Called once after astream_events drains so the cloud-fallback
# safety net can inspect staged filesystem work. The fake stays
# empty so the safety net is a no-op.
return _FakeAgentState()
def _model_stream(
*,
text: str = "",
reasoning: str = "",
tool_call_chunks: list[dict[str, Any]] | None = None,
tags: list[str] | None = None,
) -> dict[str, Any]:
return (
{
"event": "on_chat_model_stream",
"tags": tags or [],
"data": {
"chunk": _FakeChunk(
content=text,
tool_call_chunks=list(tool_call_chunks or []),
)
},
# reasoning piggybacks via additional_kwargs path; if needed,
# override content to a typed-block list. Most tests just check
# tool_call_chunks routing so this is fine.
}
if not reasoning
else {
"event": "on_chat_model_stream",
"tags": tags or [],
"data": {
"chunk": _FakeChunk(
content=text,
additional_kwargs={"reasoning_content": reasoning},
tool_call_chunks=list(tool_call_chunks or []),
)
},
}
)
def _tool_start(
*,
name: str,
run_id: str,
input_payload: dict[str, Any] | None = None,
) -> dict[str, Any]:
return {
"event": "on_tool_start",
"name": name,
"run_id": run_id,
"data": {"input": input_payload or {}},
}
def _tool_end(
*,
name: str,
run_id: str,
tool_call_id: str | None = None,
output: Any = "ok",
) -> dict[str, Any]:
return {
"event": "on_tool_end",
"name": name,
"run_id": run_id,
"data": {
"output": _FakeToolMessage(
content=json.dumps(output) if not isinstance(output, str) else output,
tool_call_id=tool_call_id,
)
},
}
@pytest.fixture
def parity_v2_on(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
stream_module,
"get_flags",
lambda: AgentFeatureFlags(enable_stream_parity_v2=True),
)
@pytest.fixture
def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
stream_module,
"get_flags",
lambda: AgentFeatureFlags(enable_stream_parity_v2=False),
)
async def _drain(events: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Run ``_stream_agent_events`` against a fake agent and return the
SSE payloads (parsed JSON) it yielded.
"""
agent = _FakeAgent(events)
service = VercelStreamingService()
result = StreamResult()
config = {"configurable": {"thread_id": "test-thread"}}
sse_lines: list[str] = []
async for sse in _stream_agent_events(
agent, config, {}, service, result, step_prefix="thinking"
):
sse_lines.append(sse)
parsed: list[dict[str, Any]] = []
for line in sse_lines:
if not line.startswith("data: "):
continue
body = line[len("data: ") :].rstrip("\n")
if not body or body == "[DONE]":
continue
try:
parsed.append(json.loads(body))
except json.JSONDecodeError:
continue
return parsed
def _types(payloads: list[dict[str, Any]]) -> list[str]:
return [p.get("type", "?") for p in payloads]
def _of_type(payloads: list[dict[str, Any]], type_name: str) -> list[dict[str, Any]]:
return [p for p in payloads if p.get("type") == type_name]
# ---------------------------------------------------------------------------
# Helper: ``_legacy_match_lc_id`` is a pure refactor; assert behaviour.
# ---------------------------------------------------------------------------
class TestLegacyMatch:
def test_pops_first_id_bearing_chunk_with_matching_name(self) -> None:
chunks: list[dict[str, Any]] = [
{"id": "x1", "name": "ls"},
{"id": "y1", "name": "write_file"},
]
runs: dict[str, str] = {}
result = _legacy_match_lc_id(chunks, "write_file", "run-1", runs)
assert result == "y1"
assert chunks == [{"id": "x1", "name": "ls"}]
assert runs == {"run-1": "y1"}
def test_falls_back_to_any_id_bearing_when_name_mismatches(self) -> None:
chunks: list[dict[str, Any]] = [{"id": "anon", "name": None}]
runs: dict[str, str] = {}
out = _legacy_match_lc_id(chunks, "ls", "run-2", runs)
assert out == "anon"
assert chunks == []
def test_returns_none_when_no_id_bearing_chunk(self) -> None:
chunks: list[dict[str, Any]] = [{"id": None, "name": None}]
runs: dict[str, str] = {}
assert _legacy_match_lc_id(chunks, "ls", "run-3", runs) is None
assert chunks == [{"id": None, "name": None}]
assert runs == {}
# ---------------------------------------------------------------------------
# parity_v2 wire format tests.
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_idless_chunk_merging_by_index(parity_v2_on: None) -> None:
"""First chunk carries id+name; later idless chunks at the same
``index`` merge into the SAME ``tool-input-start`` ui id and emit
one ``tool-input-delta`` per chunk."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-1", "name": "write_file", "args": '{"file', "index": 0}
],
),
_model_stream(
tool_call_chunks=[
{"id": None, "name": None, "args": '_path":"/x"}', "index": 0}
],
),
_tool_start(
name="write_file", run_id="run-A", input_payload={"file_path": "/x"}
),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
deltas = _of_type(payloads, "tool-input-delta")
available = _of_type(payloads, "tool-input-available")
output = _of_type(payloads, "tool-output-available")
assert len(starts) == 1
assert starts[0]["toolCallId"] == "lc-1"
assert starts[0]["toolName"] == "write_file"
assert starts[0]["langchainToolCallId"] == "lc-1"
assert [d["inputTextDelta"] for d in deltas] == ['{"file', '_path":"/x"}']
assert all(d["toolCallId"] == "lc-1" for d in deltas)
assert len(available) == 1
assert available[0]["toolCallId"] == "lc-1"
assert len(output) == 1
assert output[0]["toolCallId"] == "lc-1"
@pytest.mark.asyncio
async def test_two_interleaved_tool_calls_route_by_index(
parity_v2_on: None,
) -> None:
"""Two same-name calls with distinct indices keep their deltas
routed to the right card."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-A", "name": "write_file", "args": '{"a":1', "index": 0},
{"id": "lc-B", "name": "write_file", "args": '{"b":2', "index": 1},
]
),
_model_stream(
tool_call_chunks=[
{"id": None, "name": None, "args": "}", "index": 0},
{"id": None, "name": None, "args": "}", "index": 1},
]
),
_tool_start(name="write_file", run_id="run-A", input_payload={"a": 1}),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-A"),
_tool_start(name="write_file", run_id="run-B", input_payload={"b": 2}),
_tool_end(name="write_file", run_id="run-B", tool_call_id="lc-B"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
deltas = _of_type(payloads, "tool-input-delta")
output = _of_type(payloads, "tool-output-available")
assert {s["toolCallId"] for s in starts} == {"lc-A", "lc-B"}
by_id: dict[str, list[str]] = {"lc-A": [], "lc-B": []}
for d in deltas:
by_id[d["toolCallId"]].append(d["inputTextDelta"])
assert by_id["lc-A"] == ['{"a":1', "}"]
assert by_id["lc-B"] == ['{"b":2', "}"]
assert {o["toolCallId"] for o in output} == {"lc-A", "lc-B"}
@pytest.mark.asyncio
async def test_identity_stable_across_lifecycle(parity_v2_on: None) -> None:
"""Whatever id ``tool-input-start`` chose must be the SAME id used
on ``tool-input-available`` AND ``tool-output-available``."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-9", "name": "ls", "args": '{"path":"/"}', "index": 0}
]
),
_tool_start(name="ls", run_id="run-X", input_payload={"path": "/"}),
_tool_end(name="ls", run_id="run-X", tool_call_id="lc-9"),
]
payloads = await _drain(events)
relevant = [
p
for p in payloads
if p.get("type")
in {"tool-input-start", "tool-input-available", "tool-output-available"}
]
assert {p["toolCallId"] for p in relevant} == {"lc-9"}
@pytest.mark.asyncio
async def test_no_duplicate_tool_input_start(parity_v2_on: None) -> None:
"""When the chunk-emission loop already fired ``tool-input-start``
for this run, ``on_tool_start`` MUST NOT emit a second one."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-1", "name": "write_file", "args": "{}", "index": 0}
]
),
_tool_start(name="write_file", run_id="run-A", input_payload={}),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
assert len(starts) == 1
assert starts[0]["toolCallId"] == "lc-1"
@pytest.mark.asyncio
async def test_active_text_closes_before_early_tool_input_start(
parity_v2_on: None,
) -> None:
"""Streaming a text-delta then a tool-call chunk in subsequent
chunks: the wire MUST contain ``text-end`` before the FIRST
``tool-input-start`` (clean part boundary on the frontend)."""
events = [
_model_stream(text="Working on it"),
_model_stream(
tool_call_chunks=[
{"id": "lc-1", "name": "write_file", "args": "{}", "index": 0}
]
),
_tool_start(name="write_file", run_id="run-A", input_payload={}),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
]
types = _types(await _drain(events))
text_end_idx = types.index("text-end")
start_idx = types.index("tool-input-start")
assert text_end_idx < start_idx
@pytest.mark.asyncio
async def test_mixed_text_and_tool_chunk_preserve_order(
parity_v2_on: None,
) -> None:
"""One AIMessageChunk that carries BOTH ``text`` content AND
``tool_call_chunks`` should emit the text delta FIRST, then close
text, then ``tool-input-start``+``tool-input-delta``."""
events = [
_model_stream(
text="I'll update it",
tool_call_chunks=[
{
"id": "lc-1",
"name": "write_file",
"args": '{"file_path":"/x"}',
"index": 0,
}
],
),
_tool_start(
name="write_file", run_id="run-A", input_payload={"file_path": "/x"}
),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
]
types = _types(await _drain(events))
# text-start … text-delta … text-end … tool-input-start … tool-input-delta
assert types.index("text-start") < types.index("text-delta")
assert types.index("text-delta") < types.index("text-end")
assert types.index("text-end") < types.index("tool-input-start")
assert types.index("tool-input-start") < types.index("tool-input-delta")
@pytest.mark.asyncio
async def test_parity_v2_off_preserves_legacy_shape(
parity_v2_off: None,
) -> None:
"""When the flag is OFF, no deltas are emitted and the ``toolCallId``
is ``call_<run_id>`` (NOT the lc id)."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-1", "name": "ls", "args": '{"path":"/"}', "index": 0}
]
),
_tool_start(name="ls", run_id="run-A", input_payload={"path": "/"}),
_tool_end(name="ls", run_id="run-A", tool_call_id="lc-1"),
]
payloads = await _drain(events)
assert _of_type(payloads, "tool-input-delta") == []
starts = _of_type(payloads, "tool-input-start")
assert len(starts) == 1
assert starts[0]["toolCallId"].startswith("call_run-A")
# No ``langchainToolCallId`` propagation on ``tool-input-start`` in
# legacy mode (the start event fires before the ToolMessage is
# available, so we can't extract the authoritative LangChain id yet).
assert "langchainToolCallId" not in starts[0]
output = _of_type(payloads, "tool-output-available")
assert output[0]["toolCallId"].startswith("call_run-A")
# ``tool-output-available`` MUST carry ``langchainToolCallId`` even
# in legacy mode: the chat tool card uses it to backfill the
# LangChain id and join against the ``data-action-log`` SSE event
# (keyed by ``lc_tool_call_id``) so the inline Revert button can
# light up. Sourced from the returned ``ToolMessage.tool_call_id``,
# which is populated regardless of feature-flag state.
assert output[0]["langchainToolCallId"] == "lc-1"
@pytest.mark.asyncio
async def test_skip_append_prevents_stale_id_reuse(
parity_v2_on: None,
) -> None:
"""Two same-name tools: the SECOND tool's ``langchainToolCallId``
must NOT come from the first tool's chunk (``pending_tool_call_chunks``
must stay empty for indexed-registered chunks)."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-A", "name": "write_file", "args": "{}", "index": 0},
{"id": "lc-B", "name": "write_file", "args": "{}", "index": 1},
]
),
_tool_start(name="write_file", run_id="run-1", input_payload={}),
_tool_end(name="write_file", run_id="run-1", tool_call_id="lc-A"),
_tool_start(name="write_file", run_id="run-2", input_payload={}),
_tool_end(name="write_file", run_id="run-2", tool_call_id="lc-B"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
# Two distinct lc ids, each its own card.
assert {s["toolCallId"] for s in starts} == {"lc-A", "lc-B"}
# Each tool-output-available landed on its respective card.
output = _of_type(payloads, "tool-output-available")
assert {o["toolCallId"] for o in output} == {"lc-A", "lc-B"}
@pytest.mark.asyncio
async def test_registration_waits_for_both_id_and_name(
parity_v2_on: None,
) -> None:
"""An id-only chunk (no name yet) must NOT emit ``tool-input-start``."""
events = [
_model_stream(
tool_call_chunks=[{"id": "lc-1", "name": None, "args": "", "index": 0}]
),
]
payloads = await _drain(events)
assert _of_type(payloads, "tool-input-start") == []
@pytest.mark.asyncio
async def test_unmatched_fallback_still_attaches_lc_id(
parity_v2_on: None,
) -> None:
"""parity_v2 ON, but the provider didn't include an ``index``: the
legacy fallback path must still emit ``tool-input-start`` with the
matching ``langchainToolCallId``."""
events = [
# No index on the chunk → not registered into index_to_meta;
# falls through to ``pending_tool_call_chunks`` so the legacy
# match path can pop it at on_tool_start.
_model_stream(tool_call_chunks=[{"id": "lc-orphan", "name": "ls", "args": ""}]),
_tool_start(name="ls", run_id="run-1", input_payload={"path": "/"}),
_tool_end(name="ls", run_id="run-1", tool_call_id="lc-orphan"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
assert len(starts) == 1
assert starts[0]["toolCallId"].startswith("call_run-1")
assert starts[0]["langchainToolCallId"] == "lc-orphan"