refactor(chat): add streaming/flows/new_chat/ per-concern leaf modules

Seven focused modules that the upcoming new_chat orchestrator
composes:

* auto_pin: resolve_initial_auto_pin selects the initial config (with
  vision-capable filtering and error classification).
* llm_capability: check_image_input_capability blocks routing an
  image-bearing turn to a known text-only model.
* runtime_context: build_new_chat_runtime_context assembles the
  SurfSenseContextSchema for a new-chat turn.
* persistence_spawn: spawn_set_ai_responding_bg, spawn_persist_user_task,
  spawn_persist_assistant_shell_task, and await_persist_task background
  the four pre-stream DB writes so they overlap with agent build.
* initial_thinking_step: build_initial_thinking_step +
  iter_initial_thinking_step_frame produce the very first thinking-1 SSE
  step ("Understanding your request" / "Analyzing referenced content").
* title_gen: spawn_title_task + maybe_emit_title_update +
  await_pending_title_update background the thread-title generator and
  interleave its update into the stream when ready.
* input_state: build_new_chat_input_state assembles the LangGraph
  input_state (history bootstrap, mentions resolution, context blocks,
  human-message construction). The heavy one.

Add-only; no orchestrator yet (next commit).
This commit is contained in:
CREDO23 2026-05-25 21:49:45 +02:00
parent 21bddc73a7
commit 927009745e
7 changed files with 920 additions and 0 deletions

View file

