mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-06 06:12:40 +02:00
747 lines
24 KiB
Python
747 lines
24 KiB
Python
"""Integration tests for ``app.tasks.chat.persistence``.
|
|
|
|
Verifies the DB-side guarantees the streaming chat task relies on:
|
|
|
|
* ``persist_assistant_shell`` is idempotent against the
|
|
``(thread_id, turn_id, ASSISTANT)`` partial unique index from
|
|
migration 141. Two calls with the same ``turn_id`` return the SAME
|
|
``message_id`` and never create a duplicate ``new_chat_messages`` row.
|
|
* ``finalize_assistant_turn`` writes a status-marker payload when given
|
|
empty content, never raises, and is safe to call twice on the same
|
|
``message_id`` — the partial unique index from migration 142
|
|
(``uq_token_usage_message_id``) prevents the second insert from
|
|
producing a duplicate ``token_usage`` row.
|
|
* The same ``ON CONFLICT DO NOTHING`` invariant covers the cross-writer
|
|
race where ``finalize_assistant_turn`` and the ``append_message``
|
|
recovery branch both target the same ``message_id``.
|
|
|
|
All tests run inside the conftest's outer-transaction-with-savepoint
|
|
fixture so commits inside the helpers (which open their own
|
|
``shielded_async_session``) are released as savepoints and rolled back
|
|
at test end. We monkey-patch ``shielded_async_session`` to yield the
|
|
same pooled test session so the integration transaction stays
|
|
in-scope.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from contextlib import asynccontextmanager
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from sqlalchemy import func, select
|
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.db import (
|
|
ChatVisibility,
|
|
NewChatMessage,
|
|
NewChatMessageRole,
|
|
NewChatThread,
|
|
SearchSpace,
|
|
TokenUsage,
|
|
User,
|
|
)
|
|
from app.services.token_tracking_service import TurnTokenAccumulator
|
|
from app.tasks.chat import persistence as persistence_module
|
|
from app.tasks.chat.persistence import (
|
|
finalize_assistant_turn,
|
|
persist_assistant_shell,
|
|
persist_user_turn,
|
|
)
|
|
|
|
pytestmark = pytest.mark.integration
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def db_thread(
|
|
db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
|
|
) -> NewChatThread:
|
|
thread = NewChatThread(
|
|
title="Test Chat",
|
|
search_space_id=db_search_space.id,
|
|
created_by_id=db_user.id,
|
|
visibility=ChatVisibility.PRIVATE,
|
|
)
|
|
db_session.add(thread)
|
|
await db_session.flush()
|
|
return thread
|
|
|
|
|
|
@pytest.fixture
|
|
def patched_shielded_session(monkeypatch, db_session: AsyncSession):
|
|
"""Route persistence helpers to the test's savepoint-bound session.
|
|
|
|
The persistence helpers use ``async with shielded_async_session() as
|
|
ws`` and call ``ws.commit()`` internally. Inside the conftest's
|
|
``join_transaction_mode="create_savepoint"`` setup, those commits
|
|
release a SAVEPOINT instead of committing the outer transaction —
|
|
so the test session can see helper-staged rows immediately and the
|
|
outer rollback at end of test wipes them.
|
|
"""
|
|
|
|
@asynccontextmanager
|
|
async def _fake_shielded_session():
|
|
yield db_session
|
|
# Do NOT close — the outer fixture owns the session lifecycle.
|
|
|
|
monkeypatch.setattr(
|
|
persistence_module,
|
|
"shielded_async_session",
|
|
_fake_shielded_session,
|
|
)
|
|
return db_session
|
|
|
|
|
|
def _accumulator_with_one_call() -> TurnTokenAccumulator:
|
|
acc = TurnTokenAccumulator()
|
|
acc.add(
|
|
model="gpt-4o-mini",
|
|
prompt_tokens=100,
|
|
completion_tokens=50,
|
|
total_tokens=150,
|
|
cost_micros=12345,
|
|
)
|
|
return acc
|
|
|
|
|
|
async def _count_assistant_rows(
|
|
session: AsyncSession, thread_id: int, turn_id: str
|
|
) -> int:
|
|
result = await session.execute(
|
|
select(func.count())
|
|
.select_from(NewChatMessage)
|
|
.where(
|
|
NewChatMessage.thread_id == thread_id,
|
|
NewChatMessage.turn_id == turn_id,
|
|
NewChatMessage.role == NewChatMessageRole.ASSISTANT,
|
|
)
|
|
)
|
|
return int(result.scalar_one())
|
|
|
|
|
|
async def _count_token_usage_rows(session: AsyncSession, message_id: int) -> int:
|
|
result = await session.execute(
|
|
select(func.count())
|
|
.select_from(TokenUsage)
|
|
.where(TokenUsage.message_id == message_id)
|
|
)
|
|
return int(result.scalar_one())
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# persist_assistant_shell
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPersistAssistantShell:
|
|
async def test_first_call_inserts_empty_shell_and_returns_id(
|
|
self,
|
|
db_session,
|
|
db_user,
|
|
db_thread,
|
|
patched_shielded_session,
|
|
):
|
|
# Capture primitive ids before any persistence helper runs:
|
|
# the helpers commit/rollback the shared test session, which
|
|
# can detach ORM rows mid-test.
|
|
thread_id = db_thread.id
|
|
user_id_str = str(db_user.id)
|
|
turn_id = f"{thread_id}:1000"
|
|
|
|
msg_id = await persist_assistant_shell(
|
|
chat_id=thread_id,
|
|
user_id=user_id_str,
|
|
turn_id=turn_id,
|
|
)
|
|
assert msg_id is not None and isinstance(msg_id, int)
|
|
|
|
row = await db_session.get(NewChatMessage, msg_id)
|
|
assert row is not None
|
|
assert row.thread_id == thread_id
|
|
assert row.role == NewChatMessageRole.ASSISTANT
|
|
assert row.turn_id == turn_id
|
|
# Empty shell payload — finalize_assistant_turn overwrites later.
|
|
assert row.content == [{"type": "text", "text": ""}]
|
|
|
|
async def test_second_call_with_same_turn_id_returns_same_id(
|
|
self,
|
|
db_session,
|
|
db_user,
|
|
db_thread,
|
|
patched_shielded_session,
|
|
):
|
|
# Capture primitive ids before any persistence helper runs:
|
|
# the helpers commit/rollback the shared test session, which
|
|
# can detach ORM rows mid-test.
|
|
thread_id = db_thread.id
|
|
user_id_str = str(db_user.id)
|
|
turn_id = f"{thread_id}:2000"
|
|
|
|
first_id = await persist_assistant_shell(
|
|
chat_id=thread_id,
|
|
user_id=user_id_str,
|
|
turn_id=turn_id,
|
|
)
|
|
second_id = await persist_assistant_shell(
|
|
chat_id=thread_id,
|
|
user_id=user_id_str,
|
|
turn_id=turn_id,
|
|
)
|
|
|
|
assert first_id is not None
|
|
assert first_id == second_id
|
|
# Exactly one row in the DB for this turn.
|
|
assert await _count_assistant_rows(db_session, thread_id, turn_id) == 1
|
|
|
|
async def test_missing_turn_id_returns_none(
|
|
self,
|
|
db_user,
|
|
db_thread,
|
|
patched_shielded_session,
|
|
):
|
|
thread_id = db_thread.id
|
|
user_id_str = str(db_user.id)
|
|
|
|
msg_id = await persist_assistant_shell(
|
|
chat_id=thread_id,
|
|
user_id=user_id_str,
|
|
turn_id="",
|
|
)
|
|
assert msg_id is None
|
|
|
|
async def test_after_persist_user_turn_resolves_assistant_id(
|
|
self,
|
|
db_session,
|
|
db_user,
|
|
db_thread,
|
|
patched_shielded_session,
|
|
):
|
|
thread_id = db_thread.id
|
|
user_id_str = str(db_user.id)
|
|
turn_id = f"{thread_id}:3000"
|
|
|
|
# The streaming layer always calls persist_user_turn first, so
|
|
# smoke-test the canonical sequence.
|
|
user_msg_id = await persist_user_turn(
|
|
chat_id=thread_id,
|
|
user_id=user_id_str,
|
|
turn_id=turn_id,
|
|
user_query="hello",
|
|
)
|
|
assert isinstance(user_msg_id, int)
|
|
|
|
msg_id = await persist_assistant_shell(
|
|
chat_id=thread_id,
|
|
user_id=user_id_str,
|
|
turn_id=turn_id,
|
|
)
|
|
assert msg_id is not None
|
|
# User row + assistant shell row = 2 rows for this turn.
|
|
result = await db_session.execute(
|
|
select(func.count())
|
|
.select_from(NewChatMessage)
|
|
.where(
|
|
NewChatMessage.thread_id == thread_id,
|
|
NewChatMessage.turn_id == turn_id,
|
|
)
|
|
)
|
|
assert result.scalar_one() == 2
|
|
|
|
async def test_double_call_with_same_turn_id_uses_on_conflict(
|
|
self,
|
|
db_session,
|
|
db_user,
|
|
db_thread,
|
|
patched_shielded_session,
|
|
):
|
|
"""Verifies the ON CONFLICT DO NOTHING path on the assistant
|
|
shell does not raise ``IntegrityError`` even when the second
|
|
writer races the first within a tight loop. ``test_second_call_with_same_turn_id_returns_same_id``
|
|
already covers the same-id semantics; this test additionally
|
|
asserts neither call raises so the debugger's
|
|
``raise-on-IntegrityError`` setting won't pause the streaming
|
|
path under contention.
|
|
"""
|
|
thread_id = db_thread.id
|
|
user_id_str = str(db_user.id)
|
|
turn_id = f"{thread_id}:3500"
|
|
|
|
# Both calls go through ``pg_insert(...).on_conflict_do_nothing``;
|
|
# the second one returns RETURNING=∅ and falls into the SELECT
|
|
# branch. Neither path raises.
|
|
first_id = await persist_assistant_shell(
|
|
chat_id=thread_id, user_id=user_id_str, turn_id=turn_id
|
|
)
|
|
second_id = await persist_assistant_shell(
|
|
chat_id=thread_id, user_id=user_id_str, turn_id=turn_id
|
|
)
|
|
assert first_id is not None
|
|
assert first_id == second_id
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# persist_user_turn
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPersistUserTurn:
|
|
async def test_returns_message_id_on_first_insert(
|
|
self,
|
|
db_session,
|
|
db_user,
|
|
db_thread,
|
|
patched_shielded_session,
|
|
):
|
|
thread_id = db_thread.id
|
|
user_id_str = str(db_user.id)
|
|
turn_id = f"{thread_id}:8000"
|
|
|
|
msg_id = await persist_user_turn(
|
|
chat_id=thread_id,
|
|
user_id=user_id_str,
|
|
turn_id=turn_id,
|
|
user_query="hello",
|
|
)
|
|
assert isinstance(msg_id, int) and msg_id > 0
|
|
|
|
row = await db_session.get(NewChatMessage, msg_id)
|
|
assert row is not None
|
|
assert row.thread_id == thread_id
|
|
assert row.role == NewChatMessageRole.USER
|
|
assert row.turn_id == turn_id
|
|
assert row.content == [{"type": "text", "text": "hello"}]
|
|
|
|
async def test_returns_existing_id_on_conflict(
|
|
self,
|
|
db_session,
|
|
db_user,
|
|
db_thread,
|
|
patched_shielded_session,
|
|
):
|
|
thread_id = db_thread.id
|
|
user_id_str = str(db_user.id)
|
|
turn_id = f"{thread_id}:8100"
|
|
|
|
first_id = await persist_user_turn(
|
|
chat_id=thread_id,
|
|
user_id=user_id_str,
|
|
turn_id=turn_id,
|
|
user_query="hello",
|
|
)
|
|
# Second call simulates a legacy FE ``appendMessage`` racing the
|
|
# SSE stream: ON CONFLICT DO NOTHING short-circuits at the DB
|
|
# level, the helper recovers the existing id via SELECT, and
|
|
# crucially does NOT raise ``IntegrityError`` (the debugger
|
|
# would otherwise pause on it).
|
|
second_id = await persist_user_turn(
|
|
chat_id=thread_id,
|
|
user_id=user_id_str,
|
|
turn_id=turn_id,
|
|
user_query="ignored on conflict",
|
|
)
|
|
assert first_id is not None
|
|
assert first_id == second_id
|
|
|
|
# Exactly one user row for this turn.
|
|
count = await db_session.execute(
|
|
select(func.count())
|
|
.select_from(NewChatMessage)
|
|
.where(
|
|
NewChatMessage.thread_id == thread_id,
|
|
NewChatMessage.turn_id == turn_id,
|
|
NewChatMessage.role == NewChatMessageRole.USER,
|
|
)
|
|
)
|
|
assert count.scalar_one() == 1
|
|
|
|
async def test_embeds_mentioned_documents_part(
|
|
self,
|
|
db_session,
|
|
db_user,
|
|
db_thread,
|
|
patched_shielded_session,
|
|
):
|
|
"""The full ``{id, title, document_type}`` triple forwarded by
|
|
the FE must round-trip into a single ``mentioned-documents``
|
|
ContentPart on the persisted user message — the history loader
|
|
renders the chips on reload from this part directly.
|
|
"""
|
|
thread_id = db_thread.id
|
|
user_id_str = str(db_user.id)
|
|
turn_id = f"{thread_id}:8200"
|
|
|
|
mentioned = [
|
|
{"id": 11, "title": "Alpha", "document_type": "GENERAL"},
|
|
{"id": 22, "title": "Beta", "document_type": "GENERAL"},
|
|
]
|
|
msg_id = await persist_user_turn(
|
|
chat_id=thread_id,
|
|
user_id=user_id_str,
|
|
turn_id=turn_id,
|
|
user_query="hello",
|
|
mentioned_documents=mentioned,
|
|
)
|
|
assert isinstance(msg_id, int)
|
|
|
|
row = await db_session.get(NewChatMessage, msg_id)
|
|
assert row is not None
|
|
# Content is a 2-part list: text + mentioned-documents.
|
|
assert isinstance(row.content, list)
|
|
assert row.content[0] == {"type": "text", "text": "hello"}
|
|
assert row.content[1] == {
|
|
"type": "mentioned-documents",
|
|
"documents": [
|
|
{"id": 11, "title": "Alpha", "document_type": "GENERAL"},
|
|
{"id": 22, "title": "Beta", "document_type": "GENERAL"},
|
|
],
|
|
}
|
|
|
|
async def test_skips_mentioned_documents_when_empty_or_invalid(
|
|
self,
|
|
db_session,
|
|
db_user,
|
|
db_thread,
|
|
patched_shielded_session,
|
|
):
|
|
"""Empty list and entries missing required fields are dropped;
|
|
a ``mentioned-documents`` part is only emitted when at least
|
|
one normalised entry survived.
|
|
"""
|
|
thread_id = db_thread.id
|
|
user_id_str = str(db_user.id)
|
|
turn_id_empty = f"{thread_id}:8300"
|
|
turn_id_invalid = f"{thread_id}:8301"
|
|
|
|
msg_id_empty = await persist_user_turn(
|
|
chat_id=thread_id,
|
|
user_id=user_id_str,
|
|
turn_id=turn_id_empty,
|
|
user_query="hi",
|
|
mentioned_documents=[],
|
|
)
|
|
assert isinstance(msg_id_empty, int)
|
|
row_empty = await db_session.get(NewChatMessage, msg_id_empty)
|
|
assert row_empty is not None
|
|
assert row_empty.content == [{"type": "text", "text": "hi"}]
|
|
|
|
# Each entry missing one required field — all skipped.
|
|
msg_id_invalid = await persist_user_turn(
|
|
chat_id=thread_id,
|
|
user_id=user_id_str,
|
|
turn_id=turn_id_invalid,
|
|
user_query="hi",
|
|
mentioned_documents=[
|
|
{"title": "no id", "document_type": "GENERAL"}, # missing id
|
|
{"id": 99, "document_type": "GENERAL"}, # missing title
|
|
{"id": 100, "title": "no type"}, # missing document_type
|
|
],
|
|
)
|
|
assert isinstance(msg_id_invalid, int)
|
|
row_invalid = await db_session.get(NewChatMessage, msg_id_invalid)
|
|
assert row_invalid is not None
|
|
assert row_invalid.content == [{"type": "text", "text": "hi"}]
|
|
|
|
async def test_missing_turn_id_returns_none(
|
|
self,
|
|
db_user,
|
|
db_thread,
|
|
patched_shielded_session,
|
|
):
|
|
thread_id = db_thread.id
|
|
user_id_str = str(db_user.id)
|
|
|
|
msg_id = await persist_user_turn(
|
|
chat_id=thread_id,
|
|
user_id=user_id_str,
|
|
turn_id="",
|
|
user_query="hello",
|
|
)
|
|
assert msg_id is None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# finalize_assistant_turn
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestFinalizeAssistantTurn:
|
|
async def test_writes_content_and_token_usage(
|
|
self,
|
|
db_session,
|
|
db_user,
|
|
db_thread,
|
|
db_search_space,
|
|
patched_shielded_session,
|
|
):
|
|
thread_id = db_thread.id
|
|
user_id_uuid = db_user.id
|
|
user_id_str = str(user_id_uuid)
|
|
search_space_id = db_search_space.id
|
|
turn_id = f"{thread_id}:4000"
|
|
|
|
msg_id = await persist_assistant_shell(
|
|
chat_id=thread_id,
|
|
user_id=user_id_str,
|
|
turn_id=turn_id,
|
|
)
|
|
assert msg_id is not None
|
|
|
|
rich_content = [
|
|
{"type": "text", "text": "Hello world"},
|
|
{
|
|
"type": "tool-call",
|
|
"toolCallId": "call_x",
|
|
"toolName": "ls",
|
|
"args": {"path": "/"},
|
|
"argsText": '{\n "path": "/"\n}',
|
|
"result": {"files": []},
|
|
"langchainToolCallId": "lc_x",
|
|
},
|
|
]
|
|
await finalize_assistant_turn(
|
|
message_id=msg_id,
|
|
chat_id=thread_id,
|
|
search_space_id=search_space_id,
|
|
user_id=user_id_str,
|
|
turn_id=turn_id,
|
|
content=rich_content,
|
|
accumulator=_accumulator_with_one_call(),
|
|
)
|
|
|
|
row = await db_session.get(NewChatMessage, msg_id)
|
|
await db_session.refresh(row)
|
|
assert row.content == rich_content
|
|
|
|
# Exactly one token_usage row keyed on this message_id.
|
|
usage_rows = (
|
|
(
|
|
await db_session.execute(
|
|
select(TokenUsage).where(TokenUsage.message_id == msg_id)
|
|
)
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
assert len(usage_rows) == 1
|
|
usage = usage_rows[0]
|
|
assert usage.usage_type == "chat"
|
|
assert usage.prompt_tokens == 100
|
|
assert usage.completion_tokens == 50
|
|
assert usage.total_tokens == 150
|
|
assert usage.cost_micros == 12345
|
|
assert usage.thread_id == thread_id
|
|
assert usage.search_space_id == search_space_id
|
|
|
|
async def test_empty_content_writes_status_marker(
|
|
self,
|
|
db_session,
|
|
db_user,
|
|
db_thread,
|
|
db_search_space,
|
|
patched_shielded_session,
|
|
):
|
|
thread_id = db_thread.id
|
|
user_id_str = str(db_user.id)
|
|
search_space_id = db_search_space.id
|
|
turn_id = f"{thread_id}:5000"
|
|
|
|
msg_id = await persist_assistant_shell(
|
|
chat_id=thread_id,
|
|
user_id=user_id_str,
|
|
turn_id=turn_id,
|
|
)
|
|
assert msg_id is not None
|
|
|
|
# Pure tool-call turn that aborted before any output, or
|
|
# interrupt before any event arrived — empty list.
|
|
await finalize_assistant_turn(
|
|
message_id=msg_id,
|
|
chat_id=thread_id,
|
|
search_space_id=search_space_id,
|
|
user_id=user_id_str,
|
|
turn_id=turn_id,
|
|
content=[],
|
|
accumulator=None,
|
|
)
|
|
|
|
row = await db_session.get(NewChatMessage, msg_id)
|
|
await db_session.refresh(row)
|
|
assert row.content == [{"type": "status", "text": "(no text response)"}]
|
|
|
|
async def test_double_call_safe_via_on_conflict(
|
|
self,
|
|
db_session,
|
|
db_user,
|
|
db_thread,
|
|
db_search_space,
|
|
patched_shielded_session,
|
|
):
|
|
thread_id = db_thread.id
|
|
user_id_str = str(db_user.id)
|
|
search_space_id = db_search_space.id
|
|
turn_id = f"{thread_id}:6000"
|
|
|
|
msg_id = await persist_assistant_shell(
|
|
chat_id=thread_id,
|
|
user_id=user_id_str,
|
|
turn_id=turn_id,
|
|
)
|
|
assert msg_id is not None
|
|
|
|
first_acc = _accumulator_with_one_call()
|
|
await finalize_assistant_turn(
|
|
message_id=msg_id,
|
|
chat_id=thread_id,
|
|
search_space_id=search_space_id,
|
|
user_id=user_id_str,
|
|
turn_id=turn_id,
|
|
content=[{"type": "text", "text": "first finalize"}],
|
|
accumulator=first_acc,
|
|
)
|
|
|
|
# Simulate a follow-up finalize (e.g., resume retry within the
|
|
# shielded finally block firing twice). Different content, but
|
|
# ON CONFLICT DO NOTHING on token_usage means the cost from the
|
|
# first finalize stays authoritative.
|
|
second_acc = TurnTokenAccumulator()
|
|
second_acc.add(
|
|
model="gpt-4o-mini",
|
|
prompt_tokens=999,
|
|
completion_tokens=999,
|
|
total_tokens=1998,
|
|
cost_micros=99999,
|
|
)
|
|
await finalize_assistant_turn(
|
|
message_id=msg_id,
|
|
chat_id=thread_id,
|
|
search_space_id=search_space_id,
|
|
user_id=user_id_str,
|
|
turn_id=turn_id,
|
|
content=[{"type": "text", "text": "second finalize"}],
|
|
accumulator=second_acc,
|
|
)
|
|
|
|
# Content was overwritten by the second UPDATE.
|
|
row = await db_session.get(NewChatMessage, msg_id)
|
|
await db_session.refresh(row)
|
|
assert row.content == [{"type": "text", "text": "second finalize"}]
|
|
|
|
# But token_usage stayed at exactly one row, preserving the
|
|
# first finalize's authoritative cost.
|
|
assert await _count_token_usage_rows(db_session, msg_id) == 1
|
|
usage = (
|
|
await db_session.execute(
|
|
select(TokenUsage).where(TokenUsage.message_id == msg_id)
|
|
)
|
|
).scalar_one()
|
|
assert usage.cost_micros == 12345 # First finalize's value
|
|
|
|
async def test_append_message_style_insert_after_finalize_no_dupe(
|
|
self,
|
|
db_session,
|
|
db_user,
|
|
db_thread,
|
|
db_search_space,
|
|
patched_shielded_session,
|
|
):
|
|
"""Cross-writer race: ``append_message`` arrives after ``finalize_assistant_turn``.
|
|
|
|
Both target the same ``message_id``; the partial unique index
|
|
``uq_token_usage_message_id`` (migration 142) makes the second
|
|
insert a no-op via ``ON CONFLICT DO NOTHING``.
|
|
"""
|
|
from sqlalchemy import text as sa_text
|
|
|
|
thread_id = db_thread.id
|
|
user_uuid = db_user.id
|
|
user_id_str = str(user_uuid)
|
|
search_space_id = db_search_space.id
|
|
turn_id = f"{thread_id}:7000"
|
|
|
|
msg_id = await persist_assistant_shell(
|
|
chat_id=thread_id,
|
|
user_id=user_id_str,
|
|
turn_id=turn_id,
|
|
)
|
|
assert msg_id is not None
|
|
|
|
await finalize_assistant_turn(
|
|
message_id=msg_id,
|
|
chat_id=thread_id,
|
|
search_space_id=search_space_id,
|
|
user_id=user_id_str,
|
|
turn_id=turn_id,
|
|
content=[{"type": "text", "text": "from server"}],
|
|
accumulator=_accumulator_with_one_call(),
|
|
)
|
|
|
|
# Now simulate the FE's append_message branch firing AFTER —
|
|
# the same INSERT ... ON CONFLICT DO NOTHING shape used by the
|
|
# route handler, keyed on the migration-142 partial unique
|
|
# index.
|
|
late_insert = (
|
|
pg_insert(TokenUsage)
|
|
.values(
|
|
usage_type="chat",
|
|
prompt_tokens=42,
|
|
completion_tokens=42,
|
|
total_tokens=84,
|
|
cost_micros=1,
|
|
model_breakdown=None,
|
|
call_details=None,
|
|
thread_id=thread_id,
|
|
message_id=msg_id,
|
|
search_space_id=search_space_id,
|
|
user_id=user_uuid,
|
|
)
|
|
.on_conflict_do_nothing(
|
|
index_elements=["message_id"],
|
|
index_where=sa_text("message_id IS NOT NULL"),
|
|
)
|
|
)
|
|
await db_session.execute(late_insert)
|
|
await db_session.flush()
|
|
|
|
# Still exactly one row, with the original (server) cost value.
|
|
assert await _count_token_usage_rows(db_session, msg_id) == 1
|
|
usage = (
|
|
await db_session.execute(
|
|
select(TokenUsage).where(TokenUsage.message_id == msg_id)
|
|
)
|
|
).scalar_one()
|
|
assert usage.cost_micros == 12345
|
|
|
|
async def test_helper_never_raises_on_missing_message_id(
|
|
self,
|
|
db_session,
|
|
db_user,
|
|
db_thread,
|
|
db_search_space,
|
|
patched_shielded_session,
|
|
):
|
|
thread_id = db_thread.id
|
|
user_id_str = str(db_user.id)
|
|
search_space_id = db_search_space.id
|
|
|
|
# message_id that doesn't exist — finalize must log+return,
|
|
# never raise (called from shielded finally).
|
|
await finalize_assistant_turn(
|
|
message_id=999_999_999,
|
|
chat_id=thread_id,
|
|
search_space_id=search_space_id,
|
|
user_id=user_id_str,
|
|
turn_id="anything",
|
|
content=[{"type": "text", "text": "x"}],
|
|
accumulator=_accumulator_with_one_call(),
|
|
)
|
|
# If we got here without an exception, the test passes.
|
|
# Sanity: no token_usage row created (FK to message would have
|
|
# been rejected anyway, but ON CONFLICT path may swallow
|
|
# FK errors as well; check directly).
|
|
assert await _count_token_usage_rows(db_session, 999_999_999) == 0
|