diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/auto_pin.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/auto_pin.py new file mode 100644 index 000000000..cb20eb011 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/auto_pin.py @@ -0,0 +1,95 @@ +"""Resolve the auto-pin for the *initial* turn config. + +Auto-pin (``selected_llm_config_id=0``) picks the best eligible LLM config for +this thread / search space / user, optionally filtered to vision-capable +configs when the turn carries images. + +Errors classified here: + + * ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` — the auto-pin pool has no + vision-capable cfg for an image-bearing turn. The same gate fires later + in ``llm_capability`` for explicit selections; mapping both to the same + code keeps the FE error UI consistent. + * ``SERVER_ERROR`` — any other ``ValueError`` from the resolver. + +This module owns *initial* pin resolution; the rate-limit recovery loop has +its own narrower auto-pin call (with ``exclude_config_ids``) in +``flows/shared/rate_limit_recovery``. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.observability import otel as ot +from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id + + +@dataclass +class AutoPinResult: + """Outcome of ``resolve_initial_auto_pin``. + + ``llm_config_id`` is set when ``error`` is ``None``; ``error`` carries the + classified user-facing message plus error code/kind so the orchestrator can + emit one terminal-error SSE frame. + """ + + llm_config_id: int | None + error: tuple[str, str, Literal["user_error", "server_error"]] | None + + +async def resolve_initial_auto_pin( + session: AsyncSession, + *, + chat_id: int, + search_space_id: int, + user_id: str | None, + selected_llm_config_id: int, + requires_image_input: bool, + requested_llm_config_id: int, +) -> AutoPinResult: + """Run the resolver and classify any ``ValueError`` for the SSE error path.""" + try: + pinned = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=selected_llm_config_id, + requires_image_input=requires_image_input, + ) + ot.add_event( + "model.pin.resolved", + { + "pin.requested_id": requested_llm_config_id, + "pin.resolved_id": pinned.resolved_llm_config_id, + "pin.requires_image_input": requires_image_input, + }, + ) + return AutoPinResult( + llm_config_id=pinned.resolved_llm_config_id, error=None + ) + except ValueError as pin_error: + # The "no vision-capable cfg" path raises a ValueError whose message + # we map to the friendly image-input SSE error so the user sees the + # same message regardless of whether the gate fired in the resolver or + # in ``llm_capability.assert_vision_capability_for_image_turn``. + is_vision_failure = ( + requires_image_input and "vision-capable" in str(pin_error) + ) + error_code = ( + "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT" + if is_vision_failure + else "SERVER_ERROR" + ) + error_kind: Literal["user_error", "server_error"] = ( + "user_error" if is_vision_failure else "server_error" + ) + if is_vision_failure: + ot.add_event("quota.denied", {"quota.code": error_code}) + return AutoPinResult( + llm_config_id=None, error=(str(pin_error), error_code, error_kind) + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/initial_thinking_step.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/initial_thinking_step.py new file mode 100644 index 000000000..c860e517e --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/initial_thinking_step.py @@ -0,0 +1,95 @@ +"""Build and emit the first ``thinking-1`` step for a new-chat turn. + +The step title and "Processing X" items are derived from what the user sent +(text snippet, image count, mentioned doc titles) so the FE can render a +meaningful placeholder while the agent stream warms up. + +``thinking-1`` is the canonical id for this step — every subsequent +``thinking-N`` produced by ``stream_agent_events`` folds into the same +singleton ``data-thinking-steps`` part on the FE. +""" + +from __future__ import annotations + +from collections.abc import Iterator +from dataclasses import dataclass +from typing import Any + +from app.db import SurfsenseDocsDocument +from app.services.new_streaming_service import VercelStreamingService + + +@dataclass +class InitialThinkingStep: + """Resolved fields passed both into the SSE frame and the builder hook. + + ``items`` is the bullet list under the step title; ``title`` is the + one-line step header. ``step_id`` is hard-coded ``thinking-1`` so the FE + Timeline can de-duplicate against the prior assistant message on resume. + """ + + step_id: str + title: str + items: list[str] + + +def build_initial_thinking_step( + *, + user_query: str, + user_image_data_urls: list[str] | None, + mentioned_surfsense_docs: list[SurfsenseDocsDocument], +) -> InitialThinkingStep: + if mentioned_surfsense_docs: + title = "Analyzing referenced content" + action_verb = "Analyzing" + else: + title = "Understanding your request" + action_verb = "Processing" + + processing_parts: list[str] = [] + if user_query.strip(): + query_text = user_query[:80] + ("..." if len(user_query) > 80 else "") + processing_parts.append(query_text) + elif user_image_data_urls: + processing_parts.append(f"[{len(user_image_data_urls)} image(s)]") + else: + processing_parts.append("(message)") + + if mentioned_surfsense_docs: + doc_names: list[str] = [] + for doc in mentioned_surfsense_docs: + t = doc.title + if len(t) > 30: + t = t[:27] + "..." + doc_names.append(t) + if len(doc_names) == 1: + processing_parts.append(f"[{doc_names[0]}]") + else: + processing_parts.append(f"[{len(doc_names)} docs]") + + items = [f"{action_verb}: {' '.join(processing_parts)}"] + return InitialThinkingStep(step_id="thinking-1", title=title, items=items) + + +def iter_initial_thinking_step_frame( + step: InitialThinkingStep, + *, + streaming_service: VercelStreamingService, + content_builder: Any | None, +) -> Iterator[str]: + """Drive both the SSE emission and the builder hook for the initial step. + + The FE folds this step into the same singleton ``data-thinking-steps`` part + as everything the agent stream emits later, so we mirror that fold + server-side by driving the builder lifecycle ourselves. + """ + if content_builder is not None: + content_builder.on_thinking_step( + step.step_id, step.title, "in_progress", step.items + ) + yield streaming_service.format_thinking_step( + step_id=step.step_id, + title=step.title, + status="in_progress", + items=step.items, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/input_state.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/input_state.py new file mode 100644 index 000000000..fb171c244 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/input_state.py @@ -0,0 +1,264 @@ +r"""Assemble the LangGraph ``input_state`` for the new-chat turn. + +Pipeline: + + 1. **History bootstrap** — only for cloned chats with no LangGraph checkpoint + yet; flips the per-thread ``needs_history_bootstrap`` flag back to False + once the rows are loaded. + 2. **Mentioned SurfSense docs** — eager-load chunks so the formatter has the + full content without a second roundtrip. + 3. **Recent reports** — top 3 by id desc with non-null content, so the LLM + can resolve ``report_id`` for versioning without spelunking history. + 4. **@-mention resolve** (cloud mode) — substitute ``@title`` tokens in the + query with canonical ``\`/documents/...\``` paths the LLM expects. + 5. **Context block render** — XML-wrap surfsense docs + reports, prepend to + the rewritten query, optionally prefix with display name for SEARCH_SPACE + visibility. + 6. **HumanMessage** — multimodal content if images are attached. + +Returns the assembled ``input_state`` dict plus side-channel data the +orchestrator needs downstream (``accepted_folder_ids`` for runtime context; +``mentioned_surfsense_docs`` for the initial thinking step). +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any + +from langchain_core.messages import HumanMessage +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy.orm import selectinload + +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.mention_resolver import resolve_mentions, substitute_in_text +from app.db import ( + ChatVisibility, + NewChatThread, + Report, + SurfsenseDocsDocument, +) +from app.tasks.chat.streaming.context.mentioned_docs import ( + format_mentioned_surfsense_docs_as_context, +) +from app.utils.content_utils import bootstrap_history_from_db +from app.utils.user_message_multimodal import build_human_message_content + +logger = logging.getLogger(__name__) + + +@dataclass +class NewChatInputState: + """Everything ``build_new_chat_input_state`` produces. + + ``input_state`` is fed straight to the agent. ``accepted_folder_ids`` + feeds the runtime context (the resolver may have dropped some chips). + ``mentioned_surfsense_docs`` is consumed by the initial thinking-step + builder for the FE placeholder before the agent stream starts. + """ + + input_state: dict[str, Any] + accepted_folder_ids: list[int] + mentioned_surfsense_docs: list[SurfsenseDocsDocument] + + +async def build_new_chat_input_state( + session: AsyncSession, + *, + chat_id: int, + search_space_id: int, + user_query: str, + user_image_data_urls: list[str] | None, + mentioned_document_ids: list[int] | None, + mentioned_surfsense_doc_ids: list[int] | None, + mentioned_folder_ids: list[int] | None, + mentioned_documents: list[dict[str, Any]] | None, + needs_history_bootstrap: bool, + thread_visibility: ChatVisibility, + current_user_display_name: str | None, + filesystem_mode: str, + request_id: str | None, + turn_id: str, +) -> NewChatInputState: + langchain_messages: list[Any] = [] + + if needs_history_bootstrap: + langchain_messages = await bootstrap_history_from_db( + session, chat_id, thread_visibility=thread_visibility + ) + thread_result = await session.execute( + select(NewChatThread).filter(NewChatThread.id == chat_id) + ) + thread = thread_result.scalars().first() + if thread: + thread.needs_history_bootstrap = False + await session.commit() + + mentioned_surfsense_docs: list[SurfsenseDocsDocument] = [] + if mentioned_surfsense_doc_ids: + result = await session.execute( + select(SurfsenseDocsDocument) + .options(selectinload(SurfsenseDocsDocument.chunks)) + .filter(SurfsenseDocsDocument.id.in_(mentioned_surfsense_doc_ids)) + ) + mentioned_surfsense_docs = list(result.scalars().all()) + + # Top 3 reports keyed by id desc (newest first) with content present, + # surfaced inline so the LLM resolves ``report_id`` for versioning without + # digging through conversation history. + recent_reports_result = await session.execute( + select(Report) + .filter( + Report.thread_id == chat_id, + Report.content.isnot(None), + ) + .order_by(Report.id.desc()) + .limit(3) + ) + recent_reports = list(recent_reports_result.scalars().all()) + + agent_user_query, accepted_folder_ids = await _resolve_mentions_for_query( + session, + search_space_id=search_space_id, + user_query=user_query, + filesystem_mode=filesystem_mode, + mentioned_document_ids=mentioned_document_ids, + mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids, + mentioned_folder_ids=mentioned_folder_ids, + mentioned_documents=mentioned_documents, + ) + + final_query = _render_query_with_context( + agent_user_query=agent_user_query, + mentioned_surfsense_docs=mentioned_surfsense_docs, + recent_reports=recent_reports, + ) + + if thread_visibility == ChatVisibility.SEARCH_SPACE and current_user_display_name: + final_query = f"**[{current_user_display_name}]:** {final_query}" + + human_content = build_human_message_content( + final_query, list(user_image_data_urls or ()) + ) + langchain_messages.append(HumanMessage(content=human_content)) + + input_state = { + "messages": langchain_messages, + "search_space_id": search_space_id, + "request_id": request_id or "unknown", + "turn_id": turn_id, + } + + return NewChatInputState( + input_state=input_state, + accepted_folder_ids=accepted_folder_ids, + mentioned_surfsense_docs=mentioned_surfsense_docs, + ) + + +async def _resolve_mentions_for_query( + session: AsyncSession, + *, + search_space_id: int, + user_query: str, + filesystem_mode: str, + mentioned_document_ids: list[int] | None, + mentioned_surfsense_doc_ids: list[int] | None, + mentioned_folder_ids: list[int] | None, + mentioned_documents: list[dict[str, Any]] | None, +) -> tuple[str, list[int]]: + r"""Resolve @-mention chips and rewrite the user query to canonical paths. + + Cloud mode only: local-folder mode keeps the legacy ``@title`` text path + (mention support there is a follow-up task — the path scheme is + mount-rooted and the picker UI both need separate work). + + The substitution lands in the returned ``agent_user_query`` ONLY — the + original ``user_query`` (with ``@title`` tokens) flows untouched into + ``persist_user_turn`` so chip rendering on reload still works + (``UserTextPart`` → ``parseMentionSegments`` matches ``@title``, not + ``\`/documents/...\```). It also feeds the human-readable surfaces — SSE + "Processing X" status, auto thread title, memory seed — which all want + what the user typed. + """ + agent_user_query = user_query + accepted_folder_ids: list[int] = [] + + has_any_mention = bool( + mentioned_document_ids + or mentioned_surfsense_doc_ids + or mentioned_folder_ids + or mentioned_documents + ) + if filesystem_mode != FilesystemMode.CLOUD.value or not has_any_mention: + return agent_user_query, accepted_folder_ids + + from app.schemas.new_chat import MentionedDocumentInfo + + chip_objs: list[MentionedDocumentInfo] | None = None + if mentioned_documents: + chip_objs = [] + for raw in mentioned_documents: + if isinstance(raw, MentionedDocumentInfo): + chip_objs.append(raw) + continue + try: + chip_objs.append(MentionedDocumentInfo.model_validate(raw)) + except Exception: + logger.debug( + "stream_new_chat: dropping malformed mention chip %r", raw + ) + + resolved = await resolve_mentions( + session, + search_space_id=search_space_id, + mentioned_documents=chip_objs, + mentioned_document_ids=mentioned_document_ids, + mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids, + mentioned_folder_ids=mentioned_folder_ids, + ) + agent_user_query = substitute_in_text(user_query, resolved.token_to_path) + accepted_folder_ids = resolved.mentioned_folder_ids + return agent_user_query, accepted_folder_ids + + +def _render_query_with_context( + *, + agent_user_query: str, + mentioned_surfsense_docs: list[SurfsenseDocsDocument], + recent_reports: list[Report], +) -> str: + """Prepend surfsense-docs + recent-reports XML blocks to the user query.""" + context_parts: list[str] = [] + + if mentioned_surfsense_docs: + context_parts.append( + format_mentioned_surfsense_docs_as_context(mentioned_surfsense_docs) + ) + + if recent_reports: + report_lines: list[str] = [] + for r in recent_reports: + report_lines.append( + f' - report_id={r.id}, title="{r.title}", ' + f'style="{r.report_style or "detailed"}"' + ) + reports_listing = "\n".join(report_lines) + context_parts.append( + "\n" + "Previously generated reports in this conversation:\n" + f"{reports_listing}\n\n" + "If the user wants to MODIFY, REVISE, UPDATE, or ADD to one of " + "these reports, set parent_report_id to the relevant report_id above.\n" + "If the user wants a completely NEW report on a different topic, " + "leave parent_report_id unset.\n" + "" + ) + + if context_parts: + context = "\n\n".join(context_parts) + return f"{context}\n\n{agent_user_query}" + + return agent_user_query diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py new file mode 100644 index 000000000..ff5a56eec --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py @@ -0,0 +1,62 @@ +"""Vision-capability gate for image-bearing turns. + +Capability safety net for explicit (non-auto-pin) selections: a turn carrying +user-uploaded images cannot be routed to a chat config that LiteLLM's +authoritative model map *explicitly* marks as text-only (``supports_vision`` +set to False). The check is intentionally narrow — it only fires when LiteLLM +is *certain* the model can't accept image input; unknown / unmapped / +vision-capable models pass through. + +Without this guard a known-text-only model would 404 at the provider with +``"No endpoints found that support image input"``, surfacing as an opaque +``SERVER_ERROR`` SSE chunk; failing here lets us return a friendly message that +tells the user what to change. +""" + +from __future__ import annotations + +from app.agents.new_chat.llm_config import AgentConfig +from app.observability import otel as ot + + +def check_image_input_capability( + *, + user_image_data_urls: list[str] | None, + agent_config: AgentConfig | None, +) -> tuple[str, str] | None: + """Return ``(user_message, error_code)`` when the gate trips, else ``None``. + + The caller emits one terminal-error SSE frame on a non-``None`` return. + """ + if not (user_image_data_urls and agent_config is not None): + return None + + from app.services.provider_capabilities import is_known_text_only_chat_model + + agent_litellm_params = agent_config.litellm_params or {} + agent_base_model = ( + agent_litellm_params.get("base_model") + if isinstance(agent_litellm_params, dict) + else None + ) + if not is_known_text_only_chat_model( + provider=agent_config.provider, + model_name=agent_config.model_name, + base_model=agent_base_model, + custom_provider=agent_config.custom_provider, + ): + return None + + model_label = agent_config.config_name or agent_config.model_name or "model" + ot.add_event( + "quota.denied", {"quota.code": "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"} + ) + return ( + ( + f"The selected model ({model_label}) does not support " + "image input. Switch to a vision-capable model " + "(e.g. GPT-4o, Claude, Gemini) or remove the image " + "attachment and try again." + ), + "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT", + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/persistence_spawn.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/persistence_spawn.py new file mode 100644 index 000000000..9ea5d2ad6 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/persistence_spawn.py @@ -0,0 +1,129 @@ +"""Concurrent persistence tasks spawned right after the initial validation gate. + +These run *during* the rest of the pre-stream setup so we don't serialize +their latency against agent construction. Awaiting them at the SSE message-id +yield sites preserves the ghost-thread protection (the user-row INSERT must +succeed before any LLM streaming begins). + +The ``set_ai_responding`` flag flip runs fully fire-and-forget on its own +shielded session — failures only delay the "AI is responding…" UI flag, not +the response itself. +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any +from uuid import UUID + +from app.db import shielded_async_session +from app.services.chat_session_state_service import set_ai_responding +from app.tasks.chat.persistence import ( + persist_assistant_shell, + persist_user_turn, +) + +logger = logging.getLogger(__name__) + + +def spawn_set_ai_responding_bg( + *, + chat_id: int, + user_id: str | None, + background_tasks: set[asyncio.Task[Any]], +) -> None: + """Fire-and-forget: flip the per-thread AI-responding flag on its own session. + + Errors are swallowed and logged — the worst case is a stale UI flag, which + is preferable to delaying the SSE stream behind a flag write. + """ + if not user_id: + return + + async def _bg_set_ai_responding() -> None: + try: + async with shielded_async_session() as s: + await set_ai_responding(s, chat_id, UUID(user_id)) + except Exception: + logger.warning( + "set_ai_responding failed (chat_id=%s)", + chat_id, + exc_info=True, + ) + + t = asyncio.create_task(_bg_set_ai_responding()) + background_tasks.add(t) + t.add_done_callback(background_tasks.discard) + + +def spawn_persist_user_task( + *, + chat_id: int, + user_id: str | None, + turn_id: str, + user_query: str, + user_image_data_urls: list[str] | None, + mentioned_documents: list[dict[str, Any]] | None, + background_tasks: set[asyncio.Task[Any]], +) -> asyncio.Task[int | None]: + """Spawn the user-row INSERT; await at the user-message-id yield site.""" + task = asyncio.create_task( + persist_user_turn( + chat_id=chat_id, + user_id=user_id, + turn_id=turn_id, + user_query=user_query, + user_image_data_urls=user_image_data_urls, + mentioned_documents=mentioned_documents, + ) + ) + background_tasks.add(task) + task.add_done_callback(background_tasks.discard) + return task + + +def spawn_persist_assistant_shell_task( + *, + chat_id: int, + user_id: str | None, + turn_id: str, + background_tasks: set[asyncio.Task[Any]], +) -> asyncio.Task[int | None]: + """Spawn the assistant-shell INSERT; await at the assistant-message-id yield site.""" + task = asyncio.create_task( + persist_assistant_shell( + chat_id=chat_id, + user_id=user_id, + turn_id=turn_id, + ) + ) + background_tasks.add(task) + task.add_done_callback(background_tasks.discard) + return task + + +async def await_persist_task( + task: asyncio.Task[int | None] | None, + *, + chat_id: int, + turn_id: str, + log_label: str, +) -> int | None: + """Join a spawned persistence task with ``shield`` + uniform error handling. + + ``shield`` keeps the DB write alive if the SSE generator is cancelled by + client disconnect mid-await. Returns ``None`` on failure; the caller + abort-paths the turn with a friendly error SSE. + """ + if task is None: + return None + try: + return await asyncio.shield(task) + except asyncio.CancelledError: + raise + except Exception: + logger.exception( + "%s failed (chat_id=%s, turn_id=%s)", log_label, chat_id, turn_id + ) + return None diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/runtime_context.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/runtime_context.py new file mode 100644 index 000000000..1f11be1fe --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/runtime_context.py @@ -0,0 +1,38 @@ +"""Build the per-invocation ``SurfSenseContextSchema`` for a new-chat turn. + +Carries the per-turn read inputs that middlewares read via +``runtime.context.*`` instead of from their ``__init__`` closures, so the same +compiled-agent instance can serve multiple turns with different +mention lists / request ids / turn ids without rebuilding the graph. +""" + +from __future__ import annotations + +from app.agents.new_chat.context import SurfSenseContextSchema + + +def build_new_chat_runtime_context( + *, + search_space_id: int, + mentioned_document_ids: list[int] | None, + accepted_folder_ids: list[int], + mentioned_folder_ids: list[int] | None, + request_id: str | None, + turn_id: str, +) -> SurfSenseContextSchema: + """``mentioned_document_ids`` is consumed by ``KnowledgePriorityMiddleware``. + + ``accepted_folder_ids`` (post-resolve) wins over the raw + ``mentioned_folder_ids`` from the request: the resolver drops chips that + pointed at deleted folders or folders the caller can't see, so middlewares + only get authorized ids. + """ + return SurfSenseContextSchema( + search_space_id=search_space_id, + mentioned_document_ids=list(mentioned_document_ids or []), + mentioned_folder_ids=list( + accepted_folder_ids or mentioned_folder_ids or [] + ), + request_id=request_id, + turn_id=turn_id, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py new file mode 100644 index 000000000..11312110f --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py @@ -0,0 +1,237 @@ +"""Background thread-title generation (first-response only). + +The first assistant response in a thread gets a short auto-generated title +inserted into ``new_chat_threads.title``. We: + + 1. Spawn the generation as an ``asyncio.Task`` so it runs in parallel with + the agent stream (no extra TTFT). + 2. Probe inside the task (on its own shielded session) whether this is + actually the first response — newer turns short-circuit to ``None``. + 3. Inject the resulting ``thread-title-update`` SSE frame on the first agent + event after the task completes (mid-stream interlock), or right before + the finish frames (post-stream join) if the task hadn't finished yet. + +Usage tokens come directly off the response (LiteLLM's async callback fires +via fire-and-forget ``create_task``, so the ``TokenTrackingCallback`` would +run too late). We also blank the per-task accumulator so the late callback +doesn't double-count. +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import TYPE_CHECKING, Any + +from sqlalchemy.future import select + +from app.db import NewChatMessage, NewChatThread, shielded_async_session +from app.prompts import TITLE_GENERATION_PROMPT +from app.services.new_streaming_service import VercelStreamingService + +if TYPE_CHECKING: + from app.agents.new_chat.llm_config import AgentConfig + from app.services.token_tracking_service import TokenAccumulator + + +logger = logging.getLogger(__name__) + + +def spawn_title_task( + *, + chat_id: int, + user_query: str, + user_image_data_urls: list[str] | None, + assistant_message_id: int | None, + llm: Any, + agent_config: AgentConfig | None, +) -> asyncio.Task[tuple[str | None, dict | None]] | None: + """Spawn ``_generate_title``; returns ``None`` when prerequisites aren't met. + + Title gen is gated on a real ``assistant_message_id`` so a stream that + aborts before persistence can never leave a thread with a title and no + anchoring rows. + """ + if assistant_message_id is None: + return None + return asyncio.create_task( + _generate_title( + chat_id=chat_id, + user_query=user_query, + user_image_data_urls=user_image_data_urls, + assistant_message_id=assistant_message_id, + llm=llm, + agent_config=agent_config, + ) + ) + + +async def _generate_title( + *, + chat_id: int, + user_query: str, + user_image_data_urls: list[str] | None, + assistant_message_id: int, + llm: Any, + agent_config: AgentConfig | None, +) -> tuple[str | None, dict | None]: + """Probe is-first-response, then call ``acompletion``. Returns ``(title, usage)``.""" + try: + from litellm import acompletion + + from app.services.llm_router_service import LLMRouterService + from app.services.provider_api_base import resolve_api_base + from app.services.token_tracking_service import _turn_accumulator + + # Excludes this turn's own assistant row (pre-written by + # ``persist_assistant_shell``) — without the ``!=`` filter the gate + # would false-negative on every turn after the first. + try: + async with shielded_async_session() as probe_session: + probe_result = await probe_session.execute( + select(NewChatMessage.id) + .filter( + NewChatMessage.thread_id == chat_id, + NewChatMessage.role == "assistant", + NewChatMessage.id != assistant_message_id, + ) + .limit(1) + ) + is_first_response = probe_result.scalars().first() is None + except Exception: + logger.warning( + "[TitleGen] first-response probe failed (chat_id=%s)", + chat_id, + exc_info=True, + ) + return None, None + + if not is_first_response: + return None, None + + _turn_accumulator.set(None) + + title_seed = user_query.strip() or ( + f"[{len(user_image_data_urls or [])} image(s)]" + if user_image_data_urls + else "" + ) + prompt = TITLE_GENERATION_PROMPT.replace( + "{user_query}", title_seed[:500] or "(message)" + ) + messages = [{"role": "user", "content": prompt}] + + if getattr(llm, "model", None) == "auto": + router = LLMRouterService.get_router() + response = await router.acompletion(model="auto", messages=messages) + else: + # Apply the same ``api_base`` cascade chat / vision / image-gen + # call sites use so we never inherit ``litellm.api_base`` + # (commonly set by ``AZURE_OPENAI_ENDPOINT``) when the chat + # config itself ships an empty ``api_base``. Without this the + # title-gen on an OpenRouter chat config would 404 against the + # inherited Azure endpoint — see ``provider_api_base`` for the + # same bug repro on the image-gen / vision paths. + raw_model = getattr(llm, "model", "") or "" + provider_prefix = ( + raw_model.split("/", 1)[0] if "/" in raw_model else None + ) + provider_value = ( + agent_config.provider if agent_config is not None else None + ) + title_api_base = resolve_api_base( + provider=provider_value, + provider_prefix=provider_prefix, + config_api_base=getattr(llm, "api_base", None), + ) + response = await acompletion( + model=raw_model, + messages=messages, + api_key=getattr(llm, "api_key", None), + api_base=title_api_base, + ) + + usage_info = None + usage = getattr(response, "usage", None) + if usage: + raw_model = getattr(llm, "model", "") or "" + model_name = ( + raw_model.split("/", 1)[-1] + if "/" in raw_model + else (raw_model or response.model or "unknown") + ) + usage_info = { + "model": model_name, + "prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0, + "completion_tokens": getattr(usage, "completion_tokens", 0) or 0, + "total_tokens": getattr(usage, "total_tokens", 0) or 0, + } + + raw_title = response.choices[0].message.content.strip() + if raw_title and len(raw_title) <= 100: + return raw_title.strip("\"'"), usage_info + return None, usage_info + except Exception: + logger.exception("[TitleGen] _generate_title failed") + return None, None + + +async def maybe_emit_title_update( + *, + title_task: asyncio.Task[tuple[str | None, dict | None]] | None, + title_emitted: bool, + chat_id: int, + accumulator: TokenAccumulator, + streaming_service: VercelStreamingService, +): + """Inject one ``thread-title-update`` SSE if the task completed. + + Yields the SSE frame (when applicable). Returns nothing; the orchestrator + flips ``title_emitted`` itself after iterating so we don't fight Python's + nonlocal-in-generator semantics. + """ + if title_task is None or title_emitted or not title_task.done(): + return + generated_title, title_usage = title_task.result() + if title_usage: + accumulator.add(**title_usage) + if generated_title: + async with shielded_async_session() as title_session: + title_thread_result = await title_session.execute( + select(NewChatThread).filter(NewChatThread.id == chat_id) + ) + title_thread = title_thread_result.scalars().first() + if title_thread: + title_thread.title = generated_title + await title_session.commit() + yield streaming_service.format_thread_title_update(chat_id, generated_title) + + +async def await_pending_title_update( + *, + title_task: asyncio.Task[tuple[str | None, dict | None]] | None, + title_emitted: bool, + chat_id: int, + accumulator: TokenAccumulator, + streaming_service: VercelStreamingService, +): + """If the task hadn't completed during the stream, await it now and emit. + + Used right before the finish frames in the success path. Mirror of + ``maybe_emit_title_update`` but unconditionally awaits. + """ + if title_task is None or title_emitted: + return + generated_title, title_usage = await title_task + if title_usage: + accumulator.add(**title_usage) + if generated_title: + async with shielded_async_session() as title_session: + title_thread_result = await title_session.execute( + select(NewChatThread).filter(NewChatThread.id == chat_id) + ) + title_thread = title_thread_result.scalars().first() + if title_thread: + title_thread.title = generated_title + await title_session.commit() + yield streaming_service.format_thread_title_update(chat_id, generated_title)