@ -0,0 +1,95 @@
"""Resolve the auto-pin for the *initial* turn config.
Auto-pin (``selected_llm_config_id=0``) picks the best eligible LLM config for
this thread / search space / user, optionally filtered to vision-capable
configs when the turn carries images.
Errors classified here:
* ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` the auto-pin pool has no
vision-capable cfg for an image-bearing turn. The same gate fires later
in ``llm_capability`` for explicit selections; mapping both to the same
code keeps the FE error UI consistent.
* ``SERVER_ERROR`` any other ``ValueError`` from the resolver.
This module owns *initial* pin resolution; the rate-limit recovery loop has
its own narrower auto-pin call (with ``exclude_config_ids``) in
``flows/shared/rate_limit_recovery``.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal
from sqlalchemy.ext.asyncio import AsyncSession
from app.observability import otel as ot
from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id
@dataclass
class AutoPinResult:
"""Outcome of ``resolve_initial_auto_pin``.
``llm_config_id`` is set when ``error`` is ``None``; ``error`` carries the
classified user-facing message plus error code/kind so the orchestrator can
emit one terminal-error SSE frame.
"""
llm_config_id: int | None
error: tuple[str, str, Literal["user_error", "server_error"]] | None
async def resolve_initial_auto_pin(
session: AsyncSession,
*,
chat_id: int,
search_space_id: int,
user_id: str | None,
selected_llm_config_id: int,
requires_image_input: bool,
requested_llm_config_id: int,
) -> AutoPinResult:
"""Run the resolver and classify any ``ValueError`` for the SSE error path."""
try:
pinned = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=selected_llm_config_id,
requires_image_input=requires_image_input,
)
ot.add_event(
"model.pin.resolved",
{
"pin.requested_id": requested_llm_config_id,
"pin.resolved_id": pinned.resolved_llm_config_id,
"pin.requires_image_input": requires_image_input,
},
)
return AutoPinResult(
llm_config_id=pinned.resolved_llm_config_id, error=None
)
except ValueError as pin_error:
# The "no vision-capable cfg" path raises a ValueError whose message
# we map to the friendly image-input SSE error so the user sees the
# same message regardless of whether the gate fired in the resolver or
# in ``llm_capability.assert_vision_capability_for_image_turn``.
is_vision_failure = (
requires_image_input and "vision-capable" in str(pin_error)
)
error_code = (
"MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
if is_vision_failure
else "SERVER_ERROR"
)
error_kind: Literal["user_error", "server_error"] = (
"user_error" if is_vision_failure else "server_error"
)
if is_vision_failure:
ot.add_event("quota.denied", {"quota.code": error_code})
return AutoPinResult(
llm_config_id=None, error=(str(pin_error), error_code, error_kind)
)

View file

@ -0,0 +1,95 @@
"""Build and emit the first ``thinking-1`` step for a new-chat turn.
The step title and "Processing X" items are derived from what the user sent
(text snippet, image count, mentioned doc titles) so the FE can render a
meaningful placeholder while the agent stream warms up.
``thinking-1`` is the canonical id for this step every subsequent
``thinking-N`` produced by ``stream_agent_events`` folds into the same
singleton ``data-thinking-steps`` part on the FE.
"""
from __future__ import annotations
from collections.abc import Iterator
from dataclasses import dataclass
from typing import Any
from app.db import SurfsenseDocsDocument
from app.services.new_streaming_service import VercelStreamingService
@dataclass
class InitialThinkingStep:
"""Resolved fields passed both into the SSE frame and the builder hook.
``items`` is the bullet list under the step title; ``title`` is the
one-line step header. ``step_id`` is hard-coded ``thinking-1`` so the FE
Timeline can de-duplicate against the prior assistant message on resume.
"""
step_id: str
title: str
items: list[str]
def build_initial_thinking_step(
*,
user_query: str,
user_image_data_urls: list[str] | None,
mentioned_surfsense_docs: list[SurfsenseDocsDocument],
) -> InitialThinkingStep:
if mentioned_surfsense_docs:
title = "Analyzing referenced content"
action_verb = "Analyzing"
else:
title = "Understanding your request"
action_verb = "Processing"
processing_parts: list[str] = []
if user_query.strip():
query_text = user_query[:80] + ("..." if len(user_query) > 80 else "")
processing_parts.append(query_text)
elif user_image_data_urls:
processing_parts.append(f"[{len(user_image_data_urls)} image(s)]")
else:
processing_parts.append("(message)")
if mentioned_surfsense_docs:
doc_names: list[str] = []
for doc in mentioned_surfsense_docs:
t = doc.title
if len(t) > 30:
t = t[:27] + "..."
doc_names.append(t)
if len(doc_names) == 1:
processing_parts.append(f"[{doc_names[0]}]")
else:
processing_parts.append(f"[{len(doc_names)} docs]")
items = [f"{action_verb}: {' '.join(processing_parts)}"]
return InitialThinkingStep(step_id="thinking-1", title=title, items=items)
def iter_initial_thinking_step_frame(
step: InitialThinkingStep,
*,
streaming_service: VercelStreamingService,
content_builder: Any | None,
) -> Iterator[str]:
"""Drive both the SSE emission and the builder hook for the initial step.
The FE folds this step into the same singleton ``data-thinking-steps`` part
as everything the agent stream emits later, so we mirror that fold
server-side by driving the builder lifecycle ourselves.
"""
if content_builder is not None:
content_builder.on_thinking_step(
step.step_id, step.title, "in_progress", step.items
)
yield streaming_service.format_thinking_step(
step_id=step.step_id,
title=step.title,
status="in_progress",
items=step.items,
)

View file

@ -0,0 +1,264 @@
r"""Assemble the LangGraph ``input_state`` for the new-chat turn.
Pipeline:
1. **History bootstrap** only for cloned chats with no LangGraph checkpoint
yet; flips the per-thread ``needs_history_bootstrap`` flag back to False
once the rows are loaded.
2. **Mentioned SurfSense docs** eager-load chunks so the formatter has the
full content without a second roundtrip.
3. **Recent reports** top 3 by id desc with non-null content, so the LLM
can resolve ``report_id`` for versioning without spelunking history.
4. **@-mention resolve** (cloud mode) substitute ``@title`` tokens in the
query with canonical ``\`/documents/...\``` paths the LLM expects.
5. **Context block render** XML-wrap surfsense docs + reports, prepend to
the rewritten query, optionally prefix with display name for SEARCH_SPACE
visibility.
6. **HumanMessage** multimodal content if images are attached.
Returns the assembled ``input_state`` dict plus side-channel data the
orchestrator needs downstream (``accepted_folder_ids`` for runtime context;
``mentioned_surfsense_docs`` for the initial thinking step).
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Any
from langchain_core.messages import HumanMessage
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.mention_resolver import resolve_mentions, substitute_in_text
from app.db import (
ChatVisibility,
NewChatThread,
Report,
SurfsenseDocsDocument,
)
from app.tasks.chat.streaming.context.mentioned_docs import (
format_mentioned_surfsense_docs_as_context,
)
from app.utils.content_utils import bootstrap_history_from_db
from app.utils.user_message_multimodal import build_human_message_content
logger = logging.getLogger(__name__)
@dataclass
class NewChatInputState:
"""Everything ``build_new_chat_input_state`` produces.
``input_state`` is fed straight to the agent. ``accepted_folder_ids``
feeds the runtime context (the resolver may have dropped some chips).
``mentioned_surfsense_docs`` is consumed by the initial thinking-step
builder for the FE placeholder before the agent stream starts.
"""
input_state: dict[str, Any]
accepted_folder_ids: list[int]
mentioned_surfsense_docs: list[SurfsenseDocsDocument]
async def build_new_chat_input_state(
session: AsyncSession,
*,
chat_id: int,
search_space_id: int,
user_query: str,
user_image_data_urls: list[str] | None,
mentioned_document_ids: list[int] | None,
mentioned_surfsense_doc_ids: list[int] | None,
mentioned_folder_ids: list[int] | None,
mentioned_documents: list[dict[str, Any]] | None,
needs_history_bootstrap: bool,
thread_visibility: ChatVisibility,
current_user_display_name: str | None,
filesystem_mode: str,
request_id: str | None,
turn_id: str,
) -> NewChatInputState:
langchain_messages: list[Any] = []
if needs_history_bootstrap:
langchain_messages = await bootstrap_history_from_db(
session, chat_id, thread_visibility=thread_visibility
)
thread_result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == chat_id)
)
thread = thread_result.scalars().first()
if thread:
thread.needs_history_bootstrap = False
await session.commit()
mentioned_surfsense_docs: list[SurfsenseDocsDocument] = []
if mentioned_surfsense_doc_ids:
result = await session.execute(
select(SurfsenseDocsDocument)
.options(selectinload(SurfsenseDocsDocument.chunks))
.filter(SurfsenseDocsDocument.id.in_(mentioned_surfsense_doc_ids))
)
mentioned_surfsense_docs = list(result.scalars().all())
# Top 3 reports keyed by id desc (newest first) with content present,
# surfaced inline so the LLM resolves ``report_id`` for versioning without
# digging through conversation history.
recent_reports_result = await session.execute(
select(Report)
.filter(
Report.thread_id == chat_id,
Report.content.isnot(None),
)
.order_by(Report.id.desc())
.limit(3)
)
recent_reports = list(recent_reports_result.scalars().all())
agent_user_query, accepted_folder_ids = await _resolve_mentions_for_query(
session,
search_space_id=search_space_id,
user_query=user_query,
filesystem_mode=filesystem_mode,
mentioned_document_ids=mentioned_document_ids,
mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids,
mentioned_folder_ids=mentioned_folder_ids,
mentioned_documents=mentioned_documents,
)
final_query = _render_query_with_context(
agent_user_query=agent_user_query,
mentioned_surfsense_docs=mentioned_surfsense_docs,
recent_reports=recent_reports,
)
if thread_visibility == ChatVisibility.SEARCH_SPACE and current_user_display_name:
final_query = f"**[{current_user_display_name}]:** {final_query}"
human_content = build_human_message_content(
final_query, list(user_image_data_urls or ())
)
langchain_messages.append(HumanMessage(content=human_content))
input_state = {
"messages": langchain_messages,
"search_space_id": search_space_id,
"request_id": request_id or "unknown",
"turn_id": turn_id,
}
return NewChatInputState(
input_state=input_state,
accepted_folder_ids=accepted_folder_ids,
mentioned_surfsense_docs=mentioned_surfsense_docs,
)
async def _resolve_mentions_for_query(
session: AsyncSession,
*,
search_space_id: int,
user_query: str,
filesystem_mode: str,
mentioned_document_ids: list[int] | None,
mentioned_surfsense_doc_ids: list[int] | None,
mentioned_folder_ids: list[int] | None,
mentioned_documents: list[dict[str, Any]] | None,
) -> tuple[str, list[int]]:
r"""Resolve @-mention chips and rewrite the user query to canonical paths.
Cloud mode only: local-folder mode keeps the legacy ``@title`` text path
(mention support there is a follow-up task the path scheme is
mount-rooted and the picker UI both need separate work).
The substitution lands in the returned ``agent_user_query`` ONLY the
original ``user_query`` (with ``@title`` tokens) flows untouched into
``persist_user_turn`` so chip rendering on reload still works
(``UserTextPart`` ``parseMentionSegments`` matches ``@title``, not
``\`/documents/...\```). It also feeds the human-readable surfaces SSE
"Processing X" status, auto thread title, memory seed which all want
what the user typed.
"""
agent_user_query = user_query
accepted_folder_ids: list[int] = []
has_any_mention = bool(
mentioned_document_ids
or mentioned_surfsense_doc_ids
or mentioned_folder_ids
or mentioned_documents
)
if filesystem_mode != FilesystemMode.CLOUD.value or not has_any_mention:
return agent_user_query, accepted_folder_ids
from app.schemas.new_chat import MentionedDocumentInfo
chip_objs: list[MentionedDocumentInfo] | None = None
if mentioned_documents:
chip_objs = []
for raw in mentioned_documents:
if isinstance(raw, MentionedDocumentInfo):
chip_objs.append(raw)
continue
try:
chip_objs.append(MentionedDocumentInfo.model_validate(raw))
except Exception:
logger.debug(
"stream_new_chat: dropping malformed mention chip %r", raw
)
resolved = await resolve_mentions(
session,
search_space_id=search_space_id,
mentioned_documents=chip_objs,
mentioned_document_ids=mentioned_document_ids,
mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids,
mentioned_folder_ids=mentioned_folder_ids,
)
agent_user_query = substitute_in_text(user_query, resolved.token_to_path)
accepted_folder_ids = resolved.mentioned_folder_ids
return agent_user_query, accepted_folder_ids
def _render_query_with_context(
*,
agent_user_query: str,
mentioned_surfsense_docs: list[SurfsenseDocsDocument],
recent_reports: list[Report],
) -> str:
"""Prepend surfsense-docs + recent-reports XML blocks to the user query."""
context_parts: list[str] = []
if mentioned_surfsense_docs:
context_parts.append(
format_mentioned_surfsense_docs_as_context(mentioned_surfsense_docs)
)
if recent_reports:
report_lines: list[str] = []
for r in recent_reports:
report_lines.append(
f' - report_id={r.id}, title="{r.title}", '
f'style="{r.report_style or "detailed"}"'
)
reports_listing = "\n".join(report_lines)
context_parts.append(
"<report_context>\n"
"Previously generated reports in this conversation:\n"
f"{reports_listing}\n\n"
"If the user wants to MODIFY, REVISE, UPDATE, or ADD to one of "
"these reports, set parent_report_id to the relevant report_id above.\n"
"If the user wants a completely NEW report on a different topic, "
"leave parent_report_id unset.\n"
"</report_context>"
)
if context_parts:
context = "\n\n".join(context_parts)
return f"{context}\n\n<user_query>{agent_user_query}</user_query>"
return agent_user_query

View file

@ -0,0 +1,62 @@
"""Vision-capability gate for image-bearing turns.
Capability safety net for explicit (non-auto-pin) selections: a turn carrying
user-uploaded images cannot be routed to a chat config that LiteLLM's
authoritative model map *explicitly* marks as text-only (``supports_vision``
set to False). The check is intentionally narrow it only fires when LiteLLM
is *certain* the model can't accept image input; unknown / unmapped /
vision-capable models pass through.
Without this guard a known-text-only model would 404 at the provider with
``"No endpoints found that support image input"``, surfacing as an opaque
``SERVER_ERROR`` SSE chunk; failing here lets us return a friendly message that
tells the user what to change.
"""
from __future__ import annotations
from app.agents.new_chat.llm_config import AgentConfig
from app.observability import otel as ot
def check_image_input_capability(
*,
user_image_data_urls: list[str] | None,
agent_config: AgentConfig | None,
) -> tuple[str, str] | None:
"""Return ``(user_message, error_code)`` when the gate trips, else ``None``.
The caller emits one terminal-error SSE frame on a non-``None`` return.
"""
if not (user_image_data_urls and agent_config is not None):
return None
from app.services.provider_capabilities import is_known_text_only_chat_model
agent_litellm_params = agent_config.litellm_params or {}
agent_base_model = (
agent_litellm_params.get("base_model")
if isinstance(agent_litellm_params, dict)
else None
)
if not is_known_text_only_chat_model(
provider=agent_config.provider,
model_name=agent_config.model_name,
base_model=agent_base_model,
custom_provider=agent_config.custom_provider,
):
return None
model_label = agent_config.config_name or agent_config.model_name or "model"
ot.add_event(
"quota.denied", {"quota.code": "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"}
)
return (
(
f"The selected model ({model_label}) does not support "
"image input. Switch to a vision-capable model "
"(e.g. GPT-4o, Claude, Gemini) or remove the image "
"attachment and try again."
),
"MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT",
)

View file

@ -0,0 +1,129 @@
"""Concurrent persistence tasks spawned right after the initial validation gate.
These run *during* the rest of the pre-stream setup so we don't serialize
their latency against agent construction. Awaiting them at the SSE message-id
yield sites preserves the ghost-thread protection (the user-row INSERT must
succeed before any LLM streaming begins).
The ``set_ai_responding`` flag flip runs fully fire-and-forget on its own
shielded session failures only delay the "AI is responding…" UI flag, not
the response itself.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any
from uuid import UUID
from app.db import shielded_async_session
from app.services.chat_session_state_service import set_ai_responding
from app.tasks.chat.persistence import (
persist_assistant_shell,
persist_user_turn,
)
logger = logging.getLogger(__name__)
def spawn_set_ai_responding_bg(
*,
chat_id: int,
user_id: str | None,
background_tasks: set[asyncio.Task[Any]],
) -> None:
"""Fire-and-forget: flip the per-thread AI-responding flag on its own session.
Errors are swallowed and logged the worst case is a stale UI flag, which
is preferable to delaying the SSE stream behind a flag write.
"""
if not user_id:
return
async def _bg_set_ai_responding() -> None:
try:
async with shielded_async_session() as s:
await set_ai_responding(s, chat_id, UUID(user_id))
except Exception:
logger.warning(
"set_ai_responding failed (chat_id=%s)",
chat_id,
exc_info=True,
)
t = asyncio.create_task(_bg_set_ai_responding())
background_tasks.add(t)
t.add_done_callback(background_tasks.discard)
def spawn_persist_user_task(
*,
chat_id: int,
user_id: str | None,
turn_id: str,
user_query: str,
user_image_data_urls: list[str] | None,
mentioned_documents: list[dict[str, Any]] | None,
background_tasks: set[asyncio.Task[Any]],
) -> asyncio.Task[int | None]:
"""Spawn the user-row INSERT; await at the user-message-id yield site."""
task = asyncio.create_task(
persist_user_turn(
chat_id=chat_id,
user_id=user_id,
turn_id=turn_id,
user_query=user_query,
user_image_data_urls=user_image_data_urls,
mentioned_documents=mentioned_documents,
)
)
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)
return task
def spawn_persist_assistant_shell_task(
*,
chat_id: int,
user_id: str | None,
turn_id: str,
background_tasks: set[asyncio.Task[Any]],
) -> asyncio.Task[int | None]:
"""Spawn the assistant-shell INSERT; await at the assistant-message-id yield site."""
task = asyncio.create_task(
persist_assistant_shell(
chat_id=chat_id,
user_id=user_id,
turn_id=turn_id,
)
)
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)
return task
async def await_persist_task(
task: asyncio.Task[int | None] | None,
*,
chat_id: int,
turn_id: str,
log_label: str,
) -> int | None:
"""Join a spawned persistence task with ``shield`` + uniform error handling.
``shield`` keeps the DB write alive if the SSE generator is cancelled by
client disconnect mid-await. Returns ``None`` on failure; the caller
abort-paths the turn with a friendly error SSE.
"""
if task is None:
return None
try:
return await asyncio.shield(task)
except asyncio.CancelledError:
raise
except Exception:
logger.exception(
"%s failed (chat_id=%s, turn_id=%s)", log_label, chat_id, turn_id
)
return None

View file

@ -0,0 +1,38 @@
"""Build the per-invocation ``SurfSenseContextSchema`` for a new-chat turn.
Carries the per-turn read inputs that middlewares read via
``runtime.context.*`` instead of from their ``__init__`` closures, so the same
compiled-agent instance can serve multiple turns with different
mention lists / request ids / turn ids without rebuilding the graph.
"""
from __future__ import annotations
from app.agents.new_chat.context import SurfSenseContextSchema
def build_new_chat_runtime_context(
*,
search_space_id: int,
mentioned_document_ids: list[int] | None,
accepted_folder_ids: list[int],
mentioned_folder_ids: list[int] | None,
request_id: str | None,
turn_id: str,
) -> SurfSenseContextSchema:
"""``mentioned_document_ids`` is consumed by ``KnowledgePriorityMiddleware``.
``accepted_folder_ids`` (post-resolve) wins over the raw
``mentioned_folder_ids`` from the request: the resolver drops chips that
pointed at deleted folders or folders the caller can't see, so middlewares
only get authorized ids.
"""
return SurfSenseContextSchema(
search_space_id=search_space_id,
mentioned_document_ids=list(mentioned_document_ids or []),
mentioned_folder_ids=list(
accepted_folder_ids or mentioned_folder_ids or []
),
request_id=request_id,
turn_id=turn_id,
)

View file

@ -0,0 +1,237 @@
"""Background thread-title generation (first-response only).
The first assistant response in a thread gets a short auto-generated title
inserted into ``new_chat_threads.title``. We:
1. Spawn the generation as an ``asyncio.Task`` so it runs in parallel with
the agent stream (no extra TTFT).
2. Probe inside the task (on its own shielded session) whether this is
actually the first response newer turns short-circuit to ``None``.
3. Inject the resulting ``thread-title-update`` SSE frame on the first agent
event after the task completes (mid-stream interlock), or right before
the finish frames (post-stream join) if the task hadn't finished yet.
Usage tokens come directly off the response (LiteLLM's async callback fires
via fire-and-forget ``create_task``, so the ``TokenTrackingCallback`` would
run too late). We also blank the per-task accumulator so the late callback
doesn't double-count.
"""
from __future__ import annotations
import asyncio
import logging
from typing import TYPE_CHECKING, Any
from sqlalchemy.future import select
from app.db import NewChatMessage, NewChatThread, shielded_async_session
from app.prompts import TITLE_GENERATION_PROMPT
from app.services.new_streaming_service import VercelStreamingService
if TYPE_CHECKING:
from app.agents.new_chat.llm_config import AgentConfig
from app.services.token_tracking_service import TokenAccumulator
logger = logging.getLogger(__name__)
def spawn_title_task(
*,
chat_id: int,
user_query: str,
user_image_data_urls: list[str] | None,
assistant_message_id: int | None,
llm: Any,
agent_config: AgentConfig | None,
) -> asyncio.Task[tuple[str | None, dict | None]] | None:
"""Spawn ``_generate_title``; returns ``None`` when prerequisites aren't met.
Title gen is gated on a real ``assistant_message_id`` so a stream that
aborts before persistence can never leave a thread with a title and no
anchoring rows.
"""
if assistant_message_id is None:
return None
return asyncio.create_task(
_generate_title(
chat_id=chat_id,
user_query=user_query,
user_image_data_urls=user_image_data_urls,
assistant_message_id=assistant_message_id,
llm=llm,
agent_config=agent_config,
)
)
async def _generate_title(
*,
chat_id: int,
user_query: str,
user_image_data_urls: list[str] | None,
assistant_message_id: int,
llm: Any,
agent_config: AgentConfig | None,
) -> tuple[str | None, dict | None]:
"""Probe is-first-response, then call ``acompletion``. Returns ``(title, usage)``."""
try:
from litellm import acompletion
from app.services.llm_router_service import LLMRouterService
from app.services.provider_api_base import resolve_api_base
from app.services.token_tracking_service import _turn_accumulator
# Excludes this turn's own assistant row (pre-written by
# ``persist_assistant_shell``) — without the ``!=`` filter the gate
# would false-negative on every turn after the first.
try:
async with shielded_async_session() as probe_session:
probe_result = await probe_session.execute(
select(NewChatMessage.id)
.filter(
NewChatMessage.thread_id == chat_id,
NewChatMessage.role == "assistant",
NewChatMessage.id != assistant_message_id,
)
.limit(1)
)
is_first_response = probe_result.scalars().first() is None
except Exception:
logger.warning(
"[TitleGen] first-response probe failed (chat_id=%s)",
chat_id,
exc_info=True,
)
return None, None
if not is_first_response:
return None, None
_turn_accumulator.set(None)
title_seed = user_query.strip() or (
f"[{len(user_image_data_urls or [])} image(s)]"
if user_image_data_urls
else ""
)
prompt = TITLE_GENERATION_PROMPT.replace(
"{user_query}", title_seed[:500] or "(message)"
)
messages = [{"role": "user", "content": prompt}]
if getattr(llm, "model", None) == "auto":
router = LLMRouterService.get_router()
response = await router.acompletion(model="auto", messages=messages)
else:
# Apply the same ``api_base`` cascade chat / vision / image-gen
# call sites use so we never inherit ``litellm.api_base``
# (commonly set by ``AZURE_OPENAI_ENDPOINT``) when the chat
# config itself ships an empty ``api_base``. Without this the
# title-gen on an OpenRouter chat config would 404 against the
# inherited Azure endpoint — see ``provider_api_base`` for the
# same bug repro on the image-gen / vision paths.
raw_model = getattr(llm, "model", "") or ""
provider_prefix = (
raw_model.split("/", 1)[0] if "/" in raw_model else None
)
provider_value = (
agent_config.provider if agent_config is not None else None
)
title_api_base = resolve_api_base(
provider=provider_value,
provider_prefix=provider_prefix,
config_api_base=getattr(llm, "api_base", None),
)
response = await acompletion(
model=raw_model,
messages=messages,
api_key=getattr(llm, "api_key", None),
api_base=title_api_base,
)
usage_info = None
usage = getattr(response, "usage", None)
if usage:
raw_model = getattr(llm, "model", "") or ""
model_name = (
raw_model.split("/", 1)[-1]
if "/" in raw_model
else (raw_model or response.model or "unknown")
)
usage_info = {
"model": model_name,
"prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0,
"completion_tokens": getattr(usage, "completion_tokens", 0) or 0,
"total_tokens": getattr(usage, "total_tokens", 0) or 0,
}
raw_title = response.choices[0].message.content.strip()
if raw_title and len(raw_title) <= 100:
return raw_title.strip("\"'"), usage_info
return None, usage_info
except Exception:
logger.exception("[TitleGen] _generate_title failed")
return None, None
async def maybe_emit_title_update(
*,
title_task: asyncio.Task[tuple[str | None, dict | None]] | None,
title_emitted: bool,
chat_id: int,
accumulator: TokenAccumulator,
streaming_service: VercelStreamingService,
):
"""Inject one ``thread-title-update`` SSE if the task completed.
Yields the SSE frame (when applicable). Returns nothing; the orchestrator
flips ``title_emitted`` itself after iterating so we don't fight Python's
nonlocal-in-generator semantics.
"""
if title_task is None or title_emitted or not title_task.done():
return
generated_title, title_usage = title_task.result()
if title_usage:
accumulator.add(**title_usage)
if generated_title:
async with shielded_async_session() as title_session:
title_thread_result = await title_session.execute(
select(NewChatThread).filter(NewChatThread.id == chat_id)
)
title_thread = title_thread_result.scalars().first()
if title_thread:
title_thread.title = generated_title
await title_session.commit()
yield streaming_service.format_thread_title_update(chat_id, generated_title)
async def await_pending_title_update(
*,
title_task: asyncio.Task[tuple[str | None, dict | None]] | None,
title_emitted: bool,
chat_id: int,
accumulator: TokenAccumulator,
streaming_service: VercelStreamingService,
):
"""If the task hadn't completed during the stream, await it now and emit.
Used right before the finish frames in the success path. Mirror of
``maybe_emit_title_update`` but unconditionally awaits.
"""
if title_task is None or title_emitted:
return
generated_title, title_usage = await title_task
if title_usage:
accumulator.add(**title_usage)
if generated_title:
async with shielded_async_session() as title_session:
title_thread_result = await title_session.execute(
select(NewChatThread).filter(NewChatThread.id == chat_id)
)
title_thread = title_thread_result.scalars().first()
if title_thread:
title_thread.title = generated_title
await title_session.commit()
yield streaming_service.format_thread_title_update(chat_id, generated_title)