mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-05 13:52:40 +02:00
feat: moved chat persistance to Server Side
This commit is contained in:
parent
2e1b9b5582
commit
19b6e0a025
19 changed files with 4515 additions and 390 deletions
0
surfsense_backend/tests/integration/chat/__init__.py
Normal file
0
surfsense_backend/tests/integration/chat/__init__.py
Normal file
|
|
@ -0,0 +1,573 @@
|
|||
"""Integration tests for the cross-writer integration between the
|
||||
streaming chat task and the legacy ``POST /threads/{id}/messages``
|
||||
(``append_message``) round-trip.
|
||||
|
||||
Two scenarios anchor the contract introduced by the server-side
|
||||
persistence rework:
|
||||
|
||||
(a) **Tool-heavy turn streamed to completion.**
|
||||
|
||||
Drives :class:`AssistantContentBuilder` with synthetic SSE events
|
||||
that mirror what ``_stream_agent_events`` emits for a turn that
|
||||
interleaves text, reasoning, a tool call (start/delta/available/
|
||||
output), and a final text block. Then runs
|
||||
:func:`finalize_assistant_turn` and asserts:
|
||||
|
||||
* ``new_chat_messages.content`` JSONB matches the
|
||||
``ContentPart[]`` shape the FE history loader expects, with full
|
||||
``args``/``argsText``/``result``/``langchainToolCallId`` for the
|
||||
tool call.
|
||||
* Exactly one ``token_usage`` row exists keyed on the assistant
|
||||
``message_id``.
|
||||
|
||||
(b) **Stale FE ``appendMessage`` after server finalize.**
|
||||
|
||||
Verifies the recovery branch of the ``append_message`` route now
|
||||
returns the SERVER's authoritative ``ContentPart[]`` (not the FE's
|
||||
stale payload) when the partial unique index from migration 141
|
||||
blocks the FE's INSERT, and that the ``ON CONFLICT DO NOTHING``
|
||||
clause from migration 142 stops the route from writing a duplicate
|
||||
``token_usage`` row.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import (
|
||||
ChatVisibility,
|
||||
NewChatMessage,
|
||||
NewChatMessageRole,
|
||||
NewChatThread,
|
||||
SearchSpace,
|
||||
TokenUsage,
|
||||
User,
|
||||
)
|
||||
from app.routes import new_chat_routes
|
||||
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||
from app.tasks.chat import persistence as persistence_module
|
||||
from app.tasks.chat.content_builder import AssistantContentBuilder
|
||||
from app.tasks.chat.persistence import (
|
||||
finalize_assistant_turn,
|
||||
persist_assistant_shell,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_thread(
|
||||
db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
|
||||
) -> NewChatThread:
|
||||
thread = NewChatThread(
|
||||
title="Test Chat",
|
||||
search_space_id=db_search_space.id,
|
||||
created_by_id=db_user.id,
|
||||
visibility=ChatVisibility.PRIVATE,
|
||||
)
|
||||
db_session.add(thread)
|
||||
await db_session.flush()
|
||||
return thread
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_shielded_session(monkeypatch, db_session: AsyncSession):
|
||||
"""Route persistence helpers to the test's savepoint-bound session.
|
||||
|
||||
Mirrors the helper from ``test_persistence.py`` so the helpers'
|
||||
internal ``ws.commit()`` / ``ws.rollback()`` resolve to SAVEPOINT
|
||||
operations on the test transaction instead of touching real
|
||||
autocommit boundaries.
|
||||
"""
|
||||
|
||||
@asynccontextmanager
|
||||
async def _fake_shielded_session():
|
||||
yield db_session
|
||||
|
||||
monkeypatch.setattr(
|
||||
persistence_module,
|
||||
"shielded_async_session",
|
||||
_fake_shielded_session,
|
||||
)
|
||||
return db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bypass_permission_checks(monkeypatch):
|
||||
"""Replace RBAC + thread access checks with no-ops.
|
||||
|
||||
The append_message route under test calls ``check_permission`` and
|
||||
``check_thread_access``; those rely on a SearchSpaceMembership row
|
||||
that the existing integration fixtures don't create. The contract
|
||||
we want to verify here is the ``IntegrityError`` -> recovery branch,
|
||||
not the RBAC plumbing — so stub them.
|
||||
"""
|
||||
|
||||
async def _allow(*_args, **_kwargs):
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(new_chat_routes, "check_permission", _allow)
|
||||
monkeypatch.setattr(new_chat_routes, "check_thread_access", _allow)
|
||||
return None
|
||||
|
||||
|
||||
class _FakeRequest:
|
||||
"""Minimal Request stand-in used by ``append_message``.
|
||||
|
||||
The route only calls ``await request.json()`` — keep the surface
|
||||
area tight so this doesn't accidentally hide future signature
|
||||
changes that we *would* want to break the test.
|
||||
"""
|
||||
|
||||
def __init__(self, body: dict):
|
||||
self._body = body
|
||||
|
||||
async def json(self) -> dict:
|
||||
return self._body
|
||||
|
||||
|
||||
def _build_tool_heavy_content() -> list[dict]:
|
||||
"""Drive ``AssistantContentBuilder`` through a tool-heavy turn.
|
||||
|
||||
Produces the same ``ContentPart[]`` shape the streaming layer would
|
||||
persist if ``_stream_agent_events`` ran a turn with: opening
|
||||
reasoning -> text -> tool call (input start/delta/available/output)
|
||||
-> closing text. Centralised here so the (a) and (b) scenarios use
|
||||
the same authoritative payload.
|
||||
"""
|
||||
builder = AssistantContentBuilder()
|
||||
|
||||
builder.on_reasoning_start("r1")
|
||||
builder.on_reasoning_delta("r1", "Let me look up ")
|
||||
builder.on_reasoning_delta("r1", "the file listing.")
|
||||
builder.on_reasoning_end("r1")
|
||||
|
||||
builder.on_text_start("t1")
|
||||
builder.on_text_delta("t1", "Sure, listing files in ")
|
||||
builder.on_text_delta("t1", "/.")
|
||||
builder.on_text_end("t1")
|
||||
|
||||
builder.on_tool_input_start(
|
||||
"tool_call_ui_1",
|
||||
tool_name="ls",
|
||||
langchain_tool_call_id="lc_call_xyz",
|
||||
)
|
||||
builder.on_tool_input_delta("tool_call_ui_1", '{"path"')
|
||||
builder.on_tool_input_delta("tool_call_ui_1", ': "/"}')
|
||||
builder.on_tool_input_available(
|
||||
"tool_call_ui_1",
|
||||
tool_name="ls",
|
||||
args={"path": "/"},
|
||||
langchain_tool_call_id="lc_call_xyz",
|
||||
)
|
||||
builder.on_tool_output_available(
|
||||
"tool_call_ui_1",
|
||||
output={"files": ["a.txt", "b.txt"]},
|
||||
langchain_tool_call_id="lc_call_xyz",
|
||||
)
|
||||
|
||||
builder.on_text_start("t2")
|
||||
builder.on_text_delta("t2", "Found 2 files: a.txt and b.txt.")
|
||||
builder.on_text_end("t2")
|
||||
|
||||
return builder.snapshot()
|
||||
|
||||
|
||||
def _accumulator_with_one_call() -> TurnTokenAccumulator:
|
||||
acc = TurnTokenAccumulator()
|
||||
acc.add(
|
||||
model="gpt-4o-mini",
|
||||
prompt_tokens=200,
|
||||
completion_tokens=80,
|
||||
total_tokens=280,
|
||||
cost_micros=22222,
|
||||
)
|
||||
return acc
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# (a) Tool-heavy stream finalize
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToolHeavyTurnFinalize:
|
||||
async def test_full_tool_call_persisted_and_one_token_usage_row(
|
||||
self,
|
||||
db_session,
|
||||
db_user,
|
||||
db_thread,
|
||||
db_search_space,
|
||||
patched_shielded_session,
|
||||
):
|
||||
"""End-to-end seam: builder snapshot -> finalize -> DB row.
|
||||
|
||||
Matches the production flow's *content* invariant: whatever
|
||||
``AssistantContentBuilder.snapshot()`` produces is what the
|
||||
streaming layer hands to ``finalize_assistant_turn``, so this
|
||||
test catches any drift between the JSONB shape the builder
|
||||
emits and the one the FE history loader expects.
|
||||
"""
|
||||
thread_id = db_thread.id
|
||||
user_id_str = str(db_user.id)
|
||||
search_space_id = db_search_space.id
|
||||
turn_id = f"{thread_id}:tool_heavy"
|
||||
|
||||
msg_id = await persist_assistant_shell(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
assert msg_id is not None
|
||||
|
||||
snapshot = _build_tool_heavy_content()
|
||||
# Sanity-check the snapshot before we hand it to the DB so a
|
||||
# builder regression surfaces here, not deep inside an opaque
|
||||
# JSONB diff.
|
||||
assert any(p.get("type") == "reasoning" for p in snapshot)
|
||||
text_parts = [p for p in snapshot if p.get("type") == "text"]
|
||||
assert len(text_parts) == 2
|
||||
tool_parts = [p for p in snapshot if p.get("type") == "tool-call"]
|
||||
assert len(tool_parts) == 1
|
||||
tool_part = tool_parts[0]
|
||||
assert tool_part["toolCallId"] == "tool_call_ui_1"
|
||||
assert tool_part["toolName"] == "ls"
|
||||
assert tool_part["args"] == {"path": "/"}
|
||||
# ``argsText`` ends up as the pretty-printed final args (the
|
||||
# ``tool-input-available`` event replaces the streamed deltas
|
||||
# with ``json.dumps(args, indent=2)`` to match the FE's
|
||||
# post-stream rendering).
|
||||
assert tool_part["argsText"] == '{\n "path": "/"\n}'
|
||||
assert tool_part["result"] == {"files": ["a.txt", "b.txt"]}
|
||||
# ``langchainToolCallId`` is the agent-side correlation id used
|
||||
# by the regenerate path; a missing one breaks
|
||||
# edit-from-tool-call later.
|
||||
assert tool_part["langchainToolCallId"] == "lc_call_xyz"
|
||||
|
||||
await finalize_assistant_turn(
|
||||
message_id=msg_id,
|
||||
chat_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
content=snapshot,
|
||||
accumulator=_accumulator_with_one_call(),
|
||||
)
|
||||
|
||||
# ``content`` must round-trip byte-for-byte through the JSONB
|
||||
# column. SQLAlchemy doesn't auto-refresh the row that survived
|
||||
# the savepoint commit, so refresh explicitly.
|
||||
row = await db_session.get(NewChatMessage, msg_id)
|
||||
await db_session.refresh(row)
|
||||
|
||||
# The history loader reads ``content`` straight into the FE's
|
||||
# parts array, so a strict equality comparison is the right
|
||||
# invariant here.
|
||||
assert row.content == snapshot
|
||||
# Tool-call parts must JSON-serialise cleanly — nothing in
|
||||
# ``args`` / ``argsText`` / ``result`` should accidentally be a
|
||||
# non-JSON-safe value (datetime, set, custom class).
|
||||
assert json.dumps(row.content)
|
||||
|
||||
usage_count = (
|
||||
await db_session.execute(
|
||||
select(func.count())
|
||||
.select_from(TokenUsage)
|
||||
.where(TokenUsage.message_id == msg_id)
|
||||
)
|
||||
).scalar_one()
|
||||
assert usage_count == 1
|
||||
|
||||
usage = (
|
||||
await db_session.execute(
|
||||
select(TokenUsage).where(TokenUsage.message_id == msg_id)
|
||||
)
|
||||
).scalar_one()
|
||||
assert usage.usage_type == "chat"
|
||||
assert usage.prompt_tokens == 200
|
||||
assert usage.completion_tokens == 80
|
||||
assert usage.total_tokens == 280
|
||||
assert usage.cost_micros == 22222
|
||||
assert usage.thread_id == thread_id
|
||||
assert usage.search_space_id == search_space_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# (b) FE appendMessage after server finalize
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAppendMessageRecoveryAfterFinalize:
|
||||
async def test_returns_server_content_and_does_not_duplicate_token_usage(
|
||||
self,
|
||||
db_session,
|
||||
db_user,
|
||||
db_thread,
|
||||
db_search_space,
|
||||
patched_shielded_session,
|
||||
bypass_permission_checks,
|
||||
):
|
||||
"""FE's stale ``appendMessage`` after server finalize.
|
||||
|
||||
The frontend used to be the authoritative writer for assistant
|
||||
``content``. Now the server is. When the legacy FE round-trip
|
||||
fires *after* the server has already finalized:
|
||||
|
||||
* the route's INSERT trips the (thread_id, turn_id, role)
|
||||
partial unique index from migration 141,
|
||||
* the recovery branch fetches the existing row and returns
|
||||
*its* ``content`` — discarding the FE payload — so the
|
||||
history loader reads the rich server payload (full tool
|
||||
args, argsText, langchainToolCallId, etc.) on next page
|
||||
reload,
|
||||
* the route's optional ``token_usage`` insert is keyed on the
|
||||
partial unique index from migration 142 so it silently
|
||||
no-ops if the server already wrote one.
|
||||
"""
|
||||
thread_id = db_thread.id
|
||||
user_id_str = str(db_user.id)
|
||||
search_space_id = db_search_space.id
|
||||
turn_id = f"{thread_id}:fe_late_append"
|
||||
|
||||
# Step 1: server stream completes. Server-built rich content is
|
||||
# finalized.
|
||||
msg_id = await persist_assistant_shell(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
assert msg_id is not None
|
||||
|
||||
server_content = _build_tool_heavy_content()
|
||||
await finalize_assistant_turn(
|
||||
message_id=msg_id,
|
||||
chat_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
content=server_content,
|
||||
accumulator=_accumulator_with_one_call(),
|
||||
)
|
||||
|
||||
# Step 2: simulate the legacy FE ``appendMessage`` round-trip
|
||||
# arriving with stale, lossy content (missing tool args, etc.)
|
||||
# plus a ``token_usage`` body.
|
||||
fe_stale_content = [
|
||||
{"type": "text", "text": "Found 2 files: a.txt and b.txt."},
|
||||
]
|
||||
fe_request_body = {
|
||||
"role": "assistant",
|
||||
"content": fe_stale_content,
|
||||
"turn_id": turn_id,
|
||||
"token_usage": {
|
||||
"prompt_tokens": 999,
|
||||
"completion_tokens": 999,
|
||||
"total_tokens": 1998,
|
||||
"cost_micros": 88888,
|
||||
"usage": {"any": "thing"},
|
||||
"call_details": {"calls": []},
|
||||
},
|
||||
}
|
||||
request = _FakeRequest(fe_request_body)
|
||||
|
||||
# ``db_user`` is bound to ``db_session``. The route's
|
||||
# IntegrityError branch calls ``session.rollback()``, which
|
||||
# expires every ORM row attached to the session including this
|
||||
# user — historically causing ``user.id`` to lazy-load
|
||||
# out-of-greenlet and crash the request with ``MissingGreenlet``
|
||||
# (observed in production logs at /api/v1/threads/531/messages
|
||||
# 2026-05-04). The route now captures ``user.id`` to a primitive
|
||||
# UUID at the top of the handler, so the rollback can't reach
|
||||
# it. Pass the *attached* user here on purpose — that's the
|
||||
# production scenario, and this test is the regression guard
|
||||
# against that bug returning.
|
||||
response = await new_chat_routes.append_message(
|
||||
thread_id=thread_id,
|
||||
request=request,
|
||||
session=db_session,
|
||||
user=db_user,
|
||||
)
|
||||
|
||||
# Response must echo the SERVER's rich payload, not the FE's
|
||||
# stale snapshot. This is the user-visible part of the
|
||||
# contract: history reload + ThreadHistoryAdapter.append both
|
||||
# read from the same authoritative source.
|
||||
assert response.id == msg_id
|
||||
assert response.role == NewChatMessageRole.ASSISTANT
|
||||
assert response.turn_id == turn_id
|
||||
assert response.content == server_content
|
||||
assert response.content != fe_stale_content
|
||||
|
||||
# The on-disk row must agree with the response.
|
||||
row = await db_session.get(NewChatMessage, msg_id)
|
||||
await db_session.refresh(row)
|
||||
assert row.content == server_content
|
||||
|
||||
# ``token_usage``: exactly one row, with the *server's* values
|
||||
# (the FE's much larger token counts must not have overwritten
|
||||
# them).
|
||||
usage_count = (
|
||||
await db_session.execute(
|
||||
select(func.count())
|
||||
.select_from(TokenUsage)
|
||||
.where(TokenUsage.message_id == msg_id)
|
||||
)
|
||||
).scalar_one()
|
||||
assert usage_count == 1
|
||||
|
||||
usage = (
|
||||
await db_session.execute(
|
||||
select(TokenUsage).where(TokenUsage.message_id == msg_id)
|
||||
)
|
||||
).scalar_one()
|
||||
assert usage.cost_micros == 22222 # Server's value, not 88888
|
||||
assert usage.total_tokens == 280 # Server's value, not 1998
|
||||
|
||||
async def test_legacy_fe_first_appendmessage_then_server_no_dupe(
|
||||
self,
|
||||
db_session,
|
||||
db_user,
|
||||
db_thread,
|
||||
db_search_space,
|
||||
patched_shielded_session,
|
||||
bypass_permission_checks,
|
||||
):
|
||||
"""Inverse race: legacy FE writes first, server finalize second.
|
||||
|
||||
Some clients still post ``appendMessage`` before the streaming
|
||||
``finally`` runs. The contract is symmetric: whichever writer
|
||||
loses the (thread_id, turn_id, role) race silently lets the
|
||||
winner keep its content. In particular the *server's*
|
||||
finalize must NOT raise — it must look up the existing row and
|
||||
UPDATE its content with the server-built payload (which is
|
||||
always richer/more authoritative than whatever the FE
|
||||
snapshot held).
|
||||
"""
|
||||
thread_id = db_thread.id
|
||||
user_id_str = str(db_user.id)
|
||||
search_space_id = db_search_space.id
|
||||
turn_id = f"{thread_id}:fe_first"
|
||||
|
||||
# Step 1: legacy FE appendMessage lands first. No prior shell
|
||||
# row exists; the route does the INSERT itself.
|
||||
fe_request_body = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "early FE write"}],
|
||||
"turn_id": turn_id,
|
||||
}
|
||||
fe_response = await new_chat_routes.append_message(
|
||||
thread_id=thread_id,
|
||||
request=_FakeRequest(fe_request_body),
|
||||
session=db_session,
|
||||
user=db_user,
|
||||
)
|
||||
assert fe_response.role == NewChatMessageRole.ASSISTANT
|
||||
|
||||
# Step 2: server stream's persist_assistant_shell now races
|
||||
# behind. It must adopt the existing row id, not raise.
|
||||
adopted_id = await persist_assistant_shell(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
assert adopted_id == fe_response.id
|
||||
|
||||
# Step 3: server finalize then overwrites the FE's stub with
|
||||
# the rich content (which is the correct, more authoritative
|
||||
# payload).
|
||||
server_content = _build_tool_heavy_content()
|
||||
await finalize_assistant_turn(
|
||||
message_id=adopted_id,
|
||||
chat_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
content=server_content,
|
||||
accumulator=_accumulator_with_one_call(),
|
||||
)
|
||||
|
||||
# Final state: one row, server content, one token_usage row.
|
||||
msg_count = (
|
||||
await db_session.execute(
|
||||
select(func.count())
|
||||
.select_from(NewChatMessage)
|
||||
.where(
|
||||
NewChatMessage.thread_id == thread_id,
|
||||
NewChatMessage.turn_id == turn_id,
|
||||
NewChatMessage.role == NewChatMessageRole.ASSISTANT,
|
||||
)
|
||||
)
|
||||
).scalar_one()
|
||||
assert msg_count == 1
|
||||
|
||||
row = await db_session.get(NewChatMessage, adopted_id)
|
||||
await db_session.refresh(row)
|
||||
assert row.content == server_content
|
||||
|
||||
usage_count = (
|
||||
await db_session.execute(
|
||||
select(func.count())
|
||||
.select_from(TokenUsage)
|
||||
.where(TokenUsage.message_id == adopted_id)
|
||||
)
|
||||
).scalar_one()
|
||||
assert usage_count == 1
|
||||
|
||||
async def test_appendmessage_without_turn_id_legacy_400(
|
||||
self,
|
||||
db_session,
|
||||
db_user,
|
||||
db_thread,
|
||||
patched_shielded_session,
|
||||
bypass_permission_checks,
|
||||
):
|
||||
"""Defensive: a bare appendMessage with no turn_id and no
|
||||
existing row is just a normal INSERT — must succeed. But if a
|
||||
row with the same role already exists in this thread *without*
|
||||
turn_id collisions, the route should fall through to the
|
||||
legacy 400 path on a foreign-key / unrelated IntegrityError
|
||||
(we don't ship that bug today, but pin the behaviour so a
|
||||
future schema change can't silently regress it).
|
||||
"""
|
||||
thread_id = db_thread.id
|
||||
|
||||
# Bare appendMessage with no turn_id — should just succeed
|
||||
# without invoking the recovery branch.
|
||||
ok_response = await new_chat_routes.append_message(
|
||||
thread_id=thread_id,
|
||||
request=_FakeRequest(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "hi"}],
|
||||
}
|
||||
),
|
||||
session=db_session,
|
||||
user=db_user,
|
||||
)
|
||||
assert ok_response.role == NewChatMessageRole.USER
|
||||
assert ok_response.turn_id is None
|
||||
|
||||
# Sanity: the route did NOT silently swallow the missing
|
||||
# turn_id by routing through the unique-index recovery branch
|
||||
# — it took the happy path.
|
||||
msg_count = (
|
||||
await db_session.execute(
|
||||
select(func.count())
|
||||
.select_from(NewChatMessage)
|
||||
.where(
|
||||
NewChatMessage.thread_id == thread_id,
|
||||
NewChatMessage.role == NewChatMessageRole.USER,
|
||||
)
|
||||
)
|
||||
).scalar_one()
|
||||
assert msg_count == 1
|
||||
332
surfsense_backend/tests/integration/chat/test_message_id_sse.py
Normal file
332
surfsense_backend/tests/integration/chat/test_message_id_sse.py
Normal file
|
|
@ -0,0 +1,332 @@
|
|||
"""Integration tests for the SSE-based message ID handshake.
|
||||
|
||||
The streaming generators (``stream_new_chat`` / ``stream_resume_chat``)
|
||||
emit two new events after their respective persistence helpers resolve
|
||||
the canonical ``new_chat_messages.id``:
|
||||
|
||||
* ``data-user-message-id`` — emitted only by ``stream_new_chat``,
|
||||
AFTER ``persist_user_turn`` and BEFORE any LLM streaming.
|
||||
* ``data-assistant-message-id`` — emitted by both
|
||||
``stream_new_chat`` and ``stream_resume_chat``, AFTER
|
||||
``persist_assistant_shell`` and BEFORE any LLM streaming.
|
||||
|
||||
The frontend renames its optimistic ``msg-user-XXX`` /
|
||||
``msg-assistant-XXX`` placeholder ids to ``msg-{db_id}`` upon receiving
|
||||
these events. This test suite anchors three contracts:
|
||||
|
||||
1. ``format_data`` produces SSE bytes in the precise shape
|
||||
``data: {"type":"data-<suffix>","data":{...}}\\n\\n`` that the FE's
|
||||
``readSSEStream`` consumer parses (matches ``surfsense_web/lib/chat/streaming-state.ts``).
|
||||
2. The ``message_id`` carried in the SSE payload exactly equals the
|
||||
primary key the persistence helper inserted into
|
||||
``new_chat_messages`` — so the FE rename produces ``msg-{real_pk}``,
|
||||
which in turn unlocks DB-id-gated UI (comments, edit-from-message).
|
||||
3. The same ``message_id`` is used for the ``token_usage.message_id``
|
||||
foreign key, so ``finalize_assistant_turn``'s row binds correctly.
|
||||
|
||||
Direct end-to-end testing of ``stream_new_chat`` requires a fully
|
||||
mocked agent + LLM stack (out-of-scope here); those flows are covered
|
||||
by the harness-driven integration tests under
|
||||
``tests/integration/agents/new_chat/`` plus the assertion in
|
||||
``test_persistence.py`` that the helpers themselves return ``int``
|
||||
ids. The contracts above close the remaining gap between the persist
|
||||
helpers and the bytes that ship to the FE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import (
|
||||
ChatVisibility,
|
||||
NewChatMessage,
|
||||
NewChatMessageRole,
|
||||
NewChatThread,
|
||||
SearchSpace,
|
||||
User,
|
||||
)
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.tasks.chat import persistence as persistence_module
|
||||
from app.tasks.chat.persistence import (
|
||||
persist_assistant_shell,
|
||||
persist_user_turn,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures (mirror test_persistence.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_thread(
|
||||
db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
|
||||
) -> NewChatThread:
|
||||
thread = NewChatThread(
|
||||
title="Test Chat",
|
||||
search_space_id=db_search_space.id,
|
||||
created_by_id=db_user.id,
|
||||
visibility=ChatVisibility.PRIVATE,
|
||||
)
|
||||
db_session.add(thread)
|
||||
await db_session.flush()
|
||||
return thread
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_shielded_session(monkeypatch, db_session: AsyncSession):
|
||||
"""Route persistence helpers to the test's savepoint-bound session."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def _fake_shielded_session():
|
||||
yield db_session
|
||||
|
||||
monkeypatch.setattr(
|
||||
persistence_module,
|
||||
"shielded_async_session",
|
||||
_fake_shielded_session,
|
||||
)
|
||||
return db_session
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# (1) SSE byte-shape contract
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_sse_data_line(blob: str) -> dict:
|
||||
"""Unwrap a single ``data: <json>\\n\\n`` SSE frame.
|
||||
|
||||
Raises if there's more than one frame or the prefix is wrong — keeps
|
||||
the parser strict so a regression in ``format_data`` produces a
|
||||
test failure here, not in a downstream consumer.
|
||||
"""
|
||||
assert blob.endswith("\n\n"), f"missing terminator: {blob!r}"
|
||||
line = blob.removesuffix("\n\n")
|
||||
assert line.startswith("data: "), f"missing data prefix: {line!r}"
|
||||
return json.loads(line.removeprefix("data: "))
|
||||
|
||||
|
||||
class TestSSEByteShape:
|
||||
def test_data_user_message_id_byte_shape(self):
|
||||
"""``format_data("user-message-id", {...})`` must produce the
|
||||
exact wire format the FE's
|
||||
``readSSEStream`` -> ``data-user-message-id`` case parses.
|
||||
"""
|
||||
svc = VercelStreamingService()
|
||||
blob = svc.format_data(
|
||||
"user-message-id",
|
||||
{"message_id": 1843, "turn_id": "533:1762900000000"},
|
||||
)
|
||||
envelope = _parse_sse_data_line(blob)
|
||||
assert envelope == {
|
||||
"type": "data-user-message-id",
|
||||
"data": {"message_id": 1843, "turn_id": "533:1762900000000"},
|
||||
}
|
||||
|
||||
def test_data_assistant_message_id_byte_shape(self):
|
||||
svc = VercelStreamingService()
|
||||
blob = svc.format_data(
|
||||
"assistant-message-id",
|
||||
{"message_id": 1844, "turn_id": "533:1762900000000"},
|
||||
)
|
||||
envelope = _parse_sse_data_line(blob)
|
||||
assert envelope == {
|
||||
"type": "data-assistant-message-id",
|
||||
"data": {"message_id": 1844, "turn_id": "533:1762900000000"},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# (2) Helper-id <-> DB-pk coherence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHandshakeIdMatchesDB:
|
||||
"""The SSE handshake's correctness hinges on the integer in
|
||||
``data-{user,assistant}-message-id`` being the EXACT primary key
|
||||
the persistence helper inserted. If they ever diverge, the FE
|
||||
rename produces ``msg-{wrong_id}``, comments break (regex match
|
||||
fails), and downstream features (edit, regenerate) target the
|
||||
wrong row. Anchor it here.
|
||||
"""
|
||||
|
||||
async def test_user_message_id_matches_new_chat_messages_pk(
|
||||
self,
|
||||
db_session,
|
||||
db_user,
|
||||
db_thread,
|
||||
patched_shielded_session,
|
||||
):
|
||||
thread_id = db_thread.id
|
||||
user_id_str = str(db_user.id)
|
||||
turn_id = f"{thread_id}:9000"
|
||||
|
||||
# The streaming generator passes this same value into
|
||||
# ``streaming_service.format_data("user-message-id", {...})``.
|
||||
msg_id_from_helper = await persist_user_turn(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
user_query="hello",
|
||||
)
|
||||
assert isinstance(msg_id_from_helper, int)
|
||||
|
||||
# Look up the row the helper inserted via
|
||||
# ``(thread_id, turn_id, role)`` — the same composite the FE
|
||||
# uses to identify a turn — and confirm the PK matches.
|
||||
row = (
|
||||
await db_session.execute(
|
||||
select(NewChatMessage).where(
|
||||
NewChatMessage.thread_id == thread_id,
|
||||
NewChatMessage.turn_id == turn_id,
|
||||
NewChatMessage.role == NewChatMessageRole.USER,
|
||||
)
|
||||
)
|
||||
).scalar_one()
|
||||
assert row.id == msg_id_from_helper
|
||||
|
||||
# The byte-stream the FE actually receives — confirms the
|
||||
# round-trip from the helper return value to the SSE payload.
|
||||
svc = VercelStreamingService()
|
||||
envelope = _parse_sse_data_line(
|
||||
svc.format_data(
|
||||
"user-message-id",
|
||||
{"message_id": msg_id_from_helper, "turn_id": turn_id},
|
||||
)
|
||||
)
|
||||
assert envelope["data"]["message_id"] == row.id
|
||||
|
||||
async def test_assistant_message_id_matches_new_chat_messages_pk(
|
||||
self,
|
||||
db_session,
|
||||
db_user,
|
||||
db_thread,
|
||||
patched_shielded_session,
|
||||
):
|
||||
thread_id = db_thread.id
|
||||
user_id_str = str(db_user.id)
|
||||
turn_id = f"{thread_id}:9100"
|
||||
|
||||
msg_id_from_helper = await persist_assistant_shell(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
assert isinstance(msg_id_from_helper, int)
|
||||
|
||||
row = (
|
||||
await db_session.execute(
|
||||
select(NewChatMessage).where(
|
||||
NewChatMessage.thread_id == thread_id,
|
||||
NewChatMessage.turn_id == turn_id,
|
||||
NewChatMessage.role == NewChatMessageRole.ASSISTANT,
|
||||
)
|
||||
)
|
||||
).scalar_one()
|
||||
assert row.id == msg_id_from_helper
|
||||
|
||||
svc = VercelStreamingService()
|
||||
envelope = _parse_sse_data_line(
|
||||
svc.format_data(
|
||||
"assistant-message-id",
|
||||
{"message_id": msg_id_from_helper, "turn_id": turn_id},
|
||||
)
|
||||
)
|
||||
assert envelope["data"]["message_id"] == row.id
|
||||
|
||||
async def test_handshake_ids_for_full_turn_are_distinct_and_paired(
|
||||
self,
|
||||
db_session,
|
||||
db_user,
|
||||
db_thread,
|
||||
patched_shielded_session,
|
||||
):
|
||||
"""Sanity: a full new-chat turn's two SSE events carry two
|
||||
DIFFERENT ids (user row PK ≠ assistant row PK), both anchored
|
||||
to the SAME ``turn_id`` in the DB. This pairing is what
|
||||
``finalize_assistant_turn`` and the regenerate / edit flows
|
||||
rely on.
|
||||
"""
|
||||
thread_id = db_thread.id
|
||||
user_id_str = str(db_user.id)
|
||||
turn_id = f"{thread_id}:9200"
|
||||
|
||||
user_msg_id = await persist_user_turn(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
user_query="hi",
|
||||
)
|
||||
assistant_msg_id = await persist_assistant_shell(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
assert user_msg_id is not None and assistant_msg_id is not None
|
||||
assert user_msg_id != assistant_msg_id
|
||||
|
||||
rows = (
|
||||
(
|
||||
await db_session.execute(
|
||||
select(NewChatMessage)
|
||||
.where(
|
||||
NewChatMessage.thread_id == thread_id,
|
||||
NewChatMessage.turn_id == turn_id,
|
||||
)
|
||||
.order_by(NewChatMessage.id)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
assert len(rows) == 2
|
||||
ids_by_role = {r.role: r.id for r in rows}
|
||||
assert ids_by_role[NewChatMessageRole.USER] == user_msg_id
|
||||
assert ids_by_role[NewChatMessageRole.ASSISTANT] == assistant_msg_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# (3) Parse helpers used by the FE — sanity-check our payload shape
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPayloadShapeMatchesFEReader:
|
||||
"""The FE's ``readStreamedMessageId`` (in
|
||||
``surfsense_web/lib/chat/stream-side-effects.ts``) requires:
|
||||
|
||||
* ``message_id`` is a ``number`` (rejects null / string / NaN).
|
||||
* ``turn_id`` is an optional non-empty string (else it's coerced
|
||||
to ``null``).
|
||||
|
||||
These tests exercise the BE side of that contract by inspecting
|
||||
``format_data`` output shapes that the FE consumes verbatim.
|
||||
"""
|
||||
|
||||
def test_message_id_is_serialised_as_a_json_number(self):
|
||||
svc = VercelStreamingService()
|
||||
envelope = _parse_sse_data_line(
|
||||
svc.format_data("user-message-id", {"message_id": 42, "turn_id": "t"})
|
||||
)
|
||||
assert isinstance(envelope["data"]["message_id"], int)
|
||||
assert envelope["data"]["message_id"] == 42
|
||||
|
||||
def test_turn_id_round_trips_as_string(self):
|
||||
svc = VercelStreamingService()
|
||||
# The actual format used in production: f"{chat_id}:{int(time.time()*1000)}"
|
||||
production_turn_id = "533:1762900000000"
|
||||
envelope = _parse_sse_data_line(
|
||||
svc.format_data(
|
||||
"assistant-message-id",
|
||||
{"message_id": 1, "turn_id": production_turn_id},
|
||||
)
|
||||
)
|
||||
assert envelope["data"]["turn_id"] == production_turn_id
|
||||
747
surfsense_backend/tests/integration/chat/test_persistence.py
Normal file
747
surfsense_backend/tests/integration/chat/test_persistence.py
Normal file
|
|
@ -0,0 +1,747 @@
|
|||
"""Integration tests for ``app.tasks.chat.persistence``.
|
||||
|
||||
Verifies the DB-side guarantees the streaming chat task relies on:
|
||||
|
||||
* ``persist_assistant_shell`` is idempotent against the
|
||||
``(thread_id, turn_id, ASSISTANT)`` partial unique index from
|
||||
migration 141. Two calls with the same ``turn_id`` return the SAME
|
||||
``message_id`` and never create a duplicate ``new_chat_messages`` row.
|
||||
* ``finalize_assistant_turn`` writes a status-marker payload when given
|
||||
empty content, never raises, and is safe to call twice on the same
|
||||
``message_id`` — the partial unique index from migration 142
|
||||
(``uq_token_usage_message_id``) prevents the second insert from
|
||||
producing a duplicate ``token_usage`` row.
|
||||
* The same ``ON CONFLICT DO NOTHING`` invariant covers the cross-writer
|
||||
race where ``finalize_assistant_turn`` and the ``append_message``
|
||||
recovery branch both target the same ``message_id``.
|
||||
|
||||
All tests run inside the conftest's outer-transaction-with-savepoint
|
||||
fixture so commits inside the helpers (which open their own
|
||||
``shielded_async_session``) are released as savepoints and rolled back
|
||||
at test end. We monkey-patch ``shielded_async_session`` to yield the
|
||||
same pooled test session so the integration transaction stays
|
||||
in-scope.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import (
|
||||
ChatVisibility,
|
||||
NewChatMessage,
|
||||
NewChatMessageRole,
|
||||
NewChatThread,
|
||||
SearchSpace,
|
||||
TokenUsage,
|
||||
User,
|
||||
)
|
||||
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||
from app.tasks.chat import persistence as persistence_module
|
||||
from app.tasks.chat.persistence import (
|
||||
finalize_assistant_turn,
|
||||
persist_assistant_shell,
|
||||
persist_user_turn,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_thread(
|
||||
db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
|
||||
) -> NewChatThread:
|
||||
thread = NewChatThread(
|
||||
title="Test Chat",
|
||||
search_space_id=db_search_space.id,
|
||||
created_by_id=db_user.id,
|
||||
visibility=ChatVisibility.PRIVATE,
|
||||
)
|
||||
db_session.add(thread)
|
||||
await db_session.flush()
|
||||
return thread
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_shielded_session(monkeypatch, db_session: AsyncSession):
|
||||
"""Route persistence helpers to the test's savepoint-bound session.
|
||||
|
||||
The persistence helpers use ``async with shielded_async_session() as
|
||||
ws`` and call ``ws.commit()`` internally. Inside the conftest's
|
||||
``join_transaction_mode="create_savepoint"`` setup, those commits
|
||||
release a SAVEPOINT instead of committing the outer transaction —
|
||||
so the test session can see helper-staged rows immediately and the
|
||||
outer rollback at end of test wipes them.
|
||||
"""
|
||||
|
||||
@asynccontextmanager
|
||||
async def _fake_shielded_session():
|
||||
yield db_session
|
||||
# Do NOT close — the outer fixture owns the session lifecycle.
|
||||
|
||||
monkeypatch.setattr(
|
||||
persistence_module,
|
||||
"shielded_async_session",
|
||||
_fake_shielded_session,
|
||||
)
|
||||
return db_session
|
||||
|
||||
|
||||
def _accumulator_with_one_call() -> TurnTokenAccumulator:
|
||||
acc = TurnTokenAccumulator()
|
||||
acc.add(
|
||||
model="gpt-4o-mini",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
total_tokens=150,
|
||||
cost_micros=12345,
|
||||
)
|
||||
return acc
|
||||
|
||||
|
||||
async def _count_assistant_rows(
|
||||
session: AsyncSession, thread_id: int, turn_id: str
|
||||
) -> int:
|
||||
result = await session.execute(
|
||||
select(func.count())
|
||||
.select_from(NewChatMessage)
|
||||
.where(
|
||||
NewChatMessage.thread_id == thread_id,
|
||||
NewChatMessage.turn_id == turn_id,
|
||||
NewChatMessage.role == NewChatMessageRole.ASSISTANT,
|
||||
)
|
||||
)
|
||||
return int(result.scalar_one())
|
||||
|
||||
|
||||
async def _count_token_usage_rows(session: AsyncSession, message_id: int) -> int:
|
||||
result = await session.execute(
|
||||
select(func.count())
|
||||
.select_from(TokenUsage)
|
||||
.where(TokenUsage.message_id == message_id)
|
||||
)
|
||||
return int(result.scalar_one())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# persist_assistant_shell
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPersistAssistantShell:
|
||||
async def test_first_call_inserts_empty_shell_and_returns_id(
|
||||
self,
|
||||
db_session,
|
||||
db_user,
|
||||
db_thread,
|
||||
patched_shielded_session,
|
||||
):
|
||||
# Capture primitive ids before any persistence helper runs:
|
||||
# the helpers commit/rollback the shared test session, which
|
||||
# can detach ORM rows mid-test.
|
||||
thread_id = db_thread.id
|
||||
user_id_str = str(db_user.id)
|
||||
turn_id = f"{thread_id}:1000"
|
||||
|
||||
msg_id = await persist_assistant_shell(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
assert msg_id is not None and isinstance(msg_id, int)
|
||||
|
||||
row = await db_session.get(NewChatMessage, msg_id)
|
||||
assert row is not None
|
||||
assert row.thread_id == thread_id
|
||||
assert row.role == NewChatMessageRole.ASSISTANT
|
||||
assert row.turn_id == turn_id
|
||||
# Empty shell payload — finalize_assistant_turn overwrites later.
|
||||
assert row.content == [{"type": "text", "text": ""}]
|
||||
|
||||
async def test_second_call_with_same_turn_id_returns_same_id(
|
||||
self,
|
||||
db_session,
|
||||
db_user,
|
||||
db_thread,
|
||||
patched_shielded_session,
|
||||
):
|
||||
# Capture primitive ids before any persistence helper runs:
|
||||
# the helpers commit/rollback the shared test session, which
|
||||
# can detach ORM rows mid-test.
|
||||
thread_id = db_thread.id
|
||||
user_id_str = str(db_user.id)
|
||||
turn_id = f"{thread_id}:2000"
|
||||
|
||||
first_id = await persist_assistant_shell(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
second_id = await persist_assistant_shell(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
|
||||
assert first_id is not None
|
||||
assert first_id == second_id
|
||||
# Exactly one row in the DB for this turn.
|
||||
assert await _count_assistant_rows(db_session, thread_id, turn_id) == 1
|
||||
|
||||
async def test_missing_turn_id_returns_none(
|
||||
self,
|
||||
db_user,
|
||||
db_thread,
|
||||
patched_shielded_session,
|
||||
):
|
||||
thread_id = db_thread.id
|
||||
user_id_str = str(db_user.id)
|
||||
|
||||
msg_id = await persist_assistant_shell(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id="",
|
||||
)
|
||||
assert msg_id is None
|
||||
|
||||
async def test_after_persist_user_turn_resolves_assistant_id(
|
||||
self,
|
||||
db_session,
|
||||
db_user,
|
||||
db_thread,
|
||||
patched_shielded_session,
|
||||
):
|
||||
thread_id = db_thread.id
|
||||
user_id_str = str(db_user.id)
|
||||
turn_id = f"{thread_id}:3000"
|
||||
|
||||
# The streaming layer always calls persist_user_turn first, so
|
||||
# smoke-test the canonical sequence.
|
||||
user_msg_id = await persist_user_turn(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
user_query="hello",
|
||||
)
|
||||
assert isinstance(user_msg_id, int)
|
||||
|
||||
msg_id = await persist_assistant_shell(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
assert msg_id is not None
|
||||
# User row + assistant shell row = 2 rows for this turn.
|
||||
result = await db_session.execute(
|
||||
select(func.count())
|
||||
.select_from(NewChatMessage)
|
||||
.where(
|
||||
NewChatMessage.thread_id == thread_id,
|
||||
NewChatMessage.turn_id == turn_id,
|
||||
)
|
||||
)
|
||||
assert result.scalar_one() == 2
|
||||
|
||||
async def test_double_call_with_same_turn_id_uses_on_conflict(
|
||||
self,
|
||||
db_session,
|
||||
db_user,
|
||||
db_thread,
|
||||
patched_shielded_session,
|
||||
):
|
||||
"""Verifies the ON CONFLICT DO NOTHING path on the assistant
|
||||
shell does not raise ``IntegrityError`` even when the second
|
||||
writer races the first within a tight loop. ``test_second_call_with_same_turn_id_returns_same_id``
|
||||
already covers the same-id semantics; this test additionally
|
||||
asserts neither call raises so the debugger's
|
||||
``raise-on-IntegrityError`` setting won't pause the streaming
|
||||
path under contention.
|
||||
"""
|
||||
thread_id = db_thread.id
|
||||
user_id_str = str(db_user.id)
|
||||
turn_id = f"{thread_id}:3500"
|
||||
|
||||
# Both calls go through ``pg_insert(...).on_conflict_do_nothing``;
|
||||
# the second one returns RETURNING=∅ and falls into the SELECT
|
||||
# branch. Neither path raises.
|
||||
first_id = await persist_assistant_shell(
|
||||
chat_id=thread_id, user_id=user_id_str, turn_id=turn_id
|
||||
)
|
||||
second_id = await persist_assistant_shell(
|
||||
chat_id=thread_id, user_id=user_id_str, turn_id=turn_id
|
||||
)
|
||||
assert first_id is not None
|
||||
assert first_id == second_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# persist_user_turn
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPersistUserTurn:
|
||||
async def test_returns_message_id_on_first_insert(
|
||||
self,
|
||||
db_session,
|
||||
db_user,
|
||||
db_thread,
|
||||
patched_shielded_session,
|
||||
):
|
||||
thread_id = db_thread.id
|
||||
user_id_str = str(db_user.id)
|
||||
turn_id = f"{thread_id}:8000"
|
||||
|
||||
msg_id = await persist_user_turn(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
user_query="hello",
|
||||
)
|
||||
assert isinstance(msg_id, int) and msg_id > 0
|
||||
|
||||
row = await db_session.get(NewChatMessage, msg_id)
|
||||
assert row is not None
|
||||
assert row.thread_id == thread_id
|
||||
assert row.role == NewChatMessageRole.USER
|
||||
assert row.turn_id == turn_id
|
||||
assert row.content == [{"type": "text", "text": "hello"}]
|
||||
|
||||
async def test_returns_existing_id_on_conflict(
|
||||
self,
|
||||
db_session,
|
||||
db_user,
|
||||
db_thread,
|
||||
patched_shielded_session,
|
||||
):
|
||||
thread_id = db_thread.id
|
||||
user_id_str = str(db_user.id)
|
||||
turn_id = f"{thread_id}:8100"
|
||||
|
||||
first_id = await persist_user_turn(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
user_query="hello",
|
||||
)
|
||||
# Second call simulates a legacy FE ``appendMessage`` racing the
|
||||
# SSE stream: ON CONFLICT DO NOTHING short-circuits at the DB
|
||||
# level, the helper recovers the existing id via SELECT, and
|
||||
# crucially does NOT raise ``IntegrityError`` (the debugger
|
||||
# would otherwise pause on it).
|
||||
second_id = await persist_user_turn(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
user_query="ignored on conflict",
|
||||
)
|
||||
assert first_id is not None
|
||||
assert first_id == second_id
|
||||
|
||||
# Exactly one user row for this turn.
|
||||
count = await db_session.execute(
|
||||
select(func.count())
|
||||
.select_from(NewChatMessage)
|
||||
.where(
|
||||
NewChatMessage.thread_id == thread_id,
|
||||
NewChatMessage.turn_id == turn_id,
|
||||
NewChatMessage.role == NewChatMessageRole.USER,
|
||||
)
|
||||
)
|
||||
assert count.scalar_one() == 1
|
||||
|
||||
async def test_embeds_mentioned_documents_part(
|
||||
self,
|
||||
db_session,
|
||||
db_user,
|
||||
db_thread,
|
||||
patched_shielded_session,
|
||||
):
|
||||
"""The full ``{id, title, document_type}`` triple forwarded by
|
||||
the FE must round-trip into a single ``mentioned-documents``
|
||||
ContentPart on the persisted user message — the history loader
|
||||
renders the chips on reload from this part directly.
|
||||
"""
|
||||
thread_id = db_thread.id
|
||||
user_id_str = str(db_user.id)
|
||||
turn_id = f"{thread_id}:8200"
|
||||
|
||||
mentioned = [
|
||||
{"id": 11, "title": "Alpha", "document_type": "GENERAL"},
|
||||
{"id": 22, "title": "Beta", "document_type": "GENERAL"},
|
||||
]
|
||||
msg_id = await persist_user_turn(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
user_query="hello",
|
||||
mentioned_documents=mentioned,
|
||||
)
|
||||
assert isinstance(msg_id, int)
|
||||
|
||||
row = await db_session.get(NewChatMessage, msg_id)
|
||||
assert row is not None
|
||||
# Content is a 2-part list: text + mentioned-documents.
|
||||
assert isinstance(row.content, list)
|
||||
assert row.content[0] == {"type": "text", "text": "hello"}
|
||||
assert row.content[1] == {
|
||||
"type": "mentioned-documents",
|
||||
"documents": [
|
||||
{"id": 11, "title": "Alpha", "document_type": "GENERAL"},
|
||||
{"id": 22, "title": "Beta", "document_type": "GENERAL"},
|
||||
],
|
||||
}
|
||||
|
||||
async def test_skips_mentioned_documents_when_empty_or_invalid(
|
||||
self,
|
||||
db_session,
|
||||
db_user,
|
||||
db_thread,
|
||||
patched_shielded_session,
|
||||
):
|
||||
"""Empty list and entries missing required fields are dropped;
|
||||
a ``mentioned-documents`` part is only emitted when at least
|
||||
one normalised entry survived.
|
||||
"""
|
||||
thread_id = db_thread.id
|
||||
user_id_str = str(db_user.id)
|
||||
turn_id_empty = f"{thread_id}:8300"
|
||||
turn_id_invalid = f"{thread_id}:8301"
|
||||
|
||||
msg_id_empty = await persist_user_turn(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id_empty,
|
||||
user_query="hi",
|
||||
mentioned_documents=[],
|
||||
)
|
||||
assert isinstance(msg_id_empty, int)
|
||||
row_empty = await db_session.get(NewChatMessage, msg_id_empty)
|
||||
assert row_empty is not None
|
||||
assert row_empty.content == [{"type": "text", "text": "hi"}]
|
||||
|
||||
# Each entry missing one required field — all skipped.
|
||||
msg_id_invalid = await persist_user_turn(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id_invalid,
|
||||
user_query="hi",
|
||||
mentioned_documents=[
|
||||
{"title": "no id", "document_type": "GENERAL"}, # missing id
|
||||
{"id": 99, "document_type": "GENERAL"}, # missing title
|
||||
{"id": 100, "title": "no type"}, # missing document_type
|
||||
],
|
||||
)
|
||||
assert isinstance(msg_id_invalid, int)
|
||||
row_invalid = await db_session.get(NewChatMessage, msg_id_invalid)
|
||||
assert row_invalid is not None
|
||||
assert row_invalid.content == [{"type": "text", "text": "hi"}]
|
||||
|
||||
async def test_missing_turn_id_returns_none(
|
||||
self,
|
||||
db_user,
|
||||
db_thread,
|
||||
patched_shielded_session,
|
||||
):
|
||||
thread_id = db_thread.id
|
||||
user_id_str = str(db_user.id)
|
||||
|
||||
msg_id = await persist_user_turn(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id="",
|
||||
user_query="hello",
|
||||
)
|
||||
assert msg_id is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# finalize_assistant_turn
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFinalizeAssistantTurn:
|
||||
async def test_writes_content_and_token_usage(
|
||||
self,
|
||||
db_session,
|
||||
db_user,
|
||||
db_thread,
|
||||
db_search_space,
|
||||
patched_shielded_session,
|
||||
):
|
||||
thread_id = db_thread.id
|
||||
user_id_uuid = db_user.id
|
||||
user_id_str = str(user_id_uuid)
|
||||
search_space_id = db_search_space.id
|
||||
turn_id = f"{thread_id}:4000"
|
||||
|
||||
msg_id = await persist_assistant_shell(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
assert msg_id is not None
|
||||
|
||||
rich_content = [
|
||||
{"type": "text", "text": "Hello world"},
|
||||
{
|
||||
"type": "tool-call",
|
||||
"toolCallId": "call_x",
|
||||
"toolName": "ls",
|
||||
"args": {"path": "/"},
|
||||
"argsText": '{\n "path": "/"\n}',
|
||||
"result": {"files": []},
|
||||
"langchainToolCallId": "lc_x",
|
||||
},
|
||||
]
|
||||
await finalize_assistant_turn(
|
||||
message_id=msg_id,
|
||||
chat_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
content=rich_content,
|
||||
accumulator=_accumulator_with_one_call(),
|
||||
)
|
||||
|
||||
row = await db_session.get(NewChatMessage, msg_id)
|
||||
await db_session.refresh(row)
|
||||
assert row.content == rich_content
|
||||
|
||||
# Exactly one token_usage row keyed on this message_id.
|
||||
usage_rows = (
|
||||
(
|
||||
await db_session.execute(
|
||||
select(TokenUsage).where(TokenUsage.message_id == msg_id)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
assert len(usage_rows) == 1
|
||||
usage = usage_rows[0]
|
||||
assert usage.usage_type == "chat"
|
||||
assert usage.prompt_tokens == 100
|
||||
assert usage.completion_tokens == 50
|
||||
assert usage.total_tokens == 150
|
||||
assert usage.cost_micros == 12345
|
||||
assert usage.thread_id == thread_id
|
||||
assert usage.search_space_id == search_space_id
|
||||
|
||||
async def test_empty_content_writes_status_marker(
|
||||
self,
|
||||
db_session,
|
||||
db_user,
|
||||
db_thread,
|
||||
db_search_space,
|
||||
patched_shielded_session,
|
||||
):
|
||||
thread_id = db_thread.id
|
||||
user_id_str = str(db_user.id)
|
||||
search_space_id = db_search_space.id
|
||||
turn_id = f"{thread_id}:5000"
|
||||
|
||||
msg_id = await persist_assistant_shell(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
assert msg_id is not None
|
||||
|
||||
# Pure tool-call turn that aborted before any output, or
|
||||
# interrupt before any event arrived — empty list.
|
||||
await finalize_assistant_turn(
|
||||
message_id=msg_id,
|
||||
chat_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
content=[],
|
||||
accumulator=None,
|
||||
)
|
||||
|
||||
row = await db_session.get(NewChatMessage, msg_id)
|
||||
await db_session.refresh(row)
|
||||
assert row.content == [{"type": "status", "text": "(no text response)"}]
|
||||
|
||||
async def test_double_call_safe_via_on_conflict(
|
||||
self,
|
||||
db_session,
|
||||
db_user,
|
||||
db_thread,
|
||||
db_search_space,
|
||||
patched_shielded_session,
|
||||
):
|
||||
thread_id = db_thread.id
|
||||
user_id_str = str(db_user.id)
|
||||
search_space_id = db_search_space.id
|
||||
turn_id = f"{thread_id}:6000"
|
||||
|
||||
msg_id = await persist_assistant_shell(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
assert msg_id is not None
|
||||
|
||||
first_acc = _accumulator_with_one_call()
|
||||
await finalize_assistant_turn(
|
||||
message_id=msg_id,
|
||||
chat_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
content=[{"type": "text", "text": "first finalize"}],
|
||||
accumulator=first_acc,
|
||||
)
|
||||
|
||||
# Simulate a follow-up finalize (e.g., resume retry within the
|
||||
# shielded finally block firing twice). Different content, but
|
||||
# ON CONFLICT DO NOTHING on token_usage means the cost from the
|
||||
# first finalize stays authoritative.
|
||||
second_acc = TurnTokenAccumulator()
|
||||
second_acc.add(
|
||||
model="gpt-4o-mini",
|
||||
prompt_tokens=999,
|
||||
completion_tokens=999,
|
||||
total_tokens=1998,
|
||||
cost_micros=99999,
|
||||
)
|
||||
await finalize_assistant_turn(
|
||||
message_id=msg_id,
|
||||
chat_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
content=[{"type": "text", "text": "second finalize"}],
|
||||
accumulator=second_acc,
|
||||
)
|
||||
|
||||
# Content was overwritten by the second UPDATE.
|
||||
row = await db_session.get(NewChatMessage, msg_id)
|
||||
await db_session.refresh(row)
|
||||
assert row.content == [{"type": "text", "text": "second finalize"}]
|
||||
|
||||
# But token_usage stayed at exactly one row, preserving the
|
||||
# first finalize's authoritative cost.
|
||||
assert await _count_token_usage_rows(db_session, msg_id) == 1
|
||||
usage = (
|
||||
await db_session.execute(
|
||||
select(TokenUsage).where(TokenUsage.message_id == msg_id)
|
||||
)
|
||||
).scalar_one()
|
||||
assert usage.cost_micros == 12345 # First finalize's value
|
||||
|
||||
async def test_append_message_style_insert_after_finalize_no_dupe(
|
||||
self,
|
||||
db_session,
|
||||
db_user,
|
||||
db_thread,
|
||||
db_search_space,
|
||||
patched_shielded_session,
|
||||
):
|
||||
"""Cross-writer race: ``append_message`` arrives after ``finalize_assistant_turn``.
|
||||
|
||||
Both target the same ``message_id``; the partial unique index
|
||||
``uq_token_usage_message_id`` (migration 142) makes the second
|
||||
insert a no-op via ``ON CONFLICT DO NOTHING``.
|
||||
"""
|
||||
from sqlalchemy import text as sa_text
|
||||
|
||||
thread_id = db_thread.id
|
||||
user_uuid = db_user.id
|
||||
user_id_str = str(user_uuid)
|
||||
search_space_id = db_search_space.id
|
||||
turn_id = f"{thread_id}:7000"
|
||||
|
||||
msg_id = await persist_assistant_shell(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
assert msg_id is not None
|
||||
|
||||
await finalize_assistant_turn(
|
||||
message_id=msg_id,
|
||||
chat_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
content=[{"type": "text", "text": "from server"}],
|
||||
accumulator=_accumulator_with_one_call(),
|
||||
)
|
||||
|
||||
# Now simulate the FE's append_message branch firing AFTER —
|
||||
# the same INSERT ... ON CONFLICT DO NOTHING shape used by the
|
||||
# route handler, keyed on the migration-142 partial unique
|
||||
# index.
|
||||
late_insert = (
|
||||
pg_insert(TokenUsage)
|
||||
.values(
|
||||
usage_type="chat",
|
||||
prompt_tokens=42,
|
||||
completion_tokens=42,
|
||||
total_tokens=84,
|
||||
cost_micros=1,
|
||||
model_breakdown=None,
|
||||
call_details=None,
|
||||
thread_id=thread_id,
|
||||
message_id=msg_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_uuid,
|
||||
)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=["message_id"],
|
||||
index_where=sa_text("message_id IS NOT NULL"),
|
||||
)
|
||||
)
|
||||
await db_session.execute(late_insert)
|
||||
await db_session.flush()
|
||||
|
||||
# Still exactly one row, with the original (server) cost value.
|
||||
assert await _count_token_usage_rows(db_session, msg_id) == 1
|
||||
usage = (
|
||||
await db_session.execute(
|
||||
select(TokenUsage).where(TokenUsage.message_id == msg_id)
|
||||
)
|
||||
).scalar_one()
|
||||
assert usage.cost_micros == 12345
|
||||
|
||||
async def test_helper_never_raises_on_missing_message_id(
|
||||
self,
|
||||
db_session,
|
||||
db_user,
|
||||
db_thread,
|
||||
db_search_space,
|
||||
patched_shielded_session,
|
||||
):
|
||||
thread_id = db_thread.id
|
||||
user_id_str = str(db_user.id)
|
||||
search_space_id = db_search_space.id
|
||||
|
||||
# message_id that doesn't exist — finalize must log+return,
|
||||
# never raise (called from shielded finally).
|
||||
await finalize_assistant_turn(
|
||||
message_id=999_999_999,
|
||||
chat_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id_str,
|
||||
turn_id="anything",
|
||||
content=[{"type": "text", "text": "x"}],
|
||||
accumulator=_accumulator_with_one_call(),
|
||||
)
|
||||
# If we got here without an exception, the test passes.
|
||||
# Sanity: no token_usage row created (FK to message would have
|
||||
# been rejected anyway, but ON CONFLICT path may swallow
|
||||
# FK errors as well; check directly).
|
||||
assert await _count_token_usage_rows(db_session, 999_999_999) == 0
|
||||
526
surfsense_backend/tests/unit/tasks/chat/test_content_builder.py
Normal file
526
surfsense_backend/tests/unit/tasks/chat/test_content_builder.py
Normal file
|
|
@ -0,0 +1,526 @@
|
|||
"""Unit tests for ``AssistantContentBuilder``.
|
||||
|
||||
Pins the in-memory ``ContentPart[]`` projection so the JSONB the server
|
||||
persists matches what the frontend renders live (see
|
||||
``surfsense_web/lib/chat/streaming-state.ts``). Every test asserts both
|
||||
the structural shape of ``snapshot()`` and that the snapshot is
|
||||
``json.dumps``-safe (the streaming finally block writes it directly to
|
||||
``new_chat_messages.content`` without an explicit serialization round
|
||||
trip).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from app.tasks.chat.content_builder import AssistantContentBuilder
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _assert_jsonb_safe(parts: list[dict]) -> None:
|
||||
"""Sanity check: any snapshot must round-trip through ``json.dumps``."""
|
||||
serialized = json.dumps(parts)
|
||||
assert json.loads(serialized) == parts
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Text turns
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTextOnly:
|
||||
def test_single_text_block_collapses_consecutive_deltas(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_text_start("text-1")
|
||||
b.on_text_delta("text-1", "Hello")
|
||||
b.on_text_delta("text-1", " ")
|
||||
b.on_text_delta("text-1", "world")
|
||||
b.on_text_end("text-1")
|
||||
|
||||
snap = b.snapshot()
|
||||
assert snap == [{"type": "text", "text": "Hello world"}]
|
||||
assert not b.is_empty()
|
||||
_assert_jsonb_safe(snap)
|
||||
|
||||
def test_empty_text_start_end_pair_leaves_no_part(self):
|
||||
# Mirrors the FE: a text-start without any deltas should
|
||||
# not materialise an empty ``{"type":"text","text":""}`` part.
|
||||
b = AssistantContentBuilder()
|
||||
b.on_text_start("text-1")
|
||||
b.on_text_end("text-1")
|
||||
|
||||
assert b.snapshot() == []
|
||||
assert b.is_empty()
|
||||
|
||||
def test_text_after_text_end_starts_fresh_part(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_text_start("text-1")
|
||||
b.on_text_delta("text-1", "first")
|
||||
b.on_text_end("text-1")
|
||||
|
||||
b.on_text_start("text-2")
|
||||
b.on_text_delta("text-2", "second")
|
||||
b.on_text_end("text-2")
|
||||
|
||||
snap = b.snapshot()
|
||||
assert snap == [
|
||||
{"type": "text", "text": "first"},
|
||||
{"type": "text", "text": "second"},
|
||||
]
|
||||
|
||||
|
||||
class TestReasoningThenText:
|
||||
def test_reasoning_followed_by_text_yields_two_parts_in_order(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_reasoning_start("r-1")
|
||||
b.on_reasoning_delta("r-1", "Considering options...")
|
||||
b.on_reasoning_end("r-1")
|
||||
|
||||
b.on_text_start("text-1")
|
||||
b.on_text_delta("text-1", "The answer is 42.")
|
||||
b.on_text_end("text-1")
|
||||
|
||||
snap = b.snapshot()
|
||||
assert snap == [
|
||||
{"type": "reasoning", "text": "Considering options..."},
|
||||
{"type": "text", "text": "The answer is 42."},
|
||||
]
|
||||
_assert_jsonb_safe(snap)
|
||||
|
||||
def test_text_delta_after_reasoning_implicitly_closes_reasoning(self):
|
||||
# Mirrors FE ``appendText``: a text delta arriving while a
|
||||
# reasoning part is "active" still produces a fresh text
|
||||
# part, never appends into the reasoning block.
|
||||
b = AssistantContentBuilder()
|
||||
b.on_reasoning_start("r-1")
|
||||
b.on_reasoning_delta("r-1", "thinking")
|
||||
# No explicit reasoning_end — text delta should close it.
|
||||
b.on_text_delta("text-1", "answer")
|
||||
|
||||
snap = b.snapshot()
|
||||
assert snap == [
|
||||
{"type": "reasoning", "text": "thinking"},
|
||||
{"type": "text", "text": "answer"},
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool calls
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToolHeavyTurn:
|
||||
def test_full_tool_lifecycle_produces_complete_tool_call_part(self):
|
||||
b = AssistantContentBuilder()
|
||||
# Some narration before the tool fires.
|
||||
b.on_text_start("text-1")
|
||||
b.on_text_delta("text-1", "Searching...")
|
||||
b.on_text_end("text-1")
|
||||
|
||||
b.on_tool_input_start(
|
||||
ui_id="call_run123",
|
||||
tool_name="web_search",
|
||||
langchain_tool_call_id="lc_tool_abc",
|
||||
)
|
||||
b.on_tool_input_delta("call_run123", '{"query":')
|
||||
b.on_tool_input_delta("call_run123", '"surfsense"}')
|
||||
b.on_tool_input_available(
|
||||
ui_id="call_run123",
|
||||
tool_name="web_search",
|
||||
args={"query": "surfsense"},
|
||||
langchain_tool_call_id="lc_tool_abc",
|
||||
)
|
||||
b.on_tool_output_available(
|
||||
ui_id="call_run123",
|
||||
output={"status": "completed", "citations": {}},
|
||||
langchain_tool_call_id="lc_tool_abc",
|
||||
)
|
||||
|
||||
snap = b.snapshot()
|
||||
assert snap[0] == {"type": "text", "text": "Searching..."}
|
||||
tool_part = snap[1]
|
||||
assert tool_part["type"] == "tool-call"
|
||||
assert tool_part["toolCallId"] == "call_run123"
|
||||
assert tool_part["toolName"] == "web_search"
|
||||
assert tool_part["args"] == {"query": "surfsense"}
|
||||
# ``argsText`` is the pretty-printed final JSON, not the raw
|
||||
# streaming buffer (FE ``stream-pipeline.ts:128``).
|
||||
assert tool_part["argsText"] == json.dumps(
|
||||
{"query": "surfsense"}, indent=2, ensure_ascii=False
|
||||
)
|
||||
assert tool_part["langchainToolCallId"] == "lc_tool_abc"
|
||||
assert tool_part["result"] == {"status": "completed", "citations": {}}
|
||||
_assert_jsonb_safe(snap)
|
||||
|
||||
def test_tool_input_available_without_prior_start_creates_card(self):
|
||||
# Legacy / parity_v2-OFF path: tool-input-available may be
|
||||
# emitted without a prior tool-input-start (no streamed
|
||||
# tool_call_chunks). The card should still be created.
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_available(
|
||||
ui_id="call_run42",
|
||||
tool_name="grep",
|
||||
args={"pattern": "TODO"},
|
||||
langchain_tool_call_id="lc_x",
|
||||
)
|
||||
b.on_tool_output_available(
|
||||
ui_id="call_run42",
|
||||
output={"matches": 3},
|
||||
langchain_tool_call_id="lc_x",
|
||||
)
|
||||
|
||||
snap = b.snapshot()
|
||||
assert len(snap) == 1
|
||||
part = snap[0]
|
||||
assert part["type"] == "tool-call"
|
||||
assert part["toolCallId"] == "call_run42"
|
||||
assert part["args"] == {"pattern": "TODO"}
|
||||
assert part["langchainToolCallId"] == "lc_x"
|
||||
assert part["result"] == {"matches": 3}
|
||||
|
||||
def test_tool_input_start_idempotent_for_same_ui_id(self):
|
||||
# parity_v2: tool-input-start can fire from BOTH the chunk
|
||||
# registration path AND the canonical ``on_tool_start`` path.
|
||||
# The second call must not create a duplicate part.
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start("call_x", "ls", "lc_x")
|
||||
b.on_tool_input_start("call_x", "ls", "lc_x")
|
||||
snap = b.snapshot()
|
||||
assert len(snap) == 1
|
||||
|
||||
def test_tool_input_delta_without_prior_start_is_silently_dropped(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_delta("call_unknown", '{"orphan": "delta"}')
|
||||
assert b.snapshot() == []
|
||||
|
||||
def test_langchain_tool_call_id_backfills_only_when_absent(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start("call_x", "ls", "lc_first")
|
||||
# Late event must NOT clobber an already-set lc id.
|
||||
b.on_tool_input_start("call_x", "ls", "lc_late")
|
||||
snap = b.snapshot()
|
||||
assert snap[0]["langchainToolCallId"] == "lc_first"
|
||||
|
||||
def test_args_text_streaming_buffer_reflects_concatenation(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start("call_x", "save_doc", "lc_y")
|
||||
b.on_tool_input_delta("call_x", '{"title":')
|
||||
b.on_tool_input_delta("call_x", '"Hi"}')
|
||||
# Snapshot mid-stream should see the partial buffer (the FE
|
||||
# tolerates invalid JSON and renders it as-is).
|
||||
mid = b.snapshot()
|
||||
assert mid[0]["argsText"] == '{"title":"Hi"}'
|
||||
# Then tool-input-available replaces with pretty-printed.
|
||||
b.on_tool_input_available(
|
||||
"call_x",
|
||||
"save_doc",
|
||||
{"title": "Hi"},
|
||||
"lc_y",
|
||||
)
|
||||
final = b.snapshot()
|
||||
assert final[0]["argsText"] == json.dumps(
|
||||
{"title": "Hi"}, indent=2, ensure_ascii=False
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thinking steps & separators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestThinkingSteps:
|
||||
def test_first_thinking_step_unshifts_singleton_to_index_zero(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_text_start("text-1")
|
||||
b.on_text_delta("text-1", "Hello")
|
||||
b.on_text_end("text-1")
|
||||
|
||||
b.on_thinking_step("step-1", "Analyzing", "in_progress", ["item-a"])
|
||||
|
||||
snap = b.snapshot()
|
||||
# Singleton goes to index 0 (FE ``updateThinkingSteps`` unshift).
|
||||
assert snap[0]["type"] == "data-thinking-steps"
|
||||
assert snap[0]["data"]["steps"] == [
|
||||
{
|
||||
"id": "step-1",
|
||||
"title": "Analyzing",
|
||||
"status": "in_progress",
|
||||
"items": ["item-a"],
|
||||
}
|
||||
]
|
||||
assert snap[1] == {"type": "text", "text": "Hello"}
|
||||
|
||||
def test_subsequent_thinking_steps_mutate_the_singleton_in_place(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_thinking_step("step-1", "Analyzing", "in_progress", [])
|
||||
b.on_thinking_step("step-2", "Searching", "in_progress", ["q"])
|
||||
b.on_thinking_step("step-1", "Analyzing", "completed", ["done"])
|
||||
|
||||
snap = b.snapshot()
|
||||
assert len([p for p in snap if p["type"] == "data-thinking-steps"]) == 1
|
||||
steps = snap[0]["data"]["steps"]
|
||||
assert len(steps) == 2
|
||||
assert steps[0]["id"] == "step-1"
|
||||
assert steps[0]["status"] == "completed"
|
||||
assert steps[0]["items"] == ["done"]
|
||||
assert steps[1]["id"] == "step-2"
|
||||
|
||||
def test_thinking_step_with_text_continues_appending_to_text(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_text_start("text-1")
|
||||
b.on_text_delta("text-1", "first")
|
||||
|
||||
# Thinking step inserts at index 0, bumps text idx from 0 to 1.
|
||||
b.on_thinking_step("step-1", "Working", "in_progress", [])
|
||||
b.on_text_delta("text-1", " second")
|
||||
|
||||
snap = b.snapshot()
|
||||
text_parts = [p for p in snap if p["type"] == "text"]
|
||||
assert text_parts == [{"type": "text", "text": "first second"}]
|
||||
|
||||
def test_thinking_step_without_id_is_dropped(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_thinking_step("", "noop", "in_progress", None)
|
||||
assert b.snapshot() == []
|
||||
assert b.is_empty()
|
||||
|
||||
|
||||
class TestStepSeparators:
|
||||
def test_separator_no_op_before_any_content(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_step_separator()
|
||||
assert b.snapshot() == []
|
||||
|
||||
def test_separator_after_text_appends_with_step_index_zero(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_text_start("text-1")
|
||||
b.on_text_delta("text-1", "first")
|
||||
b.on_text_end("text-1")
|
||||
|
||||
b.on_step_separator()
|
||||
|
||||
snap = b.snapshot()
|
||||
assert snap[-1] == {
|
||||
"type": "data-step-separator",
|
||||
"data": {"stepIndex": 0},
|
||||
}
|
||||
|
||||
def test_consecutive_separators_collapse_to_one(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_text_delta("text-1", "x")
|
||||
b.on_step_separator()
|
||||
b.on_step_separator() # No-op: previous part is already a separator.
|
||||
snap = b.snapshot()
|
||||
assert sum(1 for p in snap if p["type"] == "data-step-separator") == 1
|
||||
|
||||
def test_step_index_increments_across_separators(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_text_delta("text-1", "a")
|
||||
b.on_step_separator()
|
||||
b.on_text_delta("text-2", "b")
|
||||
b.on_step_separator()
|
||||
snap = b.snapshot()
|
||||
seps = [p for p in snap if p["type"] == "data-step-separator"]
|
||||
assert [s["data"]["stepIndex"] for s in seps] == [0, 1]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Interruption handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMarkInterrupted:
|
||||
def test_running_tool_calls_get_state_aborted(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start("call_a", "ls", "lc_a")
|
||||
b.on_tool_input_available("call_a", "ls", {"path": "/"}, "lc_a")
|
||||
# No tool-output-available — simulates client disconnect mid-tool.
|
||||
|
||||
b.mark_interrupted()
|
||||
|
||||
snap = b.snapshot()
|
||||
assert snap[0]["state"] == "aborted"
|
||||
assert "result" not in snap[0]
|
||||
|
||||
def test_completed_tool_calls_are_not_marked_aborted(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start("call_a", "ls", "lc_a")
|
||||
b.on_tool_input_available("call_a", "ls", {"path": "/"}, "lc_a")
|
||||
b.on_tool_output_available("call_a", {"files": []}, "lc_a")
|
||||
|
||||
b.mark_interrupted()
|
||||
|
||||
snap = b.snapshot()
|
||||
assert "state" not in snap[0]
|
||||
assert snap[0]["result"] == {"files": []}
|
||||
|
||||
def test_open_text_block_keeps_accumulated_content(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_text_start("text-1")
|
||||
b.on_text_delta("text-1", "partial")
|
||||
# No on_text_end — disconnect mid-stream.
|
||||
|
||||
b.mark_interrupted()
|
||||
|
||||
snap = b.snapshot()
|
||||
assert snap == [{"type": "text", "text": "partial"}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_empty / snapshot semantics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsEmpty:
|
||||
def test_fresh_builder_is_empty(self):
|
||||
assert AssistantContentBuilder().is_empty()
|
||||
|
||||
def test_text_part_breaks_emptiness(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_text_delta("text-1", "x")
|
||||
assert not b.is_empty()
|
||||
|
||||
def test_tool_call_breaks_emptiness(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start("call_x", "ls", None)
|
||||
assert not b.is_empty()
|
||||
|
||||
def test_thinking_step_alone_does_not_break_emptiness(self):
|
||||
# Mirrors the "status marker fallback" semantic: a turn that
|
||||
# only emitted a thinking step before being interrupted should
|
||||
# still be treated as empty for finalize_assistant_turn's
|
||||
# status-marker substitution.
|
||||
b = AssistantContentBuilder()
|
||||
b.on_thinking_step("step-1", "Working", "in_progress", [])
|
||||
assert b.is_empty()
|
||||
|
||||
def test_step_separator_alone_does_not_break_emptiness(self):
|
||||
b = AssistantContentBuilder()
|
||||
# Force a separator (it would normally no-op without content,
|
||||
# but we simulate the underlying state to verify is_empty is
|
||||
# not fooled by a stray separator).
|
||||
b.parts.append({"type": "data-step-separator", "data": {"stepIndex": 0}})
|
||||
assert b.is_empty()
|
||||
|
||||
|
||||
class TestSnapshotSemantics:
|
||||
def test_snapshot_is_deep_copied_so_mutations_do_not_leak(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start("call_x", "ls", "lc_x")
|
||||
b.on_tool_input_available("call_x", "ls", {"path": "/"}, "lc_x")
|
||||
snap = b.snapshot()
|
||||
# Mutate the returned snapshot — original should be untouched.
|
||||
snap[0]["args"]["mutated"] = True
|
||||
snap[0]["state"] = "tampered"
|
||||
|
||||
again = b.snapshot()
|
||||
assert "mutated" not in again[0]["args"]
|
||||
assert "state" not in again[0]
|
||||
|
||||
def test_snapshot_round_trips_through_json(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_thinking_step("step-1", "Analyzing", "in_progress", ["item"])
|
||||
b.on_text_delta("text-1", "answer")
|
||||
b.on_tool_input_start("call_x", "ls", "lc_x")
|
||||
b.on_tool_input_available("call_x", "ls", {"path": "/"}, "lc_x")
|
||||
b.on_tool_output_available("call_x", {"files": ["a.txt"]}, "lc_x")
|
||||
b.on_step_separator()
|
||||
snap = b.snapshot()
|
||||
|
||||
encoded = json.dumps(snap)
|
||||
assert json.loads(encoded) == snap
|
||||
|
||||
|
||||
class TestStats:
|
||||
"""``stats()`` is the perf-log handle for [PERF] [stream_*]
|
||||
finalize_payload lines. Pin the schema so an ops dashboard can
|
||||
rely on these keys being present and meaningful.
|
||||
"""
|
||||
|
||||
def test_fresh_builder_reports_all_zeros(self):
|
||||
b = AssistantContentBuilder()
|
||||
s = b.stats()
|
||||
assert s == {
|
||||
"parts": 0,
|
||||
"bytes": 2, # ``[]`` is two bytes
|
||||
"text": 0,
|
||||
"reasoning": 0,
|
||||
"tool_calls": 0,
|
||||
"tool_calls_completed": 0,
|
||||
"tool_calls_aborted": 0,
|
||||
"thinking_step_parts": 0,
|
||||
"step_separators": 0,
|
||||
}
|
||||
|
||||
def test_counts_each_part_type_independently(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_text_start("t1")
|
||||
b.on_text_delta("t1", "hi")
|
||||
b.on_text_end("t1")
|
||||
b.on_reasoning_start("r1")
|
||||
b.on_reasoning_delta("r1", "thinking")
|
||||
b.on_reasoning_end("r1")
|
||||
b.on_thinking_step("step-1", "Analyzing", "completed", ["item"])
|
||||
b.on_step_separator()
|
||||
b.on_tool_input_start("call_done", "ls", "lc_done")
|
||||
b.on_tool_input_available("call_done", "ls", {}, "lc_done")
|
||||
b.on_tool_output_available("call_done", {"ok": True}, "lc_done")
|
||||
b.on_tool_input_start("call_running", "rm", "lc_running")
|
||||
b.on_tool_input_available("call_running", "rm", {}, "lc_running")
|
||||
|
||||
s = b.stats()
|
||||
assert s["text"] == 1
|
||||
assert s["reasoning"] == 1
|
||||
assert s["tool_calls"] == 2
|
||||
assert s["tool_calls_completed"] == 1
|
||||
assert s["tool_calls_aborted"] == 0
|
||||
assert s["thinking_step_parts"] == 1
|
||||
assert s["step_separators"] == 1
|
||||
assert s["parts"] == sum(
|
||||
[
|
||||
s["text"],
|
||||
s["reasoning"],
|
||||
s["tool_calls"],
|
||||
s["thinking_step_parts"],
|
||||
s["step_separators"],
|
||||
]
|
||||
)
|
||||
assert s["bytes"] > 0
|
||||
|
||||
def test_mark_interrupted_flips_running_calls_to_aborted_in_stats(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start("call_done", "ls", "lc_done")
|
||||
b.on_tool_input_available("call_done", "ls", {}, "lc_done")
|
||||
b.on_tool_output_available("call_done", {"ok": True}, "lc_done")
|
||||
b.on_tool_input_start("call_running", "rm", "lc_running")
|
||||
b.on_tool_input_available("call_running", "rm", {}, "lc_running")
|
||||
|
||||
# Pre-interrupt: one completed, one still running (no result).
|
||||
pre = b.stats()
|
||||
assert pre["tool_calls_completed"] == 1
|
||||
assert pre["tool_calls_aborted"] == 0
|
||||
|
||||
b.mark_interrupted()
|
||||
post = b.stats()
|
||||
assert post["tool_calls_completed"] == 1
|
||||
assert post["tool_calls_aborted"] == 1
|
||||
assert post["tool_calls"] == 2
|
||||
|
||||
def test_bytes_reflects_jsonb_payload_size(self):
|
||||
# Each text-delta adds bytes monotonically — useful for catching
|
||||
# an unbounded delta buffer regression in the perf signal.
|
||||
b = AssistantContentBuilder()
|
||||
b.on_text_start("t1")
|
||||
b.on_text_delta("t1", "x" * 10)
|
||||
small = b.stats()["bytes"]
|
||||
b.on_text_delta("t1", "x" * 1000)
|
||||
large = b.stats()["bytes"]
|
||||
assert large > small + 900
|
||||
|
|
@ -457,6 +457,9 @@ def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows()
|
|||
source = page_path.read_text(encoding="utf-8")
|
||||
|
||||
# Each flow tracks accepted boundary and passes it into shared terminal handling.
|
||||
# The acceptance boundary is still meaningful post-refactor: it gates
|
||||
# local-state cleanup (onPreAcceptFailure path) and lets the shared
|
||||
# terminal handler distinguish pre-accept aborts from in-stream errors.
|
||||
assert "let newAccepted = false;" in source
|
||||
assert "let resumeAccepted = false;" in source
|
||||
assert "let regenerateAccepted = false;" in source
|
||||
|
|
@ -464,12 +467,23 @@ def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows()
|
|||
assert "accepted: resumeAccepted," in source
|
||||
assert "accepted: regenerateAccepted," in source
|
||||
|
||||
# Pre-accept abort in resume/regenerate exits without persistence.
|
||||
assert "if (!resumeAccepted) return;" in source
|
||||
assert "if (!regenerateAccepted) return;" in source
|
||||
# NOTE: The FE-side persistence guards previously asserted here
|
||||
# ("if (!resumeAccepted) return;", "if (!regenerateAccepted) return;",
|
||||
# "if (newAccepted && !userPersisted) {") have been intentionally
|
||||
# removed by the SSE-based message-id handshake refactor. Persistence
|
||||
# is now server-authoritative: persist_user_turn / persist_assistant_shell
|
||||
# run inside stream_new_chat / stream_resume_chat unconditionally and
|
||||
# the FE consumes data-user-message-id / data-assistant-message-id
|
||||
# SSE events to learn the canonical primary keys. There is therefore
|
||||
# no FE call-site to guard, and the shared terminal handler relies
|
||||
# purely on the `accepted` field above (forwarded to onAbort /
|
||||
# onAcceptedStreamError) to drive UI cleanup. See
|
||||
# tests/integration/chat/test_message_id_sse.py for the new
|
||||
# cross-tier ID coherence guarantees.
|
||||
|
||||
# New flow persists only when accepted and not already persisted.
|
||||
assert "if (newAccepted && !userPersisted) {" in source
|
||||
# The TURN_CANCELLING / THREAD_BUSY retry plumbing is independent
|
||||
# of the persistence refactor and must still exist on every
|
||||
# start-stream fetch.
|
||||
assert "const fetchWithTurnCancellingRetry = useCallback(" in source
|
||||
assert "computeFallbackTurnCancellingRetryDelay" in source
|
||||
assert 'withMeta.errorCode === "TURN_CANCELLING"' in source
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue