mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-29 19:35:20 +02:00
refactor(chat): add streaming/flows/new_chat/ per-concern leaf modules
Seven focused modules that the upcoming new_chat orchestrator
composes:
* auto_pin: resolve_initial_auto_pin selects the initial config (with
vision-capable filtering and error classification).
* llm_capability: check_image_input_capability blocks routing an
image-bearing turn to a known text-only model.
* runtime_context: build_new_chat_runtime_context assembles the
SurfSenseContextSchema for a new-chat turn.
* persistence_spawn: spawn_set_ai_responding_bg, spawn_persist_user_task,
spawn_persist_assistant_shell_task, and await_persist_task background
the four pre-stream DB writes so they overlap with agent build.
* initial_thinking_step: build_initial_thinking_step +
iter_initial_thinking_step_frame produce the very first thinking-1 SSE
step ("Understanding your request" / "Analyzing referenced content").
* title_gen: spawn_title_task + maybe_emit_title_update +
await_pending_title_update background the thread-title generator and
interleave its update into the stream when ready.
* input_state: build_new_chat_input_state assembles the LangGraph
input_state (history bootstrap, mentions resolution, context blocks,
human-message construction). The heavy one.
Add-only; no orchestrator yet (next commit).
This commit is contained in:
parent
21bddc73a7
commit
927009745e
7 changed files with 920 additions and 0 deletions
|
|
@ -0,0 +1,95 @@
|
|||
"""Resolve the auto-pin for the *initial* turn config.
|
||||
|
||||
Auto-pin (``selected_llm_config_id=0``) picks the best eligible LLM config for
|
||||
this thread / search space / user, optionally filtered to vision-capable
|
||||
configs when the turn carries images.
|
||||
|
||||
Errors classified here:
|
||||
|
||||
* ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` — the auto-pin pool has no
|
||||
vision-capable cfg for an image-bearing turn. The same gate fires later
|
||||
in ``llm_capability`` for explicit selections; mapping both to the same
|
||||
code keeps the FE error UI consistent.
|
||||
* ``SERVER_ERROR`` — any other ``ValueError`` from the resolver.
|
||||
|
||||
This module owns *initial* pin resolution; the rate-limit recovery loop has
|
||||
its own narrower auto-pin call (with ``exclude_config_ids``) in
|
||||
``flows/shared/rate_limit_recovery``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.observability import otel as ot
|
||||
from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoPinResult:
|
||||
"""Outcome of ``resolve_initial_auto_pin``.
|
||||
|
||||
``llm_config_id`` is set when ``error`` is ``None``; ``error`` carries the
|
||||
classified user-facing message plus error code/kind so the orchestrator can
|
||||
emit one terminal-error SSE frame.
|
||||
"""
|
||||
|
||||
llm_config_id: int | None
|
||||
error: tuple[str, str, Literal["user_error", "server_error"]] | None
|
||||
|
||||
|
||||
async def resolve_initial_auto_pin(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
chat_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
selected_llm_config_id: int,
|
||||
requires_image_input: bool,
|
||||
requested_llm_config_id: int,
|
||||
) -> AutoPinResult:
|
||||
"""Run the resolver and classify any ``ValueError`` for the SSE error path."""
|
||||
try:
|
||||
pinned = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=selected_llm_config_id,
|
||||
requires_image_input=requires_image_input,
|
||||
)
|
||||
ot.add_event(
|
||||
"model.pin.resolved",
|
||||
{
|
||||
"pin.requested_id": requested_llm_config_id,
|
||||
"pin.resolved_id": pinned.resolved_llm_config_id,
|
||||
"pin.requires_image_input": requires_image_input,
|
||||
},
|
||||
)
|
||||
return AutoPinResult(
|
||||
llm_config_id=pinned.resolved_llm_config_id, error=None
|
||||
)
|
||||
except ValueError as pin_error:
|
||||
# The "no vision-capable cfg" path raises a ValueError whose message
|
||||
# we map to the friendly image-input SSE error so the user sees the
|
||||
# same message regardless of whether the gate fired in the resolver or
|
||||
# in ``llm_capability.assert_vision_capability_for_image_turn``.
|
||||
is_vision_failure = (
|
||||
requires_image_input and "vision-capable" in str(pin_error)
|
||||
)
|
||||
error_code = (
|
||||
"MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
|
||||
if is_vision_failure
|
||||
else "SERVER_ERROR"
|
||||
)
|
||||
error_kind: Literal["user_error", "server_error"] = (
|
||||
"user_error" if is_vision_failure else "server_error"
|
||||
)
|
||||
if is_vision_failure:
|
||||
ot.add_event("quota.denied", {"quota.code": error_code})
|
||||
return AutoPinResult(
|
||||
llm_config_id=None, error=(str(pin_error), error_code, error_kind)
|
||||
)
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
"""Build and emit the first ``thinking-1`` step for a new-chat turn.
|
||||
|
||||
The step title and "Processing X" items are derived from what the user sent
|
||||
(text snippet, image count, mentioned doc titles) so the FE can render a
|
||||
meaningful placeholder while the agent stream warms up.
|
||||
|
||||
``thinking-1`` is the canonical id for this step — every subsequent
|
||||
``thinking-N`` produced by ``stream_agent_events`` folds into the same
|
||||
singleton ``data-thinking-steps`` part on the FE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from app.db import SurfsenseDocsDocument
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
|
||||
|
||||
@dataclass
|
||||
class InitialThinkingStep:
|
||||
"""Resolved fields passed both into the SSE frame and the builder hook.
|
||||
|
||||
``items`` is the bullet list under the step title; ``title`` is the
|
||||
one-line step header. ``step_id`` is hard-coded ``thinking-1`` so the FE
|
||||
Timeline can de-duplicate against the prior assistant message on resume.
|
||||
"""
|
||||
|
||||
step_id: str
|
||||
title: str
|
||||
items: list[str]
|
||||
|
||||
|
||||
def build_initial_thinking_step(
|
||||
*,
|
||||
user_query: str,
|
||||
user_image_data_urls: list[str] | None,
|
||||
mentioned_surfsense_docs: list[SurfsenseDocsDocument],
|
||||
) -> InitialThinkingStep:
|
||||
if mentioned_surfsense_docs:
|
||||
title = "Analyzing referenced content"
|
||||
action_verb = "Analyzing"
|
||||
else:
|
||||
title = "Understanding your request"
|
||||
action_verb = "Processing"
|
||||
|
||||
processing_parts: list[str] = []
|
||||
if user_query.strip():
|
||||
query_text = user_query[:80] + ("..." if len(user_query) > 80 else "")
|
||||
processing_parts.append(query_text)
|
||||
elif user_image_data_urls:
|
||||
processing_parts.append(f"[{len(user_image_data_urls)} image(s)]")
|
||||
else:
|
||||
processing_parts.append("(message)")
|
||||
|
||||
if mentioned_surfsense_docs:
|
||||
doc_names: list[str] = []
|
||||
for doc in mentioned_surfsense_docs:
|
||||
t = doc.title
|
||||
if len(t) > 30:
|
||||
t = t[:27] + "..."
|
||||
doc_names.append(t)
|
||||
if len(doc_names) == 1:
|
||||
processing_parts.append(f"[{doc_names[0]}]")
|
||||
else:
|
||||
processing_parts.append(f"[{len(doc_names)} docs]")
|
||||
|
||||
items = [f"{action_verb}: {' '.join(processing_parts)}"]
|
||||
return InitialThinkingStep(step_id="thinking-1", title=title, items=items)
|
||||
|
||||
|
||||
def iter_initial_thinking_step_frame(
|
||||
step: InitialThinkingStep,
|
||||
*,
|
||||
streaming_service: VercelStreamingService,
|
||||
content_builder: Any | None,
|
||||
) -> Iterator[str]:
|
||||
"""Drive both the SSE emission and the builder hook for the initial step.
|
||||
|
||||
The FE folds this step into the same singleton ``data-thinking-steps`` part
|
||||
as everything the agent stream emits later, so we mirror that fold
|
||||
server-side by driving the builder lifecycle ourselves.
|
||||
"""
|
||||
if content_builder is not None:
|
||||
content_builder.on_thinking_step(
|
||||
step.step_id, step.title, "in_progress", step.items
|
||||
)
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=step.step_id,
|
||||
title=step.title,
|
||||
status="in_progress",
|
||||
items=step.items,
|
||||
)
|
||||
|
|
@ -0,0 +1,264 @@
|
|||
r"""Assemble the LangGraph ``input_state`` for the new-chat turn.
|
||||
|
||||
Pipeline:
|
||||
|
||||
1. **History bootstrap** — only for cloned chats with no LangGraph checkpoint
|
||||
yet; flips the per-thread ``needs_history_bootstrap`` flag back to False
|
||||
once the rows are loaded.
|
||||
2. **Mentioned SurfSense docs** — eager-load chunks so the formatter has the
|
||||
full content without a second roundtrip.
|
||||
3. **Recent reports** — top 3 by id desc with non-null content, so the LLM
|
||||
can resolve ``report_id`` for versioning without spelunking history.
|
||||
4. **@-mention resolve** (cloud mode) — substitute ``@title`` tokens in the
|
||||
query with canonical ``\`/documents/...\``` paths the LLM expects.
|
||||
5. **Context block render** — XML-wrap surfsense docs + reports, prepend to
|
||||
the rewritten query, optionally prefix with display name for SEARCH_SPACE
|
||||
visibility.
|
||||
6. **HumanMessage** — multimodal content if images are attached.
|
||||
|
||||
Returns the assembled ``input_state`` dict plus side-channel data the
|
||||
orchestrator needs downstream (``accepted_folder_ids`` for runtime context;
|
||||
``mentioned_surfsense_docs`` for the initial thinking step).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.mention_resolver import resolve_mentions, substitute_in_text
|
||||
from app.db import (
|
||||
ChatVisibility,
|
||||
NewChatThread,
|
||||
Report,
|
||||
SurfsenseDocsDocument,
|
||||
)
|
||||
from app.tasks.chat.streaming.context.mentioned_docs import (
|
||||
format_mentioned_surfsense_docs_as_context,
|
||||
)
|
||||
from app.utils.content_utils import bootstrap_history_from_db
|
||||
from app.utils.user_message_multimodal import build_human_message_content
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewChatInputState:
|
||||
"""Everything ``build_new_chat_input_state`` produces.
|
||||
|
||||
``input_state`` is fed straight to the agent. ``accepted_folder_ids``
|
||||
feeds the runtime context (the resolver may have dropped some chips).
|
||||
``mentioned_surfsense_docs`` is consumed by the initial thinking-step
|
||||
builder for the FE placeholder before the agent stream starts.
|
||||
"""
|
||||
|
||||
input_state: dict[str, Any]
|
||||
accepted_folder_ids: list[int]
|
||||
mentioned_surfsense_docs: list[SurfsenseDocsDocument]
|
||||
|
||||
|
||||
async def build_new_chat_input_state(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
chat_id: int,
|
||||
search_space_id: int,
|
||||
user_query: str,
|
||||
user_image_data_urls: list[str] | None,
|
||||
mentioned_document_ids: list[int] | None,
|
||||
mentioned_surfsense_doc_ids: list[int] | None,
|
||||
mentioned_folder_ids: list[int] | None,
|
||||
mentioned_documents: list[dict[str, Any]] | None,
|
||||
needs_history_bootstrap: bool,
|
||||
thread_visibility: ChatVisibility,
|
||||
current_user_display_name: str | None,
|
||||
filesystem_mode: str,
|
||||
request_id: str | None,
|
||||
turn_id: str,
|
||||
) -> NewChatInputState:
|
||||
langchain_messages: list[Any] = []
|
||||
|
||||
if needs_history_bootstrap:
|
||||
langchain_messages = await bootstrap_history_from_db(
|
||||
session, chat_id, thread_visibility=thread_visibility
|
||||
)
|
||||
thread_result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == chat_id)
|
||||
)
|
||||
thread = thread_result.scalars().first()
|
||||
if thread:
|
||||
thread.needs_history_bootstrap = False
|
||||
await session.commit()
|
||||
|
||||
mentioned_surfsense_docs: list[SurfsenseDocsDocument] = []
|
||||
if mentioned_surfsense_doc_ids:
|
||||
result = await session.execute(
|
||||
select(SurfsenseDocsDocument)
|
||||
.options(selectinload(SurfsenseDocsDocument.chunks))
|
||||
.filter(SurfsenseDocsDocument.id.in_(mentioned_surfsense_doc_ids))
|
||||
)
|
||||
mentioned_surfsense_docs = list(result.scalars().all())
|
||||
|
||||
# Top 3 reports keyed by id desc (newest first) with content present,
|
||||
# surfaced inline so the LLM resolves ``report_id`` for versioning without
|
||||
# digging through conversation history.
|
||||
recent_reports_result = await session.execute(
|
||||
select(Report)
|
||||
.filter(
|
||||
Report.thread_id == chat_id,
|
||||
Report.content.isnot(None),
|
||||
)
|
||||
.order_by(Report.id.desc())
|
||||
.limit(3)
|
||||
)
|
||||
recent_reports = list(recent_reports_result.scalars().all())
|
||||
|
||||
agent_user_query, accepted_folder_ids = await _resolve_mentions_for_query(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
user_query=user_query,
|
||||
filesystem_mode=filesystem_mode,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids,
|
||||
mentioned_folder_ids=mentioned_folder_ids,
|
||||
mentioned_documents=mentioned_documents,
|
||||
)
|
||||
|
||||
final_query = _render_query_with_context(
|
||||
agent_user_query=agent_user_query,
|
||||
mentioned_surfsense_docs=mentioned_surfsense_docs,
|
||||
recent_reports=recent_reports,
|
||||
)
|
||||
|
||||
if thread_visibility == ChatVisibility.SEARCH_SPACE and current_user_display_name:
|
||||
final_query = f"**[{current_user_display_name}]:** {final_query}"
|
||||
|
||||
human_content = build_human_message_content(
|
||||
final_query, list(user_image_data_urls or ())
|
||||
)
|
||||
langchain_messages.append(HumanMessage(content=human_content))
|
||||
|
||||
input_state = {
|
||||
"messages": langchain_messages,
|
||||
"search_space_id": search_space_id,
|
||||
"request_id": request_id or "unknown",
|
||||
"turn_id": turn_id,
|
||||
}
|
||||
|
||||
return NewChatInputState(
|
||||
input_state=input_state,
|
||||
accepted_folder_ids=accepted_folder_ids,
|
||||
mentioned_surfsense_docs=mentioned_surfsense_docs,
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_mentions_for_query(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
user_query: str,
|
||||
filesystem_mode: str,
|
||||
mentioned_document_ids: list[int] | None,
|
||||
mentioned_surfsense_doc_ids: list[int] | None,
|
||||
mentioned_folder_ids: list[int] | None,
|
||||
mentioned_documents: list[dict[str, Any]] | None,
|
||||
) -> tuple[str, list[int]]:
|
||||
r"""Resolve @-mention chips and rewrite the user query to canonical paths.
|
||||
|
||||
Cloud mode only: local-folder mode keeps the legacy ``@title`` text path
|
||||
(mention support there is a follow-up task — the path scheme is
|
||||
mount-rooted and the picker UI both need separate work).
|
||||
|
||||
The substitution lands in the returned ``agent_user_query`` ONLY — the
|
||||
original ``user_query`` (with ``@title`` tokens) flows untouched into
|
||||
``persist_user_turn`` so chip rendering on reload still works
|
||||
(``UserTextPart`` → ``parseMentionSegments`` matches ``@title``, not
|
||||
``\`/documents/...\```). It also feeds the human-readable surfaces — SSE
|
||||
"Processing X" status, auto thread title, memory seed — which all want
|
||||
what the user typed.
|
||||
"""
|
||||
agent_user_query = user_query
|
||||
accepted_folder_ids: list[int] = []
|
||||
|
||||
has_any_mention = bool(
|
||||
mentioned_document_ids
|
||||
or mentioned_surfsense_doc_ids
|
||||
or mentioned_folder_ids
|
||||
or mentioned_documents
|
||||
)
|
||||
if filesystem_mode != FilesystemMode.CLOUD.value or not has_any_mention:
|
||||
return agent_user_query, accepted_folder_ids
|
||||
|
||||
from app.schemas.new_chat import MentionedDocumentInfo
|
||||
|
||||
chip_objs: list[MentionedDocumentInfo] | None = None
|
||||
if mentioned_documents:
|
||||
chip_objs = []
|
||||
for raw in mentioned_documents:
|
||||
if isinstance(raw, MentionedDocumentInfo):
|
||||
chip_objs.append(raw)
|
||||
continue
|
||||
try:
|
||||
chip_objs.append(MentionedDocumentInfo.model_validate(raw))
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"stream_new_chat: dropping malformed mention chip %r", raw
|
||||
)
|
||||
|
||||
resolved = await resolve_mentions(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
mentioned_documents=chip_objs,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids,
|
||||
mentioned_folder_ids=mentioned_folder_ids,
|
||||
)
|
||||
agent_user_query = substitute_in_text(user_query, resolved.token_to_path)
|
||||
accepted_folder_ids = resolved.mentioned_folder_ids
|
||||
return agent_user_query, accepted_folder_ids
|
||||
|
||||
|
||||
def _render_query_with_context(
|
||||
*,
|
||||
agent_user_query: str,
|
||||
mentioned_surfsense_docs: list[SurfsenseDocsDocument],
|
||||
recent_reports: list[Report],
|
||||
) -> str:
|
||||
"""Prepend surfsense-docs + recent-reports XML blocks to the user query."""
|
||||
context_parts: list[str] = []
|
||||
|
||||
if mentioned_surfsense_docs:
|
||||
context_parts.append(
|
||||
format_mentioned_surfsense_docs_as_context(mentioned_surfsense_docs)
|
||||
)
|
||||
|
||||
if recent_reports:
|
||||
report_lines: list[str] = []
|
||||
for r in recent_reports:
|
||||
report_lines.append(
|
||||
f' - report_id={r.id}, title="{r.title}", '
|
||||
f'style="{r.report_style or "detailed"}"'
|
||||
)
|
||||
reports_listing = "\n".join(report_lines)
|
||||
context_parts.append(
|
||||
"<report_context>\n"
|
||||
"Previously generated reports in this conversation:\n"
|
||||
f"{reports_listing}\n\n"
|
||||
"If the user wants to MODIFY, REVISE, UPDATE, or ADD to one of "
|
||||
"these reports, set parent_report_id to the relevant report_id above.\n"
|
||||
"If the user wants a completely NEW report on a different topic, "
|
||||
"leave parent_report_id unset.\n"
|
||||
"</report_context>"
|
||||
)
|
||||
|
||||
if context_parts:
|
||||
context = "\n\n".join(context_parts)
|
||||
return f"{context}\n\n<user_query>{agent_user_query}</user_query>"
|
||||
|
||||
return agent_user_query
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
"""Vision-capability gate for image-bearing turns.
|
||||
|
||||
Capability safety net for explicit (non-auto-pin) selections: a turn carrying
|
||||
user-uploaded images cannot be routed to a chat config that LiteLLM's
|
||||
authoritative model map *explicitly* marks as text-only (``supports_vision``
|
||||
set to False). The check is intentionally narrow — it only fires when LiteLLM
|
||||
is *certain* the model can't accept image input; unknown / unmapped /
|
||||
vision-capable models pass through.
|
||||
|
||||
Without this guard a known-text-only model would 404 at the provider with
|
||||
``"No endpoints found that support image input"``, surfacing as an opaque
|
||||
``SERVER_ERROR`` SSE chunk; failing here lets us return a friendly message that
|
||||
tells the user what to change.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
from app.observability import otel as ot
|
||||
|
||||
|
||||
def check_image_input_capability(
|
||||
*,
|
||||
user_image_data_urls: list[str] | None,
|
||||
agent_config: AgentConfig | None,
|
||||
) -> tuple[str, str] | None:
|
||||
"""Return ``(user_message, error_code)`` when the gate trips, else ``None``.
|
||||
|
||||
The caller emits one terminal-error SSE frame on a non-``None`` return.
|
||||
"""
|
||||
if not (user_image_data_urls and agent_config is not None):
|
||||
return None
|
||||
|
||||
from app.services.provider_capabilities import is_known_text_only_chat_model
|
||||
|
||||
agent_litellm_params = agent_config.litellm_params or {}
|
||||
agent_base_model = (
|
||||
agent_litellm_params.get("base_model")
|
||||
if isinstance(agent_litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
if not is_known_text_only_chat_model(
|
||||
provider=agent_config.provider,
|
||||
model_name=agent_config.model_name,
|
||||
base_model=agent_base_model,
|
||||
custom_provider=agent_config.custom_provider,
|
||||
):
|
||||
return None
|
||||
|
||||
model_label = agent_config.config_name or agent_config.model_name or "model"
|
||||
ot.add_event(
|
||||
"quota.denied", {"quota.code": "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"}
|
||||
)
|
||||
return (
|
||||
(
|
||||
f"The selected model ({model_label}) does not support "
|
||||
"image input. Switch to a vision-capable model "
|
||||
"(e.g. GPT-4o, Claude, Gemini) or remove the image "
|
||||
"attachment and try again."
|
||||
),
|
||||
"MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT",
|
||||
)
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
"""Concurrent persistence tasks spawned right after the initial validation gate.
|
||||
|
||||
These run *during* the rest of the pre-stream setup so we don't serialize
|
||||
their latency against agent construction. Awaiting them at the SSE message-id
|
||||
yield sites preserves the ghost-thread protection (the user-row INSERT must
|
||||
succeed before any LLM streaming begins).
|
||||
|
||||
The ``set_ai_responding`` flag flip runs fully fire-and-forget on its own
|
||||
shielded session — failures only delay the "AI is responding…" UI flag, not
|
||||
the response itself.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from app.db import shielded_async_session
|
||||
from app.services.chat_session_state_service import set_ai_responding
|
||||
from app.tasks.chat.persistence import (
|
||||
persist_assistant_shell,
|
||||
persist_user_turn,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def spawn_set_ai_responding_bg(
|
||||
*,
|
||||
chat_id: int,
|
||||
user_id: str | None,
|
||||
background_tasks: set[asyncio.Task[Any]],
|
||||
) -> None:
|
||||
"""Fire-and-forget: flip the per-thread AI-responding flag on its own session.
|
||||
|
||||
Errors are swallowed and logged — the worst case is a stale UI flag, which
|
||||
is preferable to delaying the SSE stream behind a flag write.
|
||||
"""
|
||||
if not user_id:
|
||||
return
|
||||
|
||||
async def _bg_set_ai_responding() -> None:
|
||||
try:
|
||||
async with shielded_async_session() as s:
|
||||
await set_ai_responding(s, chat_id, UUID(user_id))
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"set_ai_responding failed (chat_id=%s)",
|
||||
chat_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
t = asyncio.create_task(_bg_set_ai_responding())
|
||||
background_tasks.add(t)
|
||||
t.add_done_callback(background_tasks.discard)
|
||||
|
||||
|
||||
def spawn_persist_user_task(
|
||||
*,
|
||||
chat_id: int,
|
||||
user_id: str | None,
|
||||
turn_id: str,
|
||||
user_query: str,
|
||||
user_image_data_urls: list[str] | None,
|
||||
mentioned_documents: list[dict[str, Any]] | None,
|
||||
background_tasks: set[asyncio.Task[Any]],
|
||||
) -> asyncio.Task[int | None]:
|
||||
"""Spawn the user-row INSERT; await at the user-message-id yield site."""
|
||||
task = asyncio.create_task(
|
||||
persist_user_turn(
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
turn_id=turn_id,
|
||||
user_query=user_query,
|
||||
user_image_data_urls=user_image_data_urls,
|
||||
mentioned_documents=mentioned_documents,
|
||||
)
|
||||
)
|
||||
background_tasks.add(task)
|
||||
task.add_done_callback(background_tasks.discard)
|
||||
return task
|
||||
|
||||
|
||||
def spawn_persist_assistant_shell_task(
|
||||
*,
|
||||
chat_id: int,
|
||||
user_id: str | None,
|
||||
turn_id: str,
|
||||
background_tasks: set[asyncio.Task[Any]],
|
||||
) -> asyncio.Task[int | None]:
|
||||
"""Spawn the assistant-shell INSERT; await at the assistant-message-id yield site."""
|
||||
task = asyncio.create_task(
|
||||
persist_assistant_shell(
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
)
|
||||
background_tasks.add(task)
|
||||
task.add_done_callback(background_tasks.discard)
|
||||
return task
|
||||
|
||||
|
||||
async def await_persist_task(
|
||||
task: asyncio.Task[int | None] | None,
|
||||
*,
|
||||
chat_id: int,
|
||||
turn_id: str,
|
||||
log_label: str,
|
||||
) -> int | None:
|
||||
"""Join a spawned persistence task with ``shield`` + uniform error handling.
|
||||
|
||||
``shield`` keeps the DB write alive if the SSE generator is cancelled by
|
||||
client disconnect mid-await. Returns ``None`` on failure; the caller
|
||||
abort-paths the turn with a friendly error SSE.
|
||||
"""
|
||||
if task is None:
|
||||
return None
|
||||
try:
|
||||
return await asyncio.shield(task)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"%s failed (chat_id=%s, turn_id=%s)", log_label, chat_id, turn_id
|
||||
)
|
||||
return None
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
"""Build the per-invocation ``SurfSenseContextSchema`` for a new-chat turn.
|
||||
|
||||
Carries the per-turn read inputs that middlewares read via
|
||||
``runtime.context.*`` instead of from their ``__init__`` closures, so the same
|
||||
compiled-agent instance can serve multiple turns with different
|
||||
mention lists / request ids / turn ids without rebuilding the graph.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||
|
||||
|
||||
def build_new_chat_runtime_context(
|
||||
*,
|
||||
search_space_id: int,
|
||||
mentioned_document_ids: list[int] | None,
|
||||
accepted_folder_ids: list[int],
|
||||
mentioned_folder_ids: list[int] | None,
|
||||
request_id: str | None,
|
||||
turn_id: str,
|
||||
) -> SurfSenseContextSchema:
|
||||
"""``mentioned_document_ids`` is consumed by ``KnowledgePriorityMiddleware``.
|
||||
|
||||
``accepted_folder_ids`` (post-resolve) wins over the raw
|
||||
``mentioned_folder_ids`` from the request: the resolver drops chips that
|
||||
pointed at deleted folders or folders the caller can't see, so middlewares
|
||||
only get authorized ids.
|
||||
"""
|
||||
return SurfSenseContextSchema(
|
||||
search_space_id=search_space_id,
|
||||
mentioned_document_ids=list(mentioned_document_ids or []),
|
||||
mentioned_folder_ids=list(
|
||||
accepted_folder_ids or mentioned_folder_ids or []
|
||||
),
|
||||
request_id=request_id,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
|
|
@ -0,0 +1,237 @@
|
|||
"""Background thread-title generation (first-response only).
|
||||
|
||||
The first assistant response in a thread gets a short auto-generated title
|
||||
inserted into ``new_chat_threads.title``. We:
|
||||
|
||||
1. Spawn the generation as an ``asyncio.Task`` so it runs in parallel with
|
||||
the agent stream (no extra TTFT).
|
||||
2. Probe inside the task (on its own shielded session) whether this is
|
||||
actually the first response — newer turns short-circuit to ``None``.
|
||||
3. Inject the resulting ``thread-title-update`` SSE frame on the first agent
|
||||
event after the task completes (mid-stream interlock), or right before
|
||||
the finish frames (post-stream join) if the task hadn't finished yet.
|
||||
|
||||
Usage tokens come directly off the response (LiteLLM's async callback fires
|
||||
via fire-and-forget ``create_task``, so the ``TokenTrackingCallback`` would
|
||||
run too late). We also blank the per-task accumulator so the late callback
|
||||
doesn't double-count.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import NewChatMessage, NewChatThread, shielded_async_session
|
||||
from app.prompts import TITLE_GENERATION_PROMPT
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
from app.services.token_tracking_service import TokenAccumulator
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def spawn_title_task(
|
||||
*,
|
||||
chat_id: int,
|
||||
user_query: str,
|
||||
user_image_data_urls: list[str] | None,
|
||||
assistant_message_id: int | None,
|
||||
llm: Any,
|
||||
agent_config: AgentConfig | None,
|
||||
) -> asyncio.Task[tuple[str | None, dict | None]] | None:
|
||||
"""Spawn ``_generate_title``; returns ``None`` when prerequisites aren't met.
|
||||
|
||||
Title gen is gated on a real ``assistant_message_id`` so a stream that
|
||||
aborts before persistence can never leave a thread with a title and no
|
||||
anchoring rows.
|
||||
"""
|
||||
if assistant_message_id is None:
|
||||
return None
|
||||
return asyncio.create_task(
|
||||
_generate_title(
|
||||
chat_id=chat_id,
|
||||
user_query=user_query,
|
||||
user_image_data_urls=user_image_data_urls,
|
||||
assistant_message_id=assistant_message_id,
|
||||
llm=llm,
|
||||
agent_config=agent_config,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _generate_title(
|
||||
*,
|
||||
chat_id: int,
|
||||
user_query: str,
|
||||
user_image_data_urls: list[str] | None,
|
||||
assistant_message_id: int,
|
||||
llm: Any,
|
||||
agent_config: AgentConfig | None,
|
||||
) -> tuple[str | None, dict | None]:
|
||||
"""Probe is-first-response, then call ``acompletion``. Returns ``(title, usage)``."""
|
||||
try:
|
||||
from litellm import acompletion
|
||||
|
||||
from app.services.llm_router_service import LLMRouterService
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.services.token_tracking_service import _turn_accumulator
|
||||
|
||||
# Excludes this turn's own assistant row (pre-written by
|
||||
# ``persist_assistant_shell``) — without the ``!=`` filter the gate
|
||||
# would false-negative on every turn after the first.
|
||||
try:
|
||||
async with shielded_async_session() as probe_session:
|
||||
probe_result = await probe_session.execute(
|
||||
select(NewChatMessage.id)
|
||||
.filter(
|
||||
NewChatMessage.thread_id == chat_id,
|
||||
NewChatMessage.role == "assistant",
|
||||
NewChatMessage.id != assistant_message_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
is_first_response = probe_result.scalars().first() is None
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[TitleGen] first-response probe failed (chat_id=%s)",
|
||||
chat_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return None, None
|
||||
|
||||
if not is_first_response:
|
||||
return None, None
|
||||
|
||||
_turn_accumulator.set(None)
|
||||
|
||||
title_seed = user_query.strip() or (
|
||||
f"[{len(user_image_data_urls or [])} image(s)]"
|
||||
if user_image_data_urls
|
||||
else ""
|
||||
)
|
||||
prompt = TITLE_GENERATION_PROMPT.replace(
|
||||
"{user_query}", title_seed[:500] or "(message)"
|
||||
)
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
if getattr(llm, "model", None) == "auto":
|
||||
router = LLMRouterService.get_router()
|
||||
response = await router.acompletion(model="auto", messages=messages)
|
||||
else:
|
||||
# Apply the same ``api_base`` cascade chat / vision / image-gen
|
||||
# call sites use so we never inherit ``litellm.api_base``
|
||||
# (commonly set by ``AZURE_OPENAI_ENDPOINT``) when the chat
|
||||
# config itself ships an empty ``api_base``. Without this the
|
||||
# title-gen on an OpenRouter chat config would 404 against the
|
||||
# inherited Azure endpoint — see ``provider_api_base`` for the
|
||||
# same bug repro on the image-gen / vision paths.
|
||||
raw_model = getattr(llm, "model", "") or ""
|
||||
provider_prefix = (
|
||||
raw_model.split("/", 1)[0] if "/" in raw_model else None
|
||||
)
|
||||
provider_value = (
|
||||
agent_config.provider if agent_config is not None else None
|
||||
)
|
||||
title_api_base = resolve_api_base(
|
||||
provider=provider_value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=getattr(llm, "api_base", None),
|
||||
)
|
||||
response = await acompletion(
|
||||
model=raw_model,
|
||||
messages=messages,
|
||||
api_key=getattr(llm, "api_key", None),
|
||||
api_base=title_api_base,
|
||||
)
|
||||
|
||||
usage_info = None
|
||||
usage = getattr(response, "usage", None)
|
||||
if usage:
|
||||
raw_model = getattr(llm, "model", "") or ""
|
||||
model_name = (
|
||||
raw_model.split("/", 1)[-1]
|
||||
if "/" in raw_model
|
||||
else (raw_model or response.model or "unknown")
|
||||
)
|
||||
usage_info = {
|
||||
"model": model_name,
|
||||
"prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0,
|
||||
"completion_tokens": getattr(usage, "completion_tokens", 0) or 0,
|
||||
"total_tokens": getattr(usage, "total_tokens", 0) or 0,
|
||||
}
|
||||
|
||||
raw_title = response.choices[0].message.content.strip()
|
||||
if raw_title and len(raw_title) <= 100:
|
||||
return raw_title.strip("\"'"), usage_info
|
||||
return None, usage_info
|
||||
except Exception:
|
||||
logger.exception("[TitleGen] _generate_title failed")
|
||||
return None, None
|
||||
|
||||
|
||||
async def maybe_emit_title_update(
|
||||
*,
|
||||
title_task: asyncio.Task[tuple[str | None, dict | None]] | None,
|
||||
title_emitted: bool,
|
||||
chat_id: int,
|
||||
accumulator: TokenAccumulator,
|
||||
streaming_service: VercelStreamingService,
|
||||
):
|
||||
"""Inject one ``thread-title-update`` SSE if the task completed.
|
||||
|
||||
Yields the SSE frame (when applicable). Returns nothing; the orchestrator
|
||||
flips ``title_emitted`` itself after iterating so we don't fight Python's
|
||||
nonlocal-in-generator semantics.
|
||||
"""
|
||||
if title_task is None or title_emitted or not title_task.done():
|
||||
return
|
||||
generated_title, title_usage = title_task.result()
|
||||
if title_usage:
|
||||
accumulator.add(**title_usage)
|
||||
if generated_title:
|
||||
async with shielded_async_session() as title_session:
|
||||
title_thread_result = await title_session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == chat_id)
|
||||
)
|
||||
title_thread = title_thread_result.scalars().first()
|
||||
if title_thread:
|
||||
title_thread.title = generated_title
|
||||
await title_session.commit()
|
||||
yield streaming_service.format_thread_title_update(chat_id, generated_title)
|
||||
|
||||
|
||||
async def await_pending_title_update(
|
||||
*,
|
||||
title_task: asyncio.Task[tuple[str | None, dict | None]] | None,
|
||||
title_emitted: bool,
|
||||
chat_id: int,
|
||||
accumulator: TokenAccumulator,
|
||||
streaming_service: VercelStreamingService,
|
||||
):
|
||||
"""If the task hadn't completed during the stream, await it now and emit.
|
||||
|
||||
Used right before the finish frames in the success path. Mirror of
|
||||
``maybe_emit_title_update`` but unconditionally awaits.
|
||||
"""
|
||||
if title_task is None or title_emitted:
|
||||
return
|
||||
generated_title, title_usage = await title_task
|
||||
if title_usage:
|
||||
accumulator.add(**title_usage)
|
||||
if generated_title:
|
||||
async with shielded_async_session() as title_session:
|
||||
title_thread_result = await title_session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == chat_id)
|
||||
)
|
||||
title_thread = title_thread_result.scalars().first()
|
||||
if title_thread:
|
||||
title_thread.title = generated_title
|
||||
await title_session.commit()
|
||||
yield streaming_service.format_thread_title_update(chat_id, generated_title)
|
||||
Loading…
Add table
Add a link
Reference in a new issue