Merge pull request #1443 from CREDO23/feature-automations

[Feat] Automation V1 — Scheduled Agent Tasks, Created via Chat (HITL) or JSON
This commit is contained in:
Rohan Verma 2026-05-28 12:41:41 -07:00 committed by GitHub
commit 4dda02c06c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
219 changed files with 13821 additions and 55 deletions

View file

@ -0,0 +1,8 @@
"""Agent construction and per-turn event-loop drivers."""
from __future__ import annotations
from app.tasks.chat.streaming.agent.builder import build_main_agent_for_thread
from app.tasks.chat.streaming.agent.event_loop import stream_agent_events
__all__ = ["build_main_agent_for_thread", "stream_agent_events"]

View file

@ -0,0 +1,49 @@
"""Single per-thread agent (re)build path.
A graph swap mid-turn would corrupt checkpointer state for the same
``thread_id``, so both the initial build and any mid-stream 429 recovery rebuild
must funnel through this single function.
"""
from __future__ import annotations
from typing import Any
from app.agents.new_chat.filesystem_selection import FilesystemSelection
from app.agents.new_chat.llm_config import AgentConfig
from app.db import ChatVisibility
from app.services.connector_service import ConnectorService
async def build_main_agent_for_thread(
agent_factory: Any,
*,
llm: Any,
search_space_id: int,
db_session: Any,
connector_service: ConnectorService,
checkpointer: Any,
user_id: str | None,
thread_id: int | None,
agent_config: AgentConfig | None,
firecrawl_api_key: str | None,
thread_visibility: ChatVisibility | None,
filesystem_selection: FilesystemSelection | None,
disabled_tools: list[str] | None = None,
mentioned_document_ids: list[int] | None = None,
) -> Any:
return await agent_factory(
llm=llm,
search_space_id=search_space_id,
db_session=db_session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=thread_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=thread_visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
)

View file

@ -0,0 +1,175 @@
"""Per-turn agent event-loop driver.
Drives ``stream_output`` (graph_stream relay) for one agent turn, then runs the
post-stream agent-state inspection: safety-net commit of any staged filesystem
state (in case ``aafter_agent`` was skipped), file-operation contract scoring,
intent classification, and interrupt detection.
"""
from __future__ import annotations
from collections.abc import AsyncGenerator
from typing import Any
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware.kb_persistence import (
commit_staged_filesystem_state,
)
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.streaming.contract.file_contract import (
contract_enforcement_active,
evaluate_file_contract_outcome,
log_file_contract,
)
from app.tasks.chat.streaming.graph_stream.event_stream import stream_output
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
all_interrupt_values,
)
from app.tasks.chat.streaming.shared.stream_result import StreamResult
from app.tasks.chat.streaming.shared.utils import safe_float
from app.utils.perf import get_perf_logger
_perf_log = get_perf_logger()
async def stream_agent_events(
agent: Any,
config: dict[str, Any],
input_data: Any,
streaming_service: VercelStreamingService,
result: StreamResult,
step_prefix: str = "thinking",
initial_step_id: str | None = None,
initial_step_title: str = "",
initial_step_items: list[str] | None = None,
*,
fallback_commit_search_space_id: int | None = None,
fallback_commit_created_by_id: str | None = None,
fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
fallback_commit_thread_id: int | None = None,
runtime_context: Any = None,
content_builder: Any | None = None,
) -> AsyncGenerator[str, None]:
"""Stream and format ``astream_events`` from the agent.
Yields SSE-formatted strings; after exhausting, ``result`` carries
``accumulated_text`` and interrupt state. See ``StreamResult`` for the
side-channel surface populated by the underlying relay.
"""
async for sse in stream_output(
agent=agent,
config=config,
input_data=input_data,
streaming_service=streaming_service,
result=result,
step_prefix=step_prefix,
initial_step_id=initial_step_id,
initial_step_title=initial_step_title,
initial_step_items=initial_step_items,
content_builder=content_builder,
runtime_context=runtime_context,
):
yield sse
accumulated_text = result.accumulated_text
state = await agent.aget_state(config)
state_values = getattr(state, "values", {}) or {}
# Safety net: if astream_events was cancelled before
# KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work
# (dirty_paths / staged_dirs / pending_moves / pending_deletes /
# pending_dir_deletes) is still in the checkpointed state. Run the SAME
# shared commit helper so the turn's writes don't get lost on client
# disconnect, then push the delta back into the graph using ``as_node=...``
# so reducers fire as if the after_agent hook produced it.
if (
fallback_commit_filesystem_mode == FilesystemMode.CLOUD
and fallback_commit_search_space_id is not None
and (
(state_values.get("dirty_paths") or [])
or (state_values.get("staged_dirs") or [])
or (state_values.get("pending_moves") or [])
or (state_values.get("pending_deletes") or [])
or (state_values.get("pending_dir_deletes") or [])
)
):
try:
delta = await commit_staged_filesystem_state(
state_values,
search_space_id=fallback_commit_search_space_id,
created_by_id=fallback_commit_created_by_id,
filesystem_mode=fallback_commit_filesystem_mode,
thread_id=fallback_commit_thread_id,
dispatch_events=False,
)
if delta:
await agent.aupdate_state(
config,
delta,
as_node="KnowledgeBasePersistenceMiddleware.after_agent",
)
except Exception as exc:
_perf_log.warning("[stream_agent_events] safety-net commit failed: %s", exc)
contract_state = state_values.get("file_operation_contract") or {}
contract_turn_id = contract_state.get("turn_id")
current_turn_id = config.get("configurable", {}).get("turn_id", "")
intent_value = contract_state.get("intent")
if (
isinstance(intent_value, str)
and intent_value in ("chat_only", "file_write", "file_read")
and contract_turn_id == current_turn_id
):
result.intent_detected = intent_value
if (
isinstance(intent_value, str)
and intent_value in ("chat_only", "file_write", "file_read")
and contract_turn_id != current_turn_id
):
# Ignore stale intent contracts from previous turns/checkpoints.
result.intent_detected = "chat_only"
result.intent_confidence = (
safe_float(contract_state.get("confidence"), default=0.0)
if contract_turn_id == current_turn_id
else 0.0
)
if result.intent_detected == "file_write":
result.commit_gate_passed, result.commit_gate_reason = (
evaluate_file_contract_outcome(result)
)
if not result.commit_gate_passed and contract_enforcement_active(result):
gate_notice = (
"I could not complete the requested file write because no successful "
"write_file/edit_file operation was confirmed."
)
gate_text_id = streaming_service.generate_text_id()
yield streaming_service.format_text_start(gate_text_id)
if content_builder is not None:
content_builder.on_text_start(gate_text_id)
yield streaming_service.format_text_delta(gate_text_id, gate_notice)
if content_builder is not None:
content_builder.on_text_delta(gate_text_id, gate_notice)
yield streaming_service.format_text_end(gate_text_id)
if content_builder is not None:
content_builder.on_text_end(gate_text_id)
yield streaming_service.format_terminal_info(gate_notice, "error")
accumulated_text = gate_notice
else:
result.commit_gate_passed = True
result.commit_gate_reason = ""
result.accumulated_text = accumulated_text
log_file_contract("turn_outcome", result)
pending_values = all_interrupt_values(state)
if pending_values:
result.is_interrupted = True
# One frame per paused subagent so each parallel HITL renders its own
# approval card on the wire. Order matches ``state.interrupts``, which
# the resume slicer in
# ``checkpointed_subagent_middleware.resume_routing`` consumes in the
# same order — keeping emit and resume in lock-step.
for interrupt_value in pending_values:
yield streaming_service.format_interrupt_request(interrupt_value)

View file

@ -0,0 +1,15 @@
"""Pre-agent context shaping: mentioned-doc rendering and todos extraction."""
from __future__ import annotations
from app.tasks.chat.streaming.context.deepagents_todos import (
extract_todos_from_deepagents,
)
from app.tasks.chat.streaming.context.mentioned_docs import (
format_mentioned_surfsense_docs_as_context,
)
__all__ = [
"extract_todos_from_deepagents",
"format_mentioned_surfsense_docs_as_context",
]

View file

@ -0,0 +1,27 @@
"""Extract todos from a deepagents ``TodoListMiddleware`` ``Command`` output."""
from __future__ import annotations
from typing import Any
def extract_todos_from_deepagents(command_output: Any) -> dict:
"""Normalize todos out of a deepagents ``Command`` or dict payload.
deepagents returns a ``Command`` whose ``update['todos']`` is a list of
``{'content': str, 'status': str}`` dicts. The UI expects the same shape,
so no transformation is required only extraction.
"""
todos_data: list = []
if hasattr(command_output, "update"):
update = command_output.update
todos_data = update.get("todos", [])
elif isinstance(command_output, dict):
if "todos" in command_output:
todos_data = command_output.get("todos", [])
elif "update" in command_output and isinstance(
command_output["update"], dict
):
todos_data = command_output["update"].get("todos", [])
return {"todos": todos_data}

View file

@ -0,0 +1,58 @@
"""Render user-mentioned SurfSense docs as XML context for the agent."""
from __future__ import annotations
import json
from app.db import SurfsenseDocsDocument
from app.utils.surfsense_docs import surfsense_docs_public_url
def format_mentioned_surfsense_docs_as_context(
documents: list[SurfsenseDocsDocument],
) -> str:
if not documents:
return ""
context_parts = ["<mentioned_surfsense_docs>"]
context_parts.append(
"The user has explicitly mentioned the following SurfSense documentation pages. "
"These are official documentation about how to use SurfSense and should be used to answer questions about the application. "
"Use [citation:CHUNK_ID] format for citations (e.g., [citation:doc-123])."
)
for doc in documents:
public_url = surfsense_docs_public_url(doc.source)
metadata_json = json.dumps(
{"source": doc.source, "public_url": public_url}, ensure_ascii=False
)
context_parts.append("<document>")
context_parts.append("<document_metadata>")
context_parts.append(f" <document_id>doc-{doc.id}</document_id>")
context_parts.append(" <document_type>SURFSENSE_DOCS</document_type>")
context_parts.append(f" <title><![CDATA[{doc.title}]]></title>")
context_parts.append(f" <url><![CDATA[{public_url}]]></url>")
context_parts.append(
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>"
)
context_parts.append("</document_metadata>")
context_parts.append("")
context_parts.append("<document_content>")
if hasattr(doc, "chunks") and doc.chunks:
for chunk in doc.chunks:
context_parts.append(
f" <chunk id='doc-{chunk.id}'><![CDATA[{chunk.content}]]></chunk>"
)
else:
context_parts.append(
f" <chunk id='doc-0'><![CDATA[{doc.content}]]></chunk>"
)
context_parts.append("</document_content>")
context_parts.append("</document>")
context_parts.append("")
context_parts.append("</mentioned_surfsense_docs>")
return "\n".join(context_parts)

View file

@ -0,0 +1,15 @@
"""File-operation contract evaluation and logging."""
from __future__ import annotations
from app.tasks.chat.streaming.contract.file_contract import (
contract_enforcement_active,
evaluate_file_contract_outcome,
log_file_contract,
)
__all__ = [
"contract_enforcement_active",
"evaluate_file_contract_outcome",
"log_file_contract",
]

View file

@ -0,0 +1,53 @@
"""File-operation contract: when to enforce, how to score, how to log."""
from __future__ import annotations
import json
from typing import Any
from app.tasks.chat.streaming.shared.stream_result import StreamResult
from app.utils.perf import get_perf_logger
_perf_log = get_perf_logger()
def contract_enforcement_active(result: StreamResult) -> bool:
# Enforce only in desktop local-folder mode. Kept deterministic, no
# env-driven progression modes.
return result.filesystem_mode == "desktop_local_folder"
def evaluate_file_contract_outcome(result: StreamResult) -> tuple[bool, str]:
if result.intent_detected != "file_write":
return True, ""
if not result.write_attempted:
return False, "no_write_attempt"
if not result.write_succeeded:
return False, "write_failed"
if not result.verification_succeeded:
return False, "verification_failed"
return True, ""
def log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None:
payload: dict[str, Any] = {
"stage": stage,
"request_id": result.request_id or "unknown",
"turn_id": result.turn_id or "unknown",
"chat_id": (
result.turn_id.split(":", 1)[0] if ":" in result.turn_id else "unknown"
),
"filesystem_mode": result.filesystem_mode,
"client_platform": result.client_platform,
"intent_detected": result.intent_detected,
"intent_confidence": result.intent_confidence,
"write_attempted": result.write_attempted,
"write_succeeded": result.write_succeeded,
"verification_succeeded": result.verification_succeeded,
"commit_gate_passed": result.commit_gate_passed,
"commit_gate_reason": result.commit_gate_reason or None,
}
payload.update(extra)
_perf_log.info(
"[file_operation_contract] %s", json.dumps(payload, ensure_ascii=False)
)

View file

@ -0,0 +1,17 @@
"""Top-level streaming flows: ``new_chat`` and ``resume_chat`` orchestrators.
Re-exports the public entry points so callers can write::
from app.tasks.chat.streaming.flows import stream_new_chat, stream_resume_chat
The orchestrators themselves live under ``new_chat/orchestrator.py`` and
``resume_chat/orchestrator.py`` (slim composition of the per-concern modules in
each flow folder and the building blocks in ``shared/``).
"""
from __future__ import annotations
from app.tasks.chat.streaming.flows.new_chat import stream_new_chat
from app.tasks.chat.streaming.flows.resume_chat import stream_resume_chat
__all__ = ["stream_new_chat", "stream_resume_chat"]

View file

@ -0,0 +1,12 @@
"""New-chat streaming flow.
The public entry point ``stream_new_chat`` is the slim coroutine in
``orchestrator.py`` that composes the per-concern modules in this folder and
the building blocks under ``flows/shared/``.
"""
from __future__ import annotations
from app.tasks.chat.streaming.flows.new_chat.orchestrator import stream_new_chat
__all__ = ["stream_new_chat"]

View file

@ -0,0 +1,95 @@
"""Resolve the auto-pin for the *initial* turn config.
Auto-pin (``selected_llm_config_id=0``) picks the best eligible LLM config for
this thread / search space / user, optionally filtered to vision-capable
configs when the turn carries images.
Errors classified here:
* ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` the auto-pin pool has no
vision-capable cfg for an image-bearing turn. The same gate fires later
in ``llm_capability`` for explicit selections; mapping both to the same
code keeps the FE error UI consistent.
* ``SERVER_ERROR`` any other ``ValueError`` from the resolver.
This module owns *initial* pin resolution; the rate-limit recovery loop has
its own narrower auto-pin call (with ``exclude_config_ids``) in
``flows/shared/rate_limit_recovery``.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal
from sqlalchemy.ext.asyncio import AsyncSession
from app.observability import otel as ot
from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id
@dataclass
class AutoPinResult:
"""Outcome of ``resolve_initial_auto_pin``.
``llm_config_id`` is set when ``error`` is ``None``; ``error`` carries the
classified user-facing message plus error code/kind so the orchestrator can
emit one terminal-error SSE frame.
"""
llm_config_id: int | None
error: tuple[str, str, Literal["user_error", "server_error"]] | None
async def resolve_initial_auto_pin(
session: AsyncSession,
*,
chat_id: int,
search_space_id: int,
user_id: str | None,
selected_llm_config_id: int,
requires_image_input: bool,
requested_llm_config_id: int,
) -> AutoPinResult:
"""Run the resolver and classify any ``ValueError`` for the SSE error path."""
try:
pinned = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=selected_llm_config_id,
requires_image_input=requires_image_input,
)
ot.add_event(
"model.pin.resolved",
{
"pin.requested_id": requested_llm_config_id,
"pin.resolved_id": pinned.resolved_llm_config_id,
"pin.requires_image_input": requires_image_input,
},
)
return AutoPinResult(
llm_config_id=pinned.resolved_llm_config_id, error=None
)
except ValueError as pin_error:
# The "no vision-capable cfg" path raises a ValueError whose message
# we map to the friendly image-input SSE error so the user sees the
# same message regardless of whether the gate fired in the resolver or
# in ``llm_capability.assert_vision_capability_for_image_turn``.
is_vision_failure = (
requires_image_input and "vision-capable" in str(pin_error)
)
error_code = (
"MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
if is_vision_failure
else "SERVER_ERROR"
)
error_kind: Literal["user_error", "server_error"] = (
"user_error" if is_vision_failure else "server_error"
)
if is_vision_failure:
ot.add_event("quota.denied", {"quota.code": error_code})
return AutoPinResult(
llm_config_id=None, error=(str(pin_error), error_code, error_kind)
)

View file

@ -0,0 +1,95 @@
"""Build and emit the first ``thinking-1`` step for a new-chat turn.
The step title and "Processing X" items are derived from what the user sent
(text snippet, image count, mentioned doc titles) so the FE can render a
meaningful placeholder while the agent stream warms up.
``thinking-1`` is the canonical id for this step every subsequent
``thinking-N`` produced by ``stream_agent_events`` folds into the same
singleton ``data-thinking-steps`` part on the FE.
"""
from __future__ import annotations
from collections.abc import Iterator
from dataclasses import dataclass
from typing import Any
from app.db import SurfsenseDocsDocument
from app.services.new_streaming_service import VercelStreamingService
@dataclass
class InitialThinkingStep:
"""Resolved fields passed both into the SSE frame and the builder hook.
``items`` is the bullet list under the step title; ``title`` is the
one-line step header. ``step_id`` is hard-coded ``thinking-1`` so the FE
Timeline can de-duplicate against the prior assistant message on resume.
"""
step_id: str
title: str
items: list[str]
def build_initial_thinking_step(
*,
user_query: str,
user_image_data_urls: list[str] | None,
mentioned_surfsense_docs: list[SurfsenseDocsDocument],
) -> InitialThinkingStep:
if mentioned_surfsense_docs:
title = "Analyzing referenced content"
action_verb = "Analyzing"
else:
title = "Understanding your request"
action_verb = "Processing"
processing_parts: list[str] = []
if user_query.strip():
query_text = user_query[:80] + ("..." if len(user_query) > 80 else "")
processing_parts.append(query_text)
elif user_image_data_urls:
processing_parts.append(f"[{len(user_image_data_urls)} image(s)]")
else:
processing_parts.append("(message)")
if mentioned_surfsense_docs:
doc_names: list[str] = []
for doc in mentioned_surfsense_docs:
t = doc.title
if len(t) > 30:
t = t[:27] + "..."
doc_names.append(t)
if len(doc_names) == 1:
processing_parts.append(f"[{doc_names[0]}]")
else:
processing_parts.append(f"[{len(doc_names)} docs]")
items = [f"{action_verb}: {' '.join(processing_parts)}"]
return InitialThinkingStep(step_id="thinking-1", title=title, items=items)
def iter_initial_thinking_step_frame(
step: InitialThinkingStep,
*,
streaming_service: VercelStreamingService,
content_builder: Any | None,
) -> Iterator[str]:
"""Drive both the SSE emission and the builder hook for the initial step.
The FE folds this step into the same singleton ``data-thinking-steps`` part
as everything the agent stream emits later, so we mirror that fold
server-side by driving the builder lifecycle ourselves.
"""
if content_builder is not None:
content_builder.on_thinking_step(
step.step_id, step.title, "in_progress", step.items
)
yield streaming_service.format_thinking_step(
step_id=step.step_id,
title=step.title,
status="in_progress",
items=step.items,
)

View file

@ -0,0 +1,264 @@
r"""Assemble the LangGraph ``input_state`` for the new-chat turn.
Pipeline:
1. **History bootstrap** only for cloned chats with no LangGraph checkpoint
yet; flips the per-thread ``needs_history_bootstrap`` flag back to False
once the rows are loaded.
2. **Mentioned SurfSense docs** eager-load chunks so the formatter has the
full content without a second roundtrip.
3. **Recent reports** top 3 by id desc with non-null content, so the LLM
can resolve ``report_id`` for versioning without spelunking history.
4. **@-mention resolve** (cloud mode) substitute ``@title`` tokens in the
query with canonical ``\`/documents/...\``` paths the LLM expects.
5. **Context block render** XML-wrap surfsense docs + reports, prepend to
the rewritten query, optionally prefix with display name for SEARCH_SPACE
visibility.
6. **HumanMessage** multimodal content if images are attached.
Returns the assembled ``input_state`` dict plus side-channel data the
orchestrator needs downstream (``accepted_folder_ids`` for runtime context;
``mentioned_surfsense_docs`` for the initial thinking step).
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Any
from langchain_core.messages import HumanMessage
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.mention_resolver import resolve_mentions, substitute_in_text
from app.db import (
ChatVisibility,
NewChatThread,
Report,
SurfsenseDocsDocument,
)
from app.tasks.chat.streaming.context.mentioned_docs import (
format_mentioned_surfsense_docs_as_context,
)
from app.utils.content_utils import bootstrap_history_from_db
from app.utils.user_message_multimodal import build_human_message_content
logger = logging.getLogger(__name__)
@dataclass
class NewChatInputState:
"""Everything ``build_new_chat_input_state`` produces.
``input_state`` is fed straight to the agent. ``accepted_folder_ids``
feeds the runtime context (the resolver may have dropped some chips).
``mentioned_surfsense_docs`` is consumed by the initial thinking-step
builder for the FE placeholder before the agent stream starts.
"""
input_state: dict[str, Any]
accepted_folder_ids: list[int]
mentioned_surfsense_docs: list[SurfsenseDocsDocument]
async def build_new_chat_input_state(
session: AsyncSession,
*,
chat_id: int,
search_space_id: int,
user_query: str,
user_image_data_urls: list[str] | None,
mentioned_document_ids: list[int] | None,
mentioned_surfsense_doc_ids: list[int] | None,
mentioned_folder_ids: list[int] | None,
mentioned_documents: list[dict[str, Any]] | None,
needs_history_bootstrap: bool,
thread_visibility: ChatVisibility,
current_user_display_name: str | None,
filesystem_mode: str,
request_id: str | None,
turn_id: str,
) -> NewChatInputState:
langchain_messages: list[Any] = []
if needs_history_bootstrap:
langchain_messages = await bootstrap_history_from_db(
session, chat_id, thread_visibility=thread_visibility
)
thread_result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == chat_id)
)
thread = thread_result.scalars().first()
if thread:
thread.needs_history_bootstrap = False
await session.commit()
mentioned_surfsense_docs: list[SurfsenseDocsDocument] = []
if mentioned_surfsense_doc_ids:
result = await session.execute(
select(SurfsenseDocsDocument)
.options(selectinload(SurfsenseDocsDocument.chunks))
.filter(SurfsenseDocsDocument.id.in_(mentioned_surfsense_doc_ids))
)
mentioned_surfsense_docs = list(result.scalars().all())
# Top 3 reports keyed by id desc (newest first) with content present,
# surfaced inline so the LLM resolves ``report_id`` for versioning without
# digging through conversation history.
recent_reports_result = await session.execute(
select(Report)
.filter(
Report.thread_id == chat_id,
Report.content.isnot(None),
)
.order_by(Report.id.desc())
.limit(3)
)
recent_reports = list(recent_reports_result.scalars().all())
agent_user_query, accepted_folder_ids = await _resolve_mentions_for_query(
session,
search_space_id=search_space_id,
user_query=user_query,
filesystem_mode=filesystem_mode,
mentioned_document_ids=mentioned_document_ids,
mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids,
mentioned_folder_ids=mentioned_folder_ids,
mentioned_documents=mentioned_documents,
)
final_query = _render_query_with_context(
agent_user_query=agent_user_query,
mentioned_surfsense_docs=mentioned_surfsense_docs,
recent_reports=recent_reports,
)
if thread_visibility == ChatVisibility.SEARCH_SPACE and current_user_display_name:
final_query = f"**[{current_user_display_name}]:** {final_query}"
human_content = build_human_message_content(
final_query, list(user_image_data_urls or ())
)
langchain_messages.append(HumanMessage(content=human_content))
input_state = {
"messages": langchain_messages,
"search_space_id": search_space_id,
"request_id": request_id or "unknown",
"turn_id": turn_id,
}
return NewChatInputState(
input_state=input_state,
accepted_folder_ids=accepted_folder_ids,
mentioned_surfsense_docs=mentioned_surfsense_docs,
)
async def _resolve_mentions_for_query(
session: AsyncSession,
*,
search_space_id: int,
user_query: str,
filesystem_mode: str,
mentioned_document_ids: list[int] | None,
mentioned_surfsense_doc_ids: list[int] | None,
mentioned_folder_ids: list[int] | None,
mentioned_documents: list[dict[str, Any]] | None,
) -> tuple[str, list[int]]:
r"""Resolve @-mention chips and rewrite the user query to canonical paths.
Cloud mode only: local-folder mode keeps the legacy ``@title`` text path
(mention support there is a follow-up task the path scheme is
mount-rooted and the picker UI both need separate work).
The substitution lands in the returned ``agent_user_query`` ONLY the
original ``user_query`` (with ``@title`` tokens) flows untouched into
``persist_user_turn`` so chip rendering on reload still works
(``UserTextPart`` ``parseMentionSegments`` matches ``@title``, not
``\`/documents/...\```). It also feeds the human-readable surfaces SSE
"Processing X" status, auto thread title, memory seed which all want
what the user typed.
"""
agent_user_query = user_query
accepted_folder_ids: list[int] = []
has_any_mention = bool(
mentioned_document_ids
or mentioned_surfsense_doc_ids
or mentioned_folder_ids
or mentioned_documents
)
if filesystem_mode != FilesystemMode.CLOUD.value or not has_any_mention:
return agent_user_query, accepted_folder_ids
from app.schemas.new_chat import MentionedDocumentInfo
chip_objs: list[MentionedDocumentInfo] | None = None
if mentioned_documents:
chip_objs = []
for raw in mentioned_documents:
if isinstance(raw, MentionedDocumentInfo):
chip_objs.append(raw)
continue
try:
chip_objs.append(MentionedDocumentInfo.model_validate(raw))
except Exception:
logger.debug(
"stream_new_chat: dropping malformed mention chip %r", raw
)
resolved = await resolve_mentions(
session,
search_space_id=search_space_id,
mentioned_documents=chip_objs,
mentioned_document_ids=mentioned_document_ids,
mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids,
mentioned_folder_ids=mentioned_folder_ids,
)
agent_user_query = substitute_in_text(user_query, resolved.token_to_path)
accepted_folder_ids = resolved.mentioned_folder_ids
return agent_user_query, accepted_folder_ids
def _render_query_with_context(
*,
agent_user_query: str,
mentioned_surfsense_docs: list[SurfsenseDocsDocument],
recent_reports: list[Report],
) -> str:
"""Prepend surfsense-docs + recent-reports XML blocks to the user query."""
context_parts: list[str] = []
if mentioned_surfsense_docs:
context_parts.append(
format_mentioned_surfsense_docs_as_context(mentioned_surfsense_docs)
)
if recent_reports:
report_lines: list[str] = []
for r in recent_reports:
report_lines.append(
f' - report_id={r.id}, title="{r.title}", '
f'style="{r.report_style or "detailed"}"'
)
reports_listing = "\n".join(report_lines)
context_parts.append(
"<report_context>\n"
"Previously generated reports in this conversation:\n"
f"{reports_listing}\n\n"
"If the user wants to MODIFY, REVISE, UPDATE, or ADD to one of "
"these reports, set parent_report_id to the relevant report_id above.\n"
"If the user wants a completely NEW report on a different topic, "
"leave parent_report_id unset.\n"
"</report_context>"
)
if context_parts:
context = "\n\n".join(context_parts)
return f"{context}\n\n<user_query>{agent_user_query}</user_query>"
return agent_user_query

View file

@ -0,0 +1,62 @@
"""Vision-capability gate for image-bearing turns.
Capability safety net for explicit (non-auto-pin) selections: a turn carrying
user-uploaded images cannot be routed to a chat config that LiteLLM's
authoritative model map *explicitly* marks as text-only (``supports_vision``
set to False). The check is intentionally narrow it only fires when LiteLLM
is *certain* the model can't accept image input; unknown / unmapped /
vision-capable models pass through.
Without this guard a known-text-only model would 404 at the provider with
``"No endpoints found that support image input"``, surfacing as an opaque
``SERVER_ERROR`` SSE chunk; failing here lets us return a friendly message that
tells the user what to change.
"""
from __future__ import annotations
from app.agents.new_chat.llm_config import AgentConfig
from app.observability import otel as ot
def check_image_input_capability(
*,
user_image_data_urls: list[str] | None,
agent_config: AgentConfig | None,
) -> tuple[str, str] | None:
"""Return ``(user_message, error_code)`` when the gate trips, else ``None``.
The caller emits one terminal-error SSE frame on a non-``None`` return.
"""
if not (user_image_data_urls and agent_config is not None):
return None
from app.services.provider_capabilities import is_known_text_only_chat_model
agent_litellm_params = agent_config.litellm_params or {}
agent_base_model = (
agent_litellm_params.get("base_model")
if isinstance(agent_litellm_params, dict)
else None
)
if not is_known_text_only_chat_model(
provider=agent_config.provider,
model_name=agent_config.model_name,
base_model=agent_base_model,
custom_provider=agent_config.custom_provider,
):
return None
model_label = agent_config.config_name or agent_config.model_name or "model"
ot.add_event(
"quota.denied", {"quota.code": "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"}
)
return (
(
f"The selected model ({model_label}) does not support "
"image input. Switch to a vision-capable model "
"(e.g. GPT-4o, Claude, Gemini) or remove the image "
"attachment and try again."
),
"MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT",
)

View file

@ -0,0 +1,868 @@
"""``stream_new_chat`` — public entry point for a fresh chat turn.
Slim composition layer over the per-concern modules in this folder and the
building blocks under ``flows/shared/``. Each phase corresponds to a numbered
block in the surrounding code so the on-the-wire ordering stays explicit:
1. Validation / config auto-pin, LLM bundle, capability, premium reserve.
2. Concurrent persistence + pre-stream setup spawn DB writes, build the
connector, fetch the checkpointer, build the agent.
3. Input assembly history bootstrap, mentions, surfsense docs, reports.
4. First SSE frames message_start, start_step, turn-info, turn-status.
5. Persistence join + message-id frames (ghost-thread protection).
6. Initial thinking step + title task + runtime context.
7. Stream loop with in-stream rate-limit recovery + mid-stream title emit.
8. Finalize premium debit, token-usage SSE, finish frames.
9. Exception branch classify, emit terminal error, finish frames.
10. Finally premium release, session close, assistant finalize, GC, span.
"""
from __future__ import annotations
import asyncio
import contextlib
import logging
import time
from collections.abc import AsyncGenerator
from functools import partial
from typing import Any, Literal
import anyio
from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
from app.agents.new_chat.middleware.busy_mutex import end_turn
from app.config import config as _app_config
from app.db import ChatVisibility, async_session_maker
from app.observability import otel as ot
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.content_builder import AssistantContentBuilder
from app.tasks.chat.streaming.agent.builder import build_main_agent_for_thread
from app.tasks.chat.streaming.contract.file_contract import log_file_contract
from app.tasks.chat.streaming.errors.emitter import emit_stream_terminal_error
from app.tasks.chat.streaming.flows.new_chat.auto_pin import resolve_initial_auto_pin
from app.tasks.chat.streaming.flows.new_chat.initial_thinking_step import (
build_initial_thinking_step,
iter_initial_thinking_step_frame,
)
from app.tasks.chat.streaming.flows.new_chat.input_state import (
build_new_chat_input_state,
)
from app.tasks.chat.streaming.flows.new_chat.llm_capability import (
check_image_input_capability,
)
from app.tasks.chat.streaming.flows.new_chat.persistence_spawn import (
await_persist_task,
spawn_persist_assistant_shell_task,
spawn_persist_user_task,
spawn_set_ai_responding_bg,
)
from app.tasks.chat.streaming.flows.new_chat.runtime_context import (
build_new_chat_runtime_context,
)
from app.tasks.chat.streaming.flows.new_chat.title_gen import (
await_pending_title_update,
maybe_emit_title_update,
spawn_title_task,
)
from app.tasks.chat.streaming.flows.shared.assistant_finalize import (
finalize_assistant_message,
)
from app.tasks.chat.streaming.flows.shared.finalize_emit import iter_token_usage_frame
from app.tasks.chat.streaming.flows.shared.finally_cleanup import (
close_session_and_clear_ai_responding,
run_gc_pass,
)
from app.tasks.chat.streaming.flows.shared.first_frames import (
iter_final_frames,
iter_initial_frames,
)
from app.tasks.chat.streaming.flows.shared.llm_bundle import load_llm_bundle
from app.tasks.chat.streaming.flows.shared.pre_stream_setup import (
get_chat_checkpointer,
setup_connector_and_firecrawl,
)
from app.tasks.chat.streaming.flows.shared.premium_quota import (
PremiumReservation,
finalize_premium,
needs_premium_quota,
release_premium,
reserve_premium,
)
from app.tasks.chat.streaming.flows.shared.rate_limit_recovery import (
can_recover_provider_rate_limit,
log_rate_limit_recovered,
reroute_to_next_auto_pin,
)
from app.tasks.chat.streaming.flows.shared.span import (
close_chat_request_span,
open_chat_request_span,
set_agent_mode,
)
from app.tasks.chat.streaming.flows.shared.stream_loop import run_stream_loop
from app.tasks.chat.streaming.flows.shared.terminal_error import (
handle_terminal_exception,
)
from app.tasks.chat.streaming.shared.stream_result import StreamResult
from app.utils.perf import get_perf_logger, log_system_snapshot
logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
# Holds spawned background tasks (set_ai_responding, persist_user, persist_asst)
# so the GC doesn't drop them before they finish. Kept at module level so it
# survives across turns within one process.
_background_tasks: set[asyncio.Task] = set()
async def stream_new_chat(
user_query: str,
search_space_id: int,
chat_id: int,
user_id: str | None = None,
llm_config_id: int = -1,
mentioned_document_ids: list[int] | None = None,
mentioned_surfsense_doc_ids: list[int] | None = None,
mentioned_folder_ids: list[int] | None = None,
mentioned_documents: list[dict[str, Any]] | None = None,
checkpoint_id: str | None = None,
needs_history_bootstrap: bool = False,
thread_visibility: ChatVisibility | None = None,
current_user_display_name: str | None = None,
disabled_tools: list[str] | None = None,
filesystem_selection: FilesystemSelection | None = None,
request_id: str | None = None,
user_image_data_urls: list[str] | None = None,
flow: Literal["new", "regenerate"] = "new",
) -> AsyncGenerator[str, None]:
"""Stream a new chat turn using the SurfSense deep agent.
Uses the Vercel AI SDK Data Stream Protocol (SSE). ``chat_id`` is the
LangGraph thread id (durable conversation memory via the checkpointer).
Manages its own database session so cleanup runs even when Starlette
cancels the task on client disconnect.
"""
streaming_service = VercelStreamingService()
stream_result = StreamResult()
_t_total = time.perf_counter()
fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud"
fs_platform = (
filesystem_selection.client_platform.value if filesystem_selection else "web"
)
stream_result.request_id = request_id
stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}"
stream_result.filesystem_mode = fs_mode
stream_result.client_platform = fs_platform
chat_agent_mode = "unknown"
chat_outcome = "success"
chat_error_category: str | None = None
chat_span_cm, chat_span = open_chat_request_span(
chat_id=chat_id,
search_space_id=search_space_id,
flow=flow,
request_id=request_id,
turn_id=stream_result.turn_id,
filesystem_mode=fs_mode,
client_platform=fs_platform,
agent_mode=chat_agent_mode,
)
log_file_contract("turn_start", stream_result)
_perf_log.info(
"[stream_new_chat] filesystem_mode=%s client_platform=%s",
fs_mode,
fs_platform,
)
log_system_snapshot("stream_new_chat_START")
from app.services.token_tracking_service import start_turn
accumulator = start_turn()
premium_reservation: PremiumReservation | None = None
busy_error_raised = False
emit_stream_error = partial(
emit_stream_terminal_error,
streaming_service=streaming_service,
flow=flow,
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
)
session = async_session_maker()
# Declared at function scope so SSE-yield join points and the finally
# clause see them on every exit path.
persist_user_task: asyncio.Task[int | None] | None = None
persist_asst_task: asyncio.Task[int | None] | None = None
try:
spawn_set_ai_responding_bg(
chat_id=chat_id, user_id=user_id, background_tasks=_background_tasks
)
# --- Block 1: LLM config + capability ---
requested_llm_config_id = llm_config_id
requires_image_input = bool(user_image_data_urls)
_t0 = time.perf_counter()
pin_result = await resolve_initial_auto_pin(
session,
chat_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=llm_config_id,
requires_image_input=requires_image_input,
requested_llm_config_id=requested_llm_config_id,
)
if pin_result.error is not None:
message, error_code, error_kind = pin_result.error
yield emit_stream_error(
message=message, error_kind=error_kind, error_code=error_code
)
yield streaming_service.format_done()
return
llm_config_id = pin_result.llm_config_id # type: ignore[assignment]
llm, agent_config, llm_load_error = await load_llm_bundle(
session, config_id=llm_config_id, search_space_id=search_space_id
)
if llm_load_error:
yield emit_stream_error(
message=llm_load_error,
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
_perf_log.info(
"[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)",
time.perf_counter() - _t0,
llm_config_id,
)
capability_error = check_image_input_capability(
user_image_data_urls=user_image_data_urls, agent_config=agent_config
)
if capability_error is not None:
message, error_code = capability_error
yield emit_stream_error(
message=message,
error_kind="user_error",
error_code=error_code,
)
yield streaming_service.format_done()
return
if needs_premium_quota(agent_config, user_id):
premium_reservation = await reserve_premium(
agent_config=agent_config, user_id=user_id # type: ignore[arg-type]
)
if not premium_reservation.allowed:
ot.add_event("quota.denied", {"quota.code": "PREMIUM_QUOTA_EXHAUSTED"})
if requested_llm_config_id == 0:
pin_fallback = await resolve_initial_auto_pin(
session,
chat_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=0,
requires_image_input=requires_image_input,
requested_llm_config_id=requested_llm_config_id,
)
if pin_fallback.error is not None:
message, error_code, error_kind = pin_fallback.error
yield emit_stream_error(
message=message,
error_kind=error_kind,
error_code=error_code,
)
yield streaming_service.format_done()
return
llm_config_id = pin_fallback.llm_config_id # type: ignore[assignment]
ot.add_event(
"model.repin",
{
"repin.reason": "premium_quota_exhausted",
"repin.to_config_id": llm_config_id,
},
)
llm, agent_config, llm_load_error = await load_llm_bundle(
session,
config_id=llm_config_id,
search_space_id=search_space_id,
)
if llm_load_error:
yield emit_stream_error(
message=llm_load_error,
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
premium_reservation = None
# Re-route to free fallback logged via the structured
# stream-error logger so cost/analytics see the auto-switch.
from app.tasks.chat.streaming.errors.classifier import (
log_chat_stream_error,
)
log_chat_stream_error(
flow=flow,
error_kind="premium_quota_exhausted",
error_code="PREMIUM_QUOTA_EXHAUSTED",
severity="info",
is_expected=True,
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
message=(
"Premium quota exhausted on pinned model; "
"auto-fallback switched to a free model"
),
extra={
"fallback_config_id": llm_config_id,
"auto_fallback": True,
},
)
else:
yield emit_stream_error(
message=(
"Buy more tokens to continue with this model, or "
"switch to a free model"
),
error_kind="premium_quota_exhausted",
error_code="PREMIUM_QUOTA_EXHAUSTED",
severity="info",
is_expected=True,
extra={
"resolved_config_id": llm_config_id,
"auto_fallback": False,
},
)
yield streaming_service.format_done()
return
if not llm:
yield emit_stream_error(
message="Failed to create LLM instance",
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
# --- Block 2: Spawn concurrent persistence; build pre-stream setup ---
persist_user_task = spawn_persist_user_task(
chat_id=chat_id,
user_id=user_id,
turn_id=stream_result.turn_id,
user_query=user_query,
user_image_data_urls=user_image_data_urls,
mentioned_documents=mentioned_documents,
background_tasks=_background_tasks,
)
persist_asst_task = spawn_persist_assistant_shell_task(
chat_id=chat_id,
user_id=user_id,
turn_id=stream_result.turn_id,
background_tasks=_background_tasks,
)
_t0 = time.perf_counter()
connector_service, firecrawl_api_key = await setup_connector_and_firecrawl(
session, search_space_id=search_space_id
)
_perf_log.info(
"[stream_new_chat] Connector service + firecrawl key in %.3fs",
time.perf_counter() - _t0,
)
_t0 = time.perf_counter()
checkpointer = await get_chat_checkpointer()
_perf_log.info(
"[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0
)
visibility = thread_visibility or ChatVisibility.PRIVATE
use_multi_agent = bool(_app_config.MULTI_AGENT_CHAT_ENABLED)
chat_agent_mode = "multi" if use_multi_agent else "single"
set_agent_mode(chat_span, chat_agent_mode)
_t0 = time.perf_counter()
agent_factory = (
create_multi_agent_chat_deep_agent
if use_multi_agent
else create_surfsense_deep_agent
)
# Build the agent inline. Provider 429s surface through the in-stream
# recovery loop below, which repins the thread to an eligible
# alternative config and rebuilds the agent before the user sees any
# output.
agent = await build_main_agent_for_thread(
agent_factory,
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
)
_perf_log.info(
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
)
# --- Block 3: Input assembly ---
_t0 = time.perf_counter()
assembled = await build_new_chat_input_state(
session,
chat_id=chat_id,
search_space_id=search_space_id,
user_query=user_query,
user_image_data_urls=user_image_data_urls,
mentioned_document_ids=mentioned_document_ids,
mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids,
mentioned_folder_ids=mentioned_folder_ids,
mentioned_documents=mentioned_documents,
needs_history_bootstrap=needs_history_bootstrap,
thread_visibility=visibility,
current_user_display_name=current_user_display_name,
filesystem_mode=fs_mode,
request_id=request_id,
turn_id=stream_result.turn_id,
)
input_state = assembled.input_state
accepted_folder_ids = assembled.accepted_folder_ids
mentioned_surfsense_docs = assembled.mentioned_surfsense_docs
_perf_log.info(
"[stream_new_chat] History bootstrap + doc/report queries in %.3fs",
time.perf_counter() - _t0,
)
# All pre-streaming DB reads done. Commit to release the transaction
# and its ACCESS SHARE locks so we don't block DDL (e.g. migrations)
# for the entire LLM streaming duration. Tools that need DB access
# during streaming start their own short-lived transactions (or use
# isolated sessions).
await session.commit()
# Detach heavy ORM objects (documents with chunks, reports, etc.)
# from the session identity map now that we've extracted what we
# need. Without this they accumulate in memory for the entire
# streaming duration (which can be several minutes).
session.expunge_all()
_perf_log.info(
"[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)",
time.perf_counter() - _t_total,
chat_id,
)
configurable: dict[str, Any] = {
"thread_id": str(chat_id),
"request_id": request_id or "unknown",
"turn_id": stream_result.turn_id,
}
if checkpoint_id:
configurable["checkpoint_id"] = checkpoint_id
config = {
"configurable": configurable,
# Effectively uncapped, matching the agent-level ``with_config``
# default in ``chat_deepagent.create_agent`` and the unbounded
# ``while(true)`` in OpenCode's ``session/processor.ts``. Real
# circuit-breakers live in middleware (``DoomLoopMiddleware``,
# plus ``enable_tool_call_limit`` / ``enable_model_call_limit``).
# The original 25 (and our previous 80 bump) hit users on
# legitimate multi-tool plans.
"recursion_limit": 10_000,
}
# --- Block 4: First SSE frames ---
for sse in iter_initial_frames(streaming_service, turn_id=stream_result.turn_id):
yield sse
# --- Block 5: Persistence join + message-id frames ---
user_message_id = await await_persist_task(
persist_user_task,
chat_id=chat_id,
turn_id=stream_result.turn_id,
log_label="persist_user_task",
)
if user_message_id is None:
yield emit_stream_error(
message="We couldn't save your message. Please try again in a moment.",
error_kind="server_error",
error_code="MESSAGE_PERSIST_FAILED",
)
for sse in iter_final_frames(streaming_service):
yield sse
return
# Emit canonical user message id BEFORE any LLM streaming so the FE
# can rename its optimistic ``msg-user-XXX`` placeholder to
# ``msg-{user_message_id}`` and unlock features gated on a real DB id
# (comments, edit-from-this-message). See B4 in the
# ``sse-based_message_id_handshake`` plan.
yield streaming_service.format_data(
"user-message-id",
{"message_id": user_message_id, "turn_id": stream_result.turn_id},
)
assistant_message_id = await await_persist_task(
persist_asst_task,
chat_id=chat_id,
turn_id=stream_result.turn_id,
log_label="persist_asst_task",
)
if assistant_message_id is None:
# Genuine DB failure — abort the turn rather than stream into a
# void. The user row is already persisted so the legacy
# ghost-thread gate isn't reopened.
yield emit_stream_error(
message=(
"We couldn't initialize the assistant message. Please try again."
),
error_kind="server_error",
error_code="MESSAGE_PERSIST_FAILED",
)
for sse in iter_final_frames(streaming_service):
yield sse
return
yield streaming_service.format_data(
"assistant-message-id",
{"message_id": assistant_message_id, "turn_id": stream_result.turn_id},
)
stream_result.assistant_message_id = assistant_message_id
stream_result.content_builder = AssistantContentBuilder()
# --- Block 6: Initial thinking step + title task + runtime context ---
initial_step = build_initial_thinking_step(
user_query=user_query,
user_image_data_urls=user_image_data_urls,
mentioned_surfsense_docs=mentioned_surfsense_docs,
)
for sse in iter_initial_thinking_step_frame(
initial_step,
streaming_service=streaming_service,
content_builder=stream_result.content_builder,
):
yield sse
initial_step_id = initial_step.step_id
initial_step_title = initial_step.title
initial_step_items = initial_step.items
# Drop the heavy ORM objects + the container that holds them so they
# aren't retained for the entire streaming duration. ``input_state``
# already carries the langchain_messages list independently.
del assembled, mentioned_surfsense_docs
title_task = spawn_title_task(
chat_id=chat_id,
user_query=user_query,
user_image_data_urls=user_image_data_urls,
assistant_message_id=assistant_message_id,
llm=llm,
agent_config=agent_config,
)
title_emitted = False
runtime_context = build_new_chat_runtime_context(
search_space_id=search_space_id,
mentioned_document_ids=mentioned_document_ids,
accepted_folder_ids=accepted_folder_ids,
mentioned_folder_ids=mentioned_folder_ids,
request_id=request_id,
turn_id=stream_result.turn_id,
)
# --- Block 7: Stream loop ---
_t_stream_start = time.perf_counter()
runtime_rate_limit_recovered = False
def _on_first_event() -> None:
_perf_log.info(
"[stream_new_chat] First agent event in %.3fs (time since stream start), "
"%.3fs (total since request start) (chat_id=%s)",
time.perf_counter() - _t_stream_start,
time.perf_counter() - _t_total,
chat_id,
)
async def _recover(exc: BaseException, first_event_seen: bool):
nonlocal llm_config_id, llm, agent_config, runtime_rate_limit_recovered
nonlocal title_task
if not can_recover_provider_rate_limit(
exc,
first_event_seen=first_event_seen,
runtime_rate_limit_recovered=runtime_rate_limit_recovered,
requested_llm_config_id=requested_llm_config_id,
current_llm_config_id=llm_config_id,
):
return None
runtime_rate_limit_recovered = True
previous_config_id = llm_config_id
llm_config_id = await reroute_to_next_auto_pin(
session,
chat_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
current_llm_config_id=llm_config_id,
requires_image_input=requires_image_input,
)
new_llm, new_agent_config, llm_load_err = await load_llm_bundle(
session, config_id=llm_config_id, search_space_id=search_space_id
)
if llm_load_err:
# Re-raise the original so the terminal-error path classifies
# it correctly (don't swallow as "config load error").
return None
llm = new_llm
agent_config = new_agent_config
# Title gen used the initial llm object. After a runtime repin we
# keep the stream focused on response recovery and skip title gen
# for this turn.
if title_task is not None and not title_task.done():
title_task.cancel()
title_task = None
_t_rebuild = time.perf_counter()
new_agent = await build_main_agent_for_thread(
agent_factory,
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
)
_perf_log.info(
"[stream_new_chat] Runtime rate-limit recovery repinned "
"config_id=%s -> %s and rebuilt agent in %.3fs",
previous_config_id,
llm_config_id,
time.perf_counter() - _t_rebuild,
)
log_rate_limit_recovered(
flow=flow,
request_id=request_id,
chat_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
previous_config_id=previous_config_id,
new_config_id=llm_config_id,
)
return new_agent
async for sse in run_stream_loop(
agent=agent,
streaming_service=streaming_service,
config=config,
input_data=input_state,
stream_result=stream_result,
step_prefix="thinking",
initial_step_id=initial_step_id,
initial_step_title=initial_step_title,
initial_step_items=initial_step_items,
fallback_commit_search_space_id=search_space_id,
fallback_commit_created_by_id=user_id,
fallback_commit_filesystem_mode=(
filesystem_selection.mode if filesystem_selection else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
runtime_context=runtime_context,
content_builder=stream_result.content_builder,
recover=_recover,
on_first_event=_on_first_event,
):
yield sse
# Inject the title update mid-stream as soon as the background
# task finishes; gated so we emit at most once.
async for title_sse in maybe_emit_title_update(
title_task=title_task,
title_emitted=title_emitted,
chat_id=chat_id,
accumulator=accumulator,
streaming_service=streaming_service,
):
yield title_sse
title_emitted = True
# Account for the case where the task completed but produced no
# title — flip the flag anyway so we don't keep checking it.
if (
title_task is not None
and title_task.done()
and not title_emitted
):
title_emitted = True
_perf_log.info(
"[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)",
time.perf_counter() - _t_stream_start,
chat_id,
)
log_system_snapshot("stream_new_chat_END")
# --- Block 8: Finalize ---
if stream_result.is_interrupted:
ot.add_event("chat.interrupted", {"chat.flow": flow})
if title_task is not None and not title_task.done():
title_task.cancel()
for sse in iter_token_usage_frame(
streaming_service,
accumulator=accumulator,
log_label="interrupted new_chat",
):
yield sse
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
return
async for title_sse in await_pending_title_update(
title_task=title_task,
title_emitted=title_emitted,
chat_id=chat_id,
accumulator=accumulator,
streaming_service=streaming_service,
):
yield title_sse
# Finalize premium credit debit with the actual provider cost reported
# by LiteLLM, summed across every call in the turn. Mirrors the
# pre-cost behaviour of "premium turn → all calls count" so free
# sub-agent calls during a premium turn still contribute to the bill
# (they're $0 in practice anyway).
if premium_reservation is not None and user_id:
await finalize_premium(
reservation=premium_reservation,
user_id=user_id,
accumulator=accumulator,
)
premium_reservation = None
for sse in iter_token_usage_frame(
streaming_service, accumulator=accumulator, log_label="normal new_chat"
):
yield sse
for sse in iter_final_frames(streaming_service):
yield sse
except Exception as exc:
frames, summary = handle_terminal_exception(
exc,
flow=flow,
flow_label="chat",
log_prefix="stream_new_chat",
streaming_service=streaming_service,
request_id=request_id,
chat_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
chat_span=chat_span,
)
if summary["busy_error_raised"]:
busy_error_raised = True
chat_outcome = summary["chat_outcome"]
chat_error_category = summary["chat_error_category"]
for sse in frames:
yield sse
finally:
# Shield the ENTIRE async cleanup from anyio cancel-scope cancellation.
# Starlette's BaseHTTPMiddleware uses anyio task groups; on client
# disconnect, it cancels the scope with level-triggered cancellation
# — every unshielded ``await`` would raise CancelledError immediately.
# Without this the very first ``await`` (session.rollback) would
# raise, ``except Exception`` wouldn't catch it (CancelledError is a
# BaseException), and the rest of cleanup — including session.close()
# — would never run.
with anyio.CancelScope(shield=True):
# Authoritative fallback cleanup for lock/cancel state. Middleware
# teardown can be skipped on some client-abort paths.
end_turn(str(chat_id))
if premium_reservation is not None and user_id:
await release_premium(
reservation=premium_reservation, user_id=user_id
)
await close_session_and_clear_ai_responding(session, chat_id)
await finalize_assistant_message(
stream_result=stream_result,
chat_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
accumulator=accumulator,
log_prefix="stream_new_chat",
)
# Persist any sandbox-produced files to local storage so they remain
# downloadable after the Daytona sandbox auto-deletes.
if stream_result and stream_result.sandbox_files:
with contextlib.suppress(Exception):
from app.agents.new_chat.sandbox import (
is_sandbox_enabled,
persist_and_delete_sandbox,
)
if is_sandbox_enabled():
with anyio.CancelScope(shield=True):
await persist_and_delete_sandbox(
chat_id, stream_result.sandbox_files
)
# ``aafter_agent`` doesn't fire on ``interrupt()`` or early bailout.
# Skip on ``BusyError`` (caller never acquired the lock).
if not busy_error_raised:
with contextlib.suppress(Exception):
end_turn(str(chat_id))
_perf_log.info(
"[stream_new_chat] end_turn cleanup (chat_id=%s)", chat_id
)
# Break circular refs held by the agent graph, tools, and LLM
# wrappers so the GC can reclaim them in a single pass.
agent = llm = connector_service = None # noqa: F841
input_state = stream_result = None # noqa: F841
session = None # noqa: F841
run_gc_pass(log_prefix="stream_new_chat", chat_id=chat_id)
close_chat_request_span(
span_cm=chat_span_cm,
span=chat_span,
chat_outcome=chat_outcome,
chat_agent_mode=chat_agent_mode,
flow=flow,
chat_error_category=chat_error_category,
duration_seconds=time.perf_counter() - _t_total,
)

View file

@ -0,0 +1,129 @@
"""Concurrent persistence tasks spawned right after the initial validation gate.
These run *during* the rest of the pre-stream setup so we don't serialize
their latency against agent construction. Awaiting them at the SSE message-id
yield sites preserves the ghost-thread protection (the user-row INSERT must
succeed before any LLM streaming begins).
The ``set_ai_responding`` flag flip runs fully fire-and-forget on its own
shielded session failures only delay the "AI is responding…" UI flag, not
the response itself.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any
from uuid import UUID
from app.db import shielded_async_session
from app.services.chat_session_state_service import set_ai_responding
from app.tasks.chat.persistence import (
persist_assistant_shell,
persist_user_turn,
)
logger = logging.getLogger(__name__)
def spawn_set_ai_responding_bg(
*,
chat_id: int,
user_id: str | None,
background_tasks: set[asyncio.Task[Any]],
) -> None:
"""Fire-and-forget: flip the per-thread AI-responding flag on its own session.
Errors are swallowed and logged the worst case is a stale UI flag, which
is preferable to delaying the SSE stream behind a flag write.
"""
if not user_id:
return
async def _bg_set_ai_responding() -> None:
try:
async with shielded_async_session() as s:
await set_ai_responding(s, chat_id, UUID(user_id))
except Exception:
logger.warning(
"set_ai_responding failed (chat_id=%s)",
chat_id,
exc_info=True,
)
t = asyncio.create_task(_bg_set_ai_responding())
background_tasks.add(t)
t.add_done_callback(background_tasks.discard)
def spawn_persist_user_task(
*,
chat_id: int,
user_id: str | None,
turn_id: str,
user_query: str,
user_image_data_urls: list[str] | None,
mentioned_documents: list[dict[str, Any]] | None,
background_tasks: set[asyncio.Task[Any]],
) -> asyncio.Task[int | None]:
"""Spawn the user-row INSERT; await at the user-message-id yield site."""
task = asyncio.create_task(
persist_user_turn(
chat_id=chat_id,
user_id=user_id,
turn_id=turn_id,
user_query=user_query,
user_image_data_urls=user_image_data_urls,
mentioned_documents=mentioned_documents,
)
)
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)
return task
def spawn_persist_assistant_shell_task(
*,
chat_id: int,
user_id: str | None,
turn_id: str,
background_tasks: set[asyncio.Task[Any]],
) -> asyncio.Task[int | None]:
"""Spawn the assistant-shell INSERT; await at the assistant-message-id yield site."""
task = asyncio.create_task(
persist_assistant_shell(
chat_id=chat_id,
user_id=user_id,
turn_id=turn_id,
)
)
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)
return task
async def await_persist_task(
task: asyncio.Task[int | None] | None,
*,
chat_id: int,
turn_id: str,
log_label: str,
) -> int | None:
"""Join a spawned persistence task with ``shield`` + uniform error handling.
``shield`` keeps the DB write alive if the SSE generator is cancelled by
client disconnect mid-await. Returns ``None`` on failure; the caller
abort-paths the turn with a friendly error SSE.
"""
if task is None:
return None
try:
return await asyncio.shield(task)
except asyncio.CancelledError:
raise
except Exception:
logger.exception(
"%s failed (chat_id=%s, turn_id=%s)", log_label, chat_id, turn_id
)
return None

View file

@ -0,0 +1,38 @@
"""Build the per-invocation ``SurfSenseContextSchema`` for a new-chat turn.
Carries the per-turn read inputs that middlewares read via
``runtime.context.*`` instead of from their ``__init__`` closures, so the same
compiled-agent instance can serve multiple turns with different
mention lists / request ids / turn ids without rebuilding the graph.
"""
from __future__ import annotations
from app.agents.new_chat.context import SurfSenseContextSchema
def build_new_chat_runtime_context(
*,
search_space_id: int,
mentioned_document_ids: list[int] | None,
accepted_folder_ids: list[int],
mentioned_folder_ids: list[int] | None,
request_id: str | None,
turn_id: str,
) -> SurfSenseContextSchema:
"""``mentioned_document_ids`` is consumed by ``KnowledgePriorityMiddleware``.
``accepted_folder_ids`` (post-resolve) wins over the raw
``mentioned_folder_ids`` from the request: the resolver drops chips that
pointed at deleted folders or folders the caller can't see, so middlewares
only get authorized ids.
"""
return SurfSenseContextSchema(
search_space_id=search_space_id,
mentioned_document_ids=list(mentioned_document_ids or []),
mentioned_folder_ids=list(
accepted_folder_ids or mentioned_folder_ids or []
),
request_id=request_id,
turn_id=turn_id,
)

View file

@ -0,0 +1,237 @@
"""Background thread-title generation (first-response only).
The first assistant response in a thread gets a short auto-generated title
inserted into ``new_chat_threads.title``. We:
1. Spawn the generation as an ``asyncio.Task`` so it runs in parallel with
the agent stream (no extra TTFT).
2. Probe inside the task (on its own shielded session) whether this is
actually the first response newer turns short-circuit to ``None``.
3. Inject the resulting ``thread-title-update`` SSE frame on the first agent
event after the task completes (mid-stream interlock), or right before
the finish frames (post-stream join) if the task hadn't finished yet.
Usage tokens come directly off the response (LiteLLM's async callback fires
via fire-and-forget ``create_task``, so the ``TokenTrackingCallback`` would
run too late). We also blank the per-task accumulator so the late callback
doesn't double-count.
"""
from __future__ import annotations
import asyncio
import logging
from typing import TYPE_CHECKING, Any
from sqlalchemy.future import select
from app.db import NewChatMessage, NewChatThread, shielded_async_session
from app.prompts import TITLE_GENERATION_PROMPT
from app.services.new_streaming_service import VercelStreamingService
if TYPE_CHECKING:
from app.agents.new_chat.llm_config import AgentConfig
from app.services.token_tracking_service import TokenAccumulator
logger = logging.getLogger(__name__)
def spawn_title_task(
*,
chat_id: int,
user_query: str,
user_image_data_urls: list[str] | None,
assistant_message_id: int | None,
llm: Any,
agent_config: AgentConfig | None,
) -> asyncio.Task[tuple[str | None, dict | None]] | None:
"""Spawn ``_generate_title``; returns ``None`` when prerequisites aren't met.
Title gen is gated on a real ``assistant_message_id`` so a stream that
aborts before persistence can never leave a thread with a title and no
anchoring rows.
"""
if assistant_message_id is None:
return None
return asyncio.create_task(
_generate_title(
chat_id=chat_id,
user_query=user_query,
user_image_data_urls=user_image_data_urls,
assistant_message_id=assistant_message_id,
llm=llm,
agent_config=agent_config,
)
)
async def _generate_title(
*,
chat_id: int,
user_query: str,
user_image_data_urls: list[str] | None,
assistant_message_id: int,
llm: Any,
agent_config: AgentConfig | None,
) -> tuple[str | None, dict | None]:
"""Probe is-first-response, then call ``acompletion``. Returns ``(title, usage)``."""
try:
from litellm import acompletion
from app.services.llm_router_service import LLMRouterService
from app.services.provider_api_base import resolve_api_base
from app.services.token_tracking_service import _turn_accumulator
# Excludes this turn's own assistant row (pre-written by
# ``persist_assistant_shell``) — without the ``!=`` filter the gate
# would false-negative on every turn after the first.
try:
async with shielded_async_session() as probe_session:
probe_result = await probe_session.execute(
select(NewChatMessage.id)
.filter(
NewChatMessage.thread_id == chat_id,
NewChatMessage.role == "assistant",
NewChatMessage.id != assistant_message_id,
)
.limit(1)
)
is_first_response = probe_result.scalars().first() is None
except Exception:
logger.warning(
"[TitleGen] first-response probe failed (chat_id=%s)",
chat_id,
exc_info=True,
)
return None, None
if not is_first_response:
return None, None
_turn_accumulator.set(None)
title_seed = user_query.strip() or (
f"[{len(user_image_data_urls or [])} image(s)]"
if user_image_data_urls
else ""
)
prompt = TITLE_GENERATION_PROMPT.replace(
"{user_query}", title_seed[:500] or "(message)"
)
messages = [{"role": "user", "content": prompt}]
if getattr(llm, "model", None) == "auto":
router = LLMRouterService.get_router()
response = await router.acompletion(model="auto", messages=messages)
else:
# Apply the same ``api_base`` cascade chat / vision / image-gen
# call sites use so we never inherit ``litellm.api_base``
# (commonly set by ``AZURE_OPENAI_ENDPOINT``) when the chat
# config itself ships an empty ``api_base``. Without this the
# title-gen on an OpenRouter chat config would 404 against the
# inherited Azure endpoint — see ``provider_api_base`` for the
# same bug repro on the image-gen / vision paths.
raw_model = getattr(llm, "model", "") or ""
provider_prefix = (
raw_model.split("/", 1)[0] if "/" in raw_model else None
)
provider_value = (
agent_config.provider if agent_config is not None else None
)
title_api_base = resolve_api_base(
provider=provider_value,
provider_prefix=provider_prefix,
config_api_base=getattr(llm, "api_base", None),
)
response = await acompletion(
model=raw_model,
messages=messages,
api_key=getattr(llm, "api_key", None),
api_base=title_api_base,
)
usage_info = None
usage = getattr(response, "usage", None)
if usage:
raw_model = getattr(llm, "model", "") or ""
model_name = (
raw_model.split("/", 1)[-1]
if "/" in raw_model
else (raw_model or response.model or "unknown")
)
usage_info = {
"model": model_name,
"prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0,
"completion_tokens": getattr(usage, "completion_tokens", 0) or 0,
"total_tokens": getattr(usage, "total_tokens", 0) or 0,
}
raw_title = response.choices[0].message.content.strip()
if raw_title and len(raw_title) <= 100:
return raw_title.strip("\"'"), usage_info
return None, usage_info
except Exception:
logger.exception("[TitleGen] _generate_title failed")
return None, None
async def maybe_emit_title_update(
*,
title_task: asyncio.Task[tuple[str | None, dict | None]] | None,
title_emitted: bool,
chat_id: int,
accumulator: TokenAccumulator,
streaming_service: VercelStreamingService,
):
"""Inject one ``thread-title-update`` SSE if the task completed.
Yields the SSE frame (when applicable). Returns nothing; the orchestrator
flips ``title_emitted`` itself after iterating so we don't fight Python's
nonlocal-in-generator semantics.
"""
if title_task is None or title_emitted or not title_task.done():
return
generated_title, title_usage = title_task.result()
if title_usage:
accumulator.add(**title_usage)
if generated_title:
async with shielded_async_session() as title_session:
title_thread_result = await title_session.execute(
select(NewChatThread).filter(NewChatThread.id == chat_id)
)
title_thread = title_thread_result.scalars().first()
if title_thread:
title_thread.title = generated_title
await title_session.commit()
yield streaming_service.format_thread_title_update(chat_id, generated_title)
async def await_pending_title_update(
*,
title_task: asyncio.Task[tuple[str | None, dict | None]] | None,
title_emitted: bool,
chat_id: int,
accumulator: TokenAccumulator,
streaming_service: VercelStreamingService,
):
"""If the task hadn't completed during the stream, await it now and emit.
Used right before the finish frames in the success path. Mirror of
``maybe_emit_title_update`` but unconditionally awaits.
"""
if title_task is None or title_emitted:
return
generated_title, title_usage = await title_task
if title_usage:
accumulator.add(**title_usage)
if generated_title:
async with shielded_async_session() as title_session:
title_thread_result = await title_session.execute(
select(NewChatThread).filter(NewChatThread.id == chat_id)
)
title_thread = title_thread_result.scalars().first()
if title_thread:
title_thread.title = generated_title
await title_session.commit()
yield streaming_service.format_thread_title_update(chat_id, generated_title)

View file

@ -0,0 +1,12 @@
"""Resume-chat streaming flow.
Public entry point ``stream_resume_chat`` is the slim coroutine in
``orchestrator.py`` that composes the per-concern modules in this folder and
the building blocks under ``flows/shared/``.
"""
from __future__ import annotations
from app.tasks.chat.streaming.flows.resume_chat.orchestrator import stream_resume_chat
__all__ = ["stream_resume_chat"]

View file

@ -0,0 +1,31 @@
"""Pre-write a fresh assistant row for this resume turn.
The original (interrupted) ``stream_new_chat`` invocation already persisted
its own assistant row anchored to a different ``turn_id``; resume allocates a
new ``turn_id`` (per-request, see ``orchestrator``) so we need a separate row
keyed on the same ``(thread_id, turn_id, ASSISTANT)`` invariant.
Idempotent against migration 141's partial unique index — recovers the
existing id on retry.
Resume does NOT emit ``data-user-message-id``: the user row is from the
original interrupted turn (different ``turn_id``) and is never re-persisted
here. See B5 in the ``sse-based_message_id_handshake`` plan.
"""
from __future__ import annotations
from app.tasks.chat.persistence import persist_assistant_shell
async def persist_resume_assistant_shell(
*,
chat_id: int,
user_id: str | None,
turn_id: str,
) -> int | None:
return await persist_assistant_shell(
chat_id=chat_id,
user_id=user_id,
turn_id=turn_id,
)

View file

@ -0,0 +1,629 @@
"""``stream_resume_chat`` — public entry point for a HITL resume turn.
Slim composition layer over the per-concern modules in this folder and the
building blocks under ``flows/shared/``. Mirrors ``stream_new_chat`` but:
* No user-message persistence (the original turn already wrote it).
* No mentions / surfsense-doc / report context assembly (seeded by original).
* No title generation (only fires on first-response).
* Synchronous ``persist_assistant_shell`` call (we have no other in-flight
pre-stream work to overlap it with).
* ``input_data`` is a ``Command(resume=lg_resume_map)`` instead of a
LangChain message list.
"""
from __future__ import annotations
import contextlib
import gc
import logging
import sys
import time
import uuid as _uuid
from collections.abc import AsyncGenerator
from functools import partial
from typing import Any
from uuid import UUID
import anyio
from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
from app.agents.new_chat.middleware.busy_mutex import end_turn
from app.config import config as _app_config
from app.db import ChatVisibility, async_session_maker, shielded_async_session
from app.observability import otel as ot
from app.services.chat_session_state_service import set_ai_responding
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.content_builder import AssistantContentBuilder
from app.tasks.chat.streaming.agent.builder import build_main_agent_for_thread
from app.tasks.chat.streaming.contract.file_contract import log_file_contract
from app.tasks.chat.streaming.errors.emitter import emit_stream_terminal_error
from app.tasks.chat.streaming.flows.resume_chat.assistant_shell import (
persist_resume_assistant_shell,
)
from app.tasks.chat.streaming.flows.resume_chat.resume_routing import (
build_resume_routing,
)
from app.tasks.chat.streaming.flows.resume_chat.runtime_context import (
build_resume_chat_runtime_context,
)
from app.tasks.chat.streaming.flows.shared.assistant_finalize import (
finalize_assistant_message,
)
from app.tasks.chat.streaming.flows.shared.finalize_emit import iter_token_usage_frame
from app.tasks.chat.streaming.flows.shared.finally_cleanup import (
close_session_and_clear_ai_responding,
run_gc_pass,
)
from app.tasks.chat.streaming.flows.shared.first_frames import (
iter_final_frames,
iter_initial_frames,
)
from app.tasks.chat.streaming.flows.shared.llm_bundle import load_llm_bundle
from app.tasks.chat.streaming.flows.shared.pre_stream_setup import (
get_chat_checkpointer,
setup_connector_and_firecrawl,
)
from app.tasks.chat.streaming.flows.shared.premium_quota import (
PremiumReservation,
finalize_premium,
needs_premium_quota,
release_premium,
reserve_premium,
)
from app.tasks.chat.streaming.flows.shared.rate_limit_recovery import (
can_recover_provider_rate_limit,
log_rate_limit_recovered,
reroute_to_next_auto_pin,
)
from app.tasks.chat.streaming.flows.shared.span import (
close_chat_request_span,
open_chat_request_span,
set_agent_mode,
)
from app.tasks.chat.streaming.flows.shared.stream_loop import run_stream_loop
from app.tasks.chat.streaming.flows.shared.terminal_error import (
handle_terminal_exception,
)
from app.tasks.chat.streaming.shared.stream_result import StreamResult
from app.tasks.chat.streaming.shared.utils import resume_step_prefix
from app.utils.perf import get_perf_logger, log_system_snapshot
logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
async def stream_resume_chat(
chat_id: int,
search_space_id: int,
decisions: list[dict],
user_id: str | None = None,
llm_config_id: int = -1,
thread_visibility: ChatVisibility | None = None,
filesystem_selection: FilesystemSelection | None = None,
request_id: str | None = None,
disabled_tools: list[str] | None = None,
) -> AsyncGenerator[str, None]:
"""Resume a paused HITL turn with the user's decisions.
Mirrors ``stream_new_chat`` except for the resume-specific routing of
``decisions`` to per-``tool_call_id`` slices (``build_resume_routing``).
"""
streaming_service = VercelStreamingService()
stream_result = StreamResult()
_t_total = time.perf_counter()
fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud"
fs_platform = (
filesystem_selection.client_platform.value if filesystem_selection else "web"
)
stream_result.request_id = request_id
stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}"
stream_result.filesystem_mode = fs_mode
stream_result.client_platform = fs_platform
chat_agent_mode = "unknown"
chat_outcome = "success"
chat_error_category: str | None = None
chat_span_cm, chat_span = open_chat_request_span(
chat_id=chat_id,
search_space_id=search_space_id,
flow="resume",
request_id=request_id,
turn_id=stream_result.turn_id,
filesystem_mode=fs_mode,
client_platform=fs_platform,
agent_mode=chat_agent_mode,
)
log_file_contract("turn_start", stream_result)
_perf_log.info(
"[stream_resume] filesystem_mode=%s client_platform=%s",
fs_mode,
fs_platform,
)
from app.services.token_tracking_service import start_turn
accumulator = start_turn()
premium_reservation: PremiumReservation | None = None
busy_error_raised = False
emit_stream_error = partial(
emit_stream_terminal_error,
streaming_service=streaming_service,
flow="resume",
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
)
session = async_session_maker()
try:
if user_id:
await set_ai_responding(session, chat_id, UUID(user_id))
requested_llm_config_id = llm_config_id
# --- LLM config ---
_t0 = time.perf_counter()
try:
from app.services.auto_model_pin_service import (
resolve_or_get_pinned_llm_config_id,
)
pinned = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=llm_config_id,
)
llm_config_id = pinned.resolved_llm_config_id
ot.add_event(
"model.pin.resolved",
{
"pin.requested_id": requested_llm_config_id,
"pin.resolved_id": llm_config_id,
"pin.requires_image_input": False,
},
)
except ValueError as pin_error:
yield emit_stream_error(
message=str(pin_error),
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
llm, agent_config, llm_load_error = await load_llm_bundle(
session, config_id=llm_config_id, search_space_id=search_space_id
)
if llm_load_error:
yield emit_stream_error(
message=llm_load_error,
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
_perf_log.info(
"[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0
)
if needs_premium_quota(agent_config, user_id):
premium_reservation = await reserve_premium(
agent_config=agent_config, user_id=user_id # type: ignore[arg-type]
)
if not premium_reservation.allowed:
ot.add_event(
"quota.denied", {"quota.code": "PREMIUM_QUOTA_EXHAUSTED"}
)
if requested_llm_config_id == 0:
try:
pinned_fb = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=0,
force_repin_free=True,
)
llm_config_id = pinned_fb.resolved_llm_config_id
ot.add_event(
"model.repin",
{
"repin.reason": "premium_quota_exhausted",
"repin.to_config_id": llm_config_id,
},
)
except ValueError as pin_error:
yield emit_stream_error(
message=str(pin_error),
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
llm, agent_config, llm_load_error = await load_llm_bundle(
session,
config_id=llm_config_id,
search_space_id=search_space_id,
)
if llm_load_error:
yield emit_stream_error(
message=llm_load_error,
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
premium_reservation = None
from app.tasks.chat.streaming.errors.classifier import (
log_chat_stream_error,
)
log_chat_stream_error(
flow="resume",
error_kind="premium_quota_exhausted",
error_code="PREMIUM_QUOTA_EXHAUSTED",
severity="info",
is_expected=True,
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
message=(
"Premium quota exhausted on pinned model; "
"auto-fallback switched to a free model"
),
extra={
"fallback_config_id": llm_config_id,
"auto_fallback": True,
},
)
else:
yield emit_stream_error(
message=(
"Buy more tokens to continue with this model, or "
"switch to a free model"
),
error_kind="premium_quota_exhausted",
error_code="PREMIUM_QUOTA_EXHAUSTED",
severity="info",
is_expected=True,
extra={
"resolved_config_id": llm_config_id,
"auto_fallback": False,
},
)
yield streaming_service.format_done()
return
if not llm:
yield emit_stream_error(
message="Failed to create LLM instance",
error_kind="server_error",
error_code="SERVER_ERROR",
)
yield streaming_service.format_done()
return
# --- Pre-stream setup ---
_t0 = time.perf_counter()
connector_service, firecrawl_api_key = await setup_connector_and_firecrawl(
session, search_space_id=search_space_id
)
_perf_log.info(
"[stream_resume] Connector service + firecrawl key in %.3fs",
time.perf_counter() - _t0,
)
_t0 = time.perf_counter()
checkpointer = await get_chat_checkpointer()
_perf_log.info(
"[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0
)
visibility = thread_visibility or ChatVisibility.PRIVATE
use_multi_agent = bool(_app_config.MULTI_AGENT_CHAT_ENABLED)
chat_agent_mode = "multi" if use_multi_agent else "single"
set_agent_mode(chat_span, chat_agent_mode)
_t0 = time.perf_counter()
agent_factory = (
create_multi_agent_chat_deep_agent
if use_multi_agent
else create_surfsense_deep_agent
)
agent = await build_main_agent_for_thread(
agent_factory,
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
)
_perf_log.info(
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
)
# Release the transaction before streaming (same rationale as stream_new_chat).
await session.commit()
session.expunge_all()
_perf_log.info(
"[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)",
time.perf_counter() - _t_total,
chat_id,
)
# --- Resume routing ---
from langgraph.types import Command
routing = await build_resume_routing(
agent, chat_id=chat_id, decisions=decisions
)
config = {
"configurable": {
"thread_id": str(chat_id),
"request_id": request_id or "unknown",
"turn_id": stream_result.turn_id,
# Per-``tool_call_id`` resume slices read by
# ``SurfSenseCheckpointedSubAgentMiddleware``. Parallel
# siblings each pop their own entry, so they never race.
"surfsense_resume_value": routing.routed_resume_value,
},
# Same rationale as ``stream_new_chat``: effectively uncapped to
# mirror the agent default and OpenCode's session loop. Doom-loop
# / call-limit middleware enforce the real ceiling.
"recursion_limit": 10_000,
}
# --- First SSE frames ---
for sse in iter_initial_frames(streaming_service, turn_id=stream_result.turn_id):
yield sse
# --- Assistant-shell persistence + id frame ---
assistant_message_id = await persist_resume_assistant_shell(
chat_id=chat_id,
user_id=user_id,
turn_id=stream_result.turn_id,
)
if assistant_message_id is None:
yield emit_stream_error(
message=(
"We couldn't initialize the assistant message. Please try again."
),
error_kind="server_error",
error_code="MESSAGE_PERSIST_FAILED",
)
for sse in iter_final_frames(streaming_service):
yield sse
return
yield streaming_service.format_data(
"assistant-message-id",
{"message_id": assistant_message_id, "turn_id": stream_result.turn_id},
)
stream_result.assistant_message_id = assistant_message_id
stream_result.content_builder = AssistantContentBuilder()
runtime_context = build_resume_chat_runtime_context(
search_space_id=search_space_id,
request_id=request_id,
turn_id=stream_result.turn_id,
)
# --- Stream loop ---
_t_stream_start = time.perf_counter()
runtime_rate_limit_recovered = False
def _on_first_event() -> None:
_perf_log.info(
"[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
time.perf_counter() - _t_stream_start,
time.perf_counter() - _t_total,
chat_id,
)
async def _recover(exc: BaseException, first_event_seen: bool):
nonlocal llm_config_id, llm, agent_config, runtime_rate_limit_recovered
if not can_recover_provider_rate_limit(
exc,
first_event_seen=first_event_seen,
runtime_rate_limit_recovered=runtime_rate_limit_recovered,
requested_llm_config_id=requested_llm_config_id,
current_llm_config_id=llm_config_id,
):
return None
runtime_rate_limit_recovered = True
previous_config_id = llm_config_id
llm_config_id = await reroute_to_next_auto_pin(
session,
chat_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
current_llm_config_id=llm_config_id,
requires_image_input=False,
)
new_llm, new_agent_config, llm_load_err = await load_llm_bundle(
session, config_id=llm_config_id, search_space_id=search_space_id
)
if llm_load_err:
return None
llm = new_llm
agent_config = new_agent_config
_t_rebuild = time.perf_counter()
new_agent = await build_main_agent_for_thread(
agent_factory,
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
)
_perf_log.info(
"[stream_resume] Runtime rate-limit recovery repinned "
"config_id=%s -> %s and rebuilt agent in %.3fs",
previous_config_id,
llm_config_id,
time.perf_counter() - _t_rebuild,
)
log_rate_limit_recovered(
flow="resume",
request_id=request_id,
chat_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
previous_config_id=previous_config_id,
new_config_id=llm_config_id,
)
return new_agent
async for sse in run_stream_loop(
agent=agent,
streaming_service=streaming_service,
config=config,
input_data=Command(resume=routing.lg_resume_map),
stream_result=stream_result,
step_prefix=resume_step_prefix(stream_result.turn_id),
fallback_commit_search_space_id=search_space_id,
fallback_commit_created_by_id=user_id,
fallback_commit_filesystem_mode=(
filesystem_selection.mode if filesystem_selection else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
runtime_context=runtime_context,
content_builder=stream_result.content_builder,
recover=_recover,
on_first_event=_on_first_event,
):
yield sse
_perf_log.info(
"[stream_resume] Agent stream completed in %.3fs (chat_id=%s)",
time.perf_counter() - _t_stream_start,
chat_id,
)
# --- Finalize ---
if stream_result.is_interrupted:
ot.add_event("chat.interrupted", {"chat.flow": "resume"})
for sse in iter_token_usage_frame(
streaming_service,
accumulator=accumulator,
log_label="interrupted resume_chat",
):
yield sse
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
return
if premium_reservation is not None and user_id:
await finalize_premium(
reservation=premium_reservation,
user_id=user_id,
accumulator=accumulator,
)
premium_reservation = None
for sse in iter_token_usage_frame(
streaming_service, accumulator=accumulator, log_label="normal resume_chat"
):
yield sse
for sse in iter_final_frames(streaming_service):
yield sse
except Exception as exc:
frames, summary = handle_terminal_exception(
exc,
flow="resume",
flow_label="resume",
log_prefix="stream_resume_chat",
streaming_service=streaming_service,
request_id=request_id,
chat_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
chat_span=chat_span,
)
if summary["busy_error_raised"]:
busy_error_raised = True
chat_outcome = summary["chat_outcome"]
chat_error_category = summary["chat_error_category"]
for sse in frames:
yield sse
finally:
with anyio.CancelScope(shield=True):
end_turn(str(chat_id))
if premium_reservation is not None and user_id:
await release_premium(
reservation=premium_reservation, user_id=user_id
)
await close_session_and_clear_ai_responding(session, chat_id)
await finalize_assistant_message(
stream_result=stream_result,
chat_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
accumulator=accumulator,
log_prefix="stream_resume",
)
# Release the lock from the original interrupted turn or any
# re-interrupt/bailout. Skip on ``BusyError`` (lock not held here).
if not busy_error_raised:
with contextlib.suppress(Exception):
end_turn(str(chat_id))
_perf_log.info(
"[stream_resume] end_turn cleanup (chat_id=%s)", chat_id
)
agent = llm = connector_service = None # noqa: F841
stream_result = None # noqa: F841
session = None # noqa: F841
run_gc_pass(log_prefix="stream_resume", chat_id=chat_id)
close_chat_request_span(
span_cm=chat_span_cm,
span=chat_span,
chat_outcome=chat_outcome,
chat_agent_mode=chat_agent_mode,
flow="resume",
chat_error_category=chat_error_category,
duration_seconds=time.perf_counter() - _t_total,
)

View file

@ -0,0 +1,65 @@
"""Route a flat ``decisions`` list back to the right paused subagent.
Each pending interrupt is stamped with its originating ``tool_call_id`` (see
``checkpointed_subagent_middleware.propagation``) so the resume slicer can
re-target each ``HumanReview`` decision at the right ``tool_call_id``.
LangGraph rejects scalar ``Command(resume=...)`` when multiple interrupts are
pending (parallel HITL); the mapped form works for the single-pause case too,
so we always use it.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Any
from app.utils.perf import get_perf_logger
_perf_log = get_perf_logger()
logger = logging.getLogger(__name__)
@dataclass
class ResumeRoutingPayload:
"""Resolved per-``tool_call_id`` resume slices + the lg-shaped resume map."""
routed_resume_value: dict[str, Any]
lg_resume_map: dict[str, Any]
async def build_resume_routing(
agent: Any,
*,
chat_id: int,
decisions: list[dict],
) -> ResumeRoutingPayload:
"""Read parent_state, collect pending tool-calls, slice decisions, build map.
The middleware reads its per-``tool_call_id`` resume slice from the
``surfsense_resume_value`` configurable; parallel siblings each pop their
own entry so they never race.
"""
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
build_lg_resume_map,
collect_pending_tool_calls,
slice_decisions_by_tool_call,
)
parent_state = await agent.aget_state(
{"configurable": {"thread_id": str(chat_id)}}
)
pending = collect_pending_tool_calls(parent_state)
_perf_log.info(
"[hitl_route] resume_entry chat_id=%s decisions=%d pending_subagents=%d",
chat_id,
len(decisions),
len(pending),
)
routed_resume_value = slice_decisions_by_tool_call(decisions, pending)
lg_resume_map = build_lg_resume_map(parent_state, routed_resume_value)
return ResumeRoutingPayload(
routed_resume_value=routed_resume_value,
lg_resume_map=lg_resume_map,
)

View file

@ -0,0 +1,23 @@
"""Build the per-invocation ``SurfSenseContextSchema`` for a resume turn.
Resume doesn't carry new ``mentioned_document_ids`` (those are seeded by the
original turn). We still build the context so future middleware extensions
can rely on ``runtime.context`` always being populated.
"""
from __future__ import annotations
from app.agents.new_chat.context import SurfSenseContextSchema
def build_resume_chat_runtime_context(
*,
search_space_id: int,
request_id: str | None,
turn_id: str,
) -> SurfSenseContextSchema:
return SurfSenseContextSchema(
search_space_id=search_space_id,
request_id=request_id,
turn_id=turn_id,
)

View file

@ -0,0 +1,3 @@
"""Building blocks shared by ``new_chat`` and ``resume_chat`` orchestrators."""
from __future__ import annotations

View file

@ -0,0 +1,109 @@
"""Server-side assistant-message + token_usage finalization.
Runs inside the streaming flow's ``finally`` block, after the main session has
been closed (uses its own shielded session, so we don't fight the same DB
connection).
Idempotent against the legacy frontend ``appendMessage`` recovery branch:
* the assistant row was already INSERTed by ``persist_assistant_shell``
earlier in the turn, so this just UPDATEs it with the rich
``ContentPart[]`` projection from the builder.
* ``token_usage`` uses ``INSERT ... ON CONFLICT DO NOTHING`` against the
partial unique index from migration 142, so a racing append_message
recovery branch can never double-write.
``mark_interrupted`` closes any open text/reasoning blocks and flips running
tool-calls (no result) to ``state=aborted`` so the persisted JSONB reflects a
coherent end-state even on client disconnect.
Never raises (best-effort, logs only).
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from app.tasks.chat.streaming.shared.stream_result import StreamResult
from app.utils.perf import get_perf_logger
if TYPE_CHECKING:
from app.services.token_tracking_service import TokenAccumulator
_perf_log = get_perf_logger()
async def finalize_assistant_message(
*,
stream_result: StreamResult | None,
chat_id: int,
search_space_id: int,
user_id: str | None,
accumulator: TokenAccumulator,
log_prefix: str,
) -> None:
"""Snapshot the content builder and persist the final assistant payload.
No-op when ``stream_result`` was never populated, the turn never reached
``persist_assistant_shell`` (no ``assistant_message_id``), or the turn id
was never assigned.
"""
if not (
stream_result
and stream_result.turn_id
and stream_result.assistant_message_id
):
return
from app.tasks.chat.persistence import finalize_assistant_turn
builder_stats: dict[str, int] | None = None
if stream_result.content_builder is not None:
stream_result.content_builder.mark_interrupted()
# Snapshot stats BEFORE ``snapshot()`` deepcopies so the perf log
# records the actual finalised payload (post-mark_interrupted), not
# the live-mutating builder state.
builder_stats = stream_result.content_builder.stats()
content_payload = stream_result.content_builder.snapshot()
else:
# Defensive fallback — we always set the builder alongside
# ``assistant_message_id`` in the orchestrator, so this branch only
# fires if a future refactor ever decouples them. Persist whatever
# accumulated text we captured so the row at least renders.
content_payload = [
{
"type": "text",
"text": stream_result.accumulated_text or "",
}
]
if builder_stats is not None:
_perf_log.info(
"[%s] finalize_payload chat_id=%s "
"message_id=%s parts=%d bytes=%d text=%d "
"reasoning=%d tool_calls=%d "
"tool_calls_completed=%d tool_calls_aborted=%d "
"thinking_step_parts=%d step_separators=%d",
log_prefix,
chat_id,
stream_result.assistant_message_id,
builder_stats["parts"],
builder_stats["bytes"],
builder_stats["text"],
builder_stats["reasoning"],
builder_stats["tool_calls"],
builder_stats["tool_calls_completed"],
builder_stats["tool_calls_aborted"],
builder_stats["thinking_step_parts"],
builder_stats["step_separators"],
)
await finalize_assistant_turn(
message_id=stream_result.assistant_message_id,
chat_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
turn_id=stream_result.turn_id,
content=content_payload,
accumulator=accumulator,
)

View file

@ -0,0 +1,54 @@
"""Emit the per-turn token-usage SSE frame from the accumulator.
``per_message_summary()`` returns ``None`` when the turn made no chargeable
LLM calls (e.g. interrupt-on-input). In that case we skip the frame; the
frontend has no usage to render.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from app.services.new_streaming_service import VercelStreamingService
from app.utils.perf import get_perf_logger
if TYPE_CHECKING:
from app.services.token_tracking_service import TokenAccumulator
_perf_log = get_perf_logger()
logger = logging.getLogger(__name__)
def iter_token_usage_frame(
streaming_service: VercelStreamingService,
*,
accumulator: TokenAccumulator,
log_label: str,
):
"""Yield zero or one ``data: token-usage`` SSE frame.
Side effect: logs a one-line ``[token_usage] {log_label}: ...`` summary so
cost analysis can grep call/total/cost across all flows.
"""
usage_summary = accumulator.per_message_summary()
_perf_log.info(
"[token_usage] %s: calls=%d total=%d cost_micros=%d summary=%s",
log_label,
len(accumulator.calls),
accumulator.grand_total,
accumulator.total_cost_micros,
usage_summary,
)
if usage_summary:
yield streaming_service.format_data(
"token-usage",
{
"usage": usage_summary,
"prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total,
"cost_micros": accumulator.total_cost_micros,
"call_details": accumulator.serialized_calls(),
},
)

View file

@ -0,0 +1,69 @@
"""Shared finally-block helpers: session close, GC pass, native-heap trim.
These are called from inside an ``anyio.CancelScope(shield=True)`` block in
each flow's ``finally`` (Starlette's BaseHTTPMiddleware cancels the scope on
client disconnect; without the shield the very first ``await`` would raise
``CancelledError`` and the rest of cleanup including ``session.close()``
would never run).
"""
from __future__ import annotations
import contextlib
import gc
import logging
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import shielded_async_session
from app.services.chat_session_state_service import clear_ai_responding
from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
_perf_log = get_perf_logger()
logger = logging.getLogger(__name__)
async def close_session_and_clear_ai_responding(
session: AsyncSession, chat_id: int
) -> None:
"""Rollback + clear AI-responding flag + expunge_all + close.
On rollback failure we fall back to a fresh shielded session for the flag
clear so a UI is never stuck on "AI is responding…" after a crash.
"""
try:
await session.rollback()
await clear_ai_responding(session, chat_id)
except Exception:
try:
async with shielded_async_session() as fresh_session:
await clear_ai_responding(fresh_session, chat_id)
except Exception:
logger.warning(
"Failed to clear AI responding state for thread %s", chat_id
)
with contextlib.suppress(Exception):
session.expunge_all()
with contextlib.suppress(Exception):
await session.close()
def run_gc_pass(*, log_prefix: str, chat_id: int) -> None:
"""One full gen0/1/2 pass + native-heap trim + END system snapshot.
Breaking circular refs held by the agent graph, tools, and LLM wrappers
needs to happen in the caller (set the locals to ``None``) this just
runs the collector and logs how many objects came back.
"""
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
if collected:
_perf_log.info(
"[%s] gc.collect() reclaimed %d objects (chat_id=%s)",
log_prefix,
collected,
chat_id,
)
trim_native_heap()
log_system_snapshot(f"{log_prefix}_END")

View file

@ -0,0 +1,40 @@
"""Initial SSE frames every flow emits right after pre-stream setup.
Order matters: ``message_start`` opens the assistant message, ``start_step``
opens the first thinking step, ``turn-info`` lets the frontend stamp the
correlation id onto the in-flight message, and ``turn-status: busy`` flips the
UI into the streaming state.
"""
from __future__ import annotations
from collections.abc import Iterator
from app.services.new_streaming_service import VercelStreamingService
def iter_initial_frames(
streaming_service: VercelStreamingService,
*,
turn_id: str,
) -> Iterator[str]:
"""Yield the four canonical opening frames in order.
``turn-info`` carries ``chat_turn_id`` so even pure-text turns (which
never produce a tool / action-log event) still teach the frontend the
turn correlation id used for ``appendMessage`` durable storage.
"""
yield streaming_service.format_message_start()
yield streaming_service.format_start_step()
yield streaming_service.format_data("turn-info", {"chat_turn_id": turn_id})
yield streaming_service.format_data("turn-status", {"status": "busy"})
def iter_final_frames(
streaming_service: VercelStreamingService,
) -> Iterator[str]:
"""Yield ``turn-status: idle`` plus the finish/done trailer in order."""
yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()

View file

@ -0,0 +1,57 @@
"""Load an LLM + AgentConfig bundle for a given config id.
Handles both code paths uniformly:
- ``config_id >= 0`` database-backed ``NewLLMConfig`` row (per-user/per-space).
- ``config_id < 0`` YAML-defined global LLM config (built-in defaults).
Returns ``(llm, agent_config, error_message)``; on success ``error_message`` is
``None``. The caller emits the friendly SSE error frame.
"""
from __future__ import annotations
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.llm_config import (
AgentConfig,
create_chat_litellm_from_agent_config,
create_chat_litellm_from_config,
load_agent_config,
load_global_llm_config_by_id,
)
async def load_llm_bundle(
session: AsyncSession,
*,
config_id: int,
search_space_id: int,
) -> tuple[Any, AgentConfig | None, str | None]:
if config_id >= 0:
loaded_agent_config = await load_agent_config(
session=session,
config_id=config_id,
search_space_id=search_space_id,
)
if not loaded_agent_config:
return (
None,
None,
f"Failed to load NewLLMConfig with id {config_id}",
)
return (
create_chat_litellm_from_agent_config(loaded_agent_config),
loaded_agent_config,
None,
)
loaded_llm_config = load_global_llm_config_by_id(config_id)
if not loaded_llm_config:
return None, None, f"Failed to load LLM config with id {config_id}"
return (
create_chat_litellm_from_config(loaded_llm_config),
AgentConfig.from_yaml_config(loaded_llm_config),
None,
)

View file

@ -0,0 +1,40 @@
"""Pre-stream setup: connector service, firecrawl key, checkpointer."""
from __future__ import annotations
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.checkpointer import get_checkpointer
from app.db import SearchSourceConnectorType
from app.services.connector_service import ConnectorService
async def setup_connector_and_firecrawl(
session: AsyncSession,
*,
search_space_id: int,
) -> tuple[ConnectorService, str | None]:
"""Build the per-turn connector service and pull the firecrawl API key.
Returns ``(connector_service, firecrawl_api_key)``. ``firecrawl_api_key`` is
``None`` when no web-crawler connector is configured (the agent simply
skips firecrawl-backed tools in that case).
"""
connector_service = ConnectorService(session, search_space_id=search_space_id)
firecrawl_api_key: str | None = None
webcrawler_connector = await connector_service.get_connector_by_type(
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
)
if webcrawler_connector and webcrawler_connector.config:
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
return connector_service, firecrawl_api_key
async def get_chat_checkpointer():
"""Resolve the PostgreSQL checkpointer for persistent conversation memory.
Thin wrapper around ``app.agents.new_chat.checkpointer.get_checkpointer`` so
flow orchestrators can rely on a streaming-local symbol and we have a hook
point if the checkpointer source ever needs to vary per flow.
"""
return await get_checkpointer()

View file

@ -0,0 +1,132 @@
"""Premium credit (USD micro-units) reserve / finalize / release lifecycle.
Both ``stream_new_chat`` and ``stream_resume_chat`` reserve premium credits up
front (so a single LLM call can't run away with the budget), then finalize the
actual provider cost reported by LiteLLM when the turn completes successfully,
or release the reservation on the cancellation / interrupted-without-finalize
paths.
State is held by the orchestrator as a simple ``PremiumReservation`` tuple
so reservation, fallback-on-denied, finalize, and release can all be reasoned
about from one place.
"""
from __future__ import annotations
import logging
import uuid as _uuid
from dataclasses import dataclass
from typing import TYPE_CHECKING
from uuid import UUID
from app.agents.new_chat.llm_config import AgentConfig
from app.db import shielded_async_session
if TYPE_CHECKING:
from app.services.token_tracking_service import TokenAccumulator
@dataclass
class PremiumReservation:
"""Active premium-credit reservation for one turn.
``request_id`` is the per-reservation idempotency key (also passed to
``finalize``/``release`` so racing branches resolve to the same row).
``reserved_micros`` is the up-front estimate; ``finalize`` debits the
actual cost, ``release`` returns it untouched.
"""
request_id: str
reserved_micros: int
allowed: bool
def needs_premium_quota(
agent_config: AgentConfig | None, user_id: str | None
) -> bool:
return bool(agent_config is not None and user_id and agent_config.is_premium)
async def reserve_premium(
*,
agent_config: AgentConfig,
user_id: str,
) -> PremiumReservation:
"""Reserve estimated micros up front; returns the reservation handle."""
from app.services.token_quota_service import (
TokenQuotaService,
estimate_call_reserve_micros,
)
request_id = _uuid.uuid4().hex[:16]
litellm_params = agent_config.litellm_params or {}
base_model = (
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
) or agent_config.model_name or ""
reserve_amount_micros = estimate_call_reserve_micros(
base_model=base_model,
quota_reserve_tokens=agent_config.quota_reserve_tokens,
)
async with shielded_async_session() as quota_session:
quota_result = await TokenQuotaService.premium_reserve(
db_session=quota_session,
user_id=UUID(user_id),
request_id=request_id,
reserve_micros=reserve_amount_micros,
)
return PremiumReservation(
request_id=request_id,
reserved_micros=reserve_amount_micros,
allowed=quota_result.allowed,
)
async def finalize_premium(
*,
reservation: PremiumReservation,
user_id: str,
accumulator: TokenAccumulator,
) -> None:
"""Finalize debit using the actual provider cost reported by LiteLLM.
Best-effort: failures here must not bubble up to the SSE stream the user
has already received their tokens; we log and move on.
"""
try:
from app.services.token_quota_service import TokenQuotaService
async with shielded_async_session() as quota_session:
await TokenQuotaService.premium_finalize(
db_session=quota_session,
user_id=UUID(user_id),
request_id=reservation.request_id,
actual_micros=accumulator.total_cost_micros,
reserved_micros=reservation.reserved_micros,
)
except Exception:
logging.getLogger(__name__).warning(
"Failed to finalize premium quota for user %s",
user_id,
exc_info=True,
)
async def release_premium(
*,
reservation: PremiumReservation,
user_id: str,
) -> None:
"""Release the reservation on cancellation paths; never raises."""
try:
from app.services.token_quota_service import TokenQuotaService
async with shielded_async_session() as quota_session:
await TokenQuotaService.premium_release(
db_session=quota_session,
user_id=UUID(user_id),
reserved_micros=reservation.reserved_micros,
)
except Exception:
logging.getLogger(__name__).warning(
"Failed to release premium quota for user %s", user_id
)

View file

@ -0,0 +1,129 @@
"""Shared steps for the in-stream provider rate-limit recovery loop.
Both flows wrap ``run_stream_loop`` with a flow-specific ``recover`` closure;
the *guard*, the *auto-pin reroute*, and the *post-recovery telemetry* are the
same on both sides and live here so behaviour can't drift.
The orchestrator owns the parts that genuinely diverge:
* cancelling the title task (new_chat only),
* passing ``mentioned_document_ids`` to ``build_main_agent_for_thread``,
* the log prefix (``stream_new_chat`` vs ``stream_resume``).
"""
from __future__ import annotations
from typing import Literal
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.middleware.busy_mutex import end_turn
from app.observability import otel as ot
from app.services.auto_model_pin_service import (
mark_runtime_cooldown,
resolve_or_get_pinned_llm_config_id,
)
from app.tasks.chat.streaming.errors.classifier import (
is_provider_rate_limited,
log_chat_stream_error,
)
def can_recover_provider_rate_limit(
exc: BaseException,
*,
first_event_seen: bool,
runtime_rate_limit_recovered: bool,
requested_llm_config_id: int,
current_llm_config_id: int,
) -> bool:
"""Guard: only the first auto-pin → provider-rate-limited failure recovers.
All conditions must hold:
* ``runtime_rate_limit_recovered is False`` at most one recovery per turn.
* ``requested_llm_config_id == 0`` caller opted into auto-pin (id=0).
* ``current_llm_config_id < 0`` currently on a YAML config (the only
kind the auto-pin pool draws from).
* ``first_event_seen is False`` we haven't sent any SSE to the user yet,
so a silent rebuild + retry is invisible.
* The exception is provider-side rate-limited (HTTP 429 or known shape).
"""
return (
not runtime_rate_limit_recovered
and requested_llm_config_id == 0
and current_llm_config_id < 0
and not first_event_seen
and is_provider_rate_limited(exc)
)
async def reroute_to_next_auto_pin(
session: AsyncSession,
*,
chat_id: int,
search_space_id: int,
user_id: str | None,
current_llm_config_id: int,
requires_image_input: bool,
) -> int:
"""Release lock, cool down the failing config, pick a new auto-pin id.
Returns the new ``llm_config_id``. ``end_turn`` is called because the failed
attempt may still hold the per-thread busy mutex (middleware teardown can
lag behind raised provider errors) the same-request retry would otherwise
bounce on ``BusyError``.
"""
end_turn(str(chat_id))
mark_runtime_cooldown(current_llm_config_id, reason="provider_rate_limited")
pinned = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=0,
exclude_config_ids={current_llm_config_id},
requires_image_input=requires_image_input,
)
return pinned.resolved_llm_config_id
def log_rate_limit_recovered(
*,
flow: Literal["new", "regenerate", "resume"],
request_id: str | None,
chat_id: int,
search_space_id: int,
user_id: str | None,
previous_config_id: int,
new_config_id: int,
) -> None:
"""Emit the OTEL event + structured ``[chat_stream_error]`` log line."""
ot.add_event(
"chat.rate_limit.recovered",
{
"recovery.reason": "provider_rate_limited",
"recovery.previous_config_id": previous_config_id,
"recovery.fallback_config_id": new_config_id,
},
)
log_chat_stream_error(
flow=flow,
error_kind="rate_limited",
error_code="RATE_LIMITED",
severity="info",
is_expected=True,
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
message=(
"Auto-pinned model hit runtime rate limit; switched to "
"another eligible model and retried."
),
extra={
"auto_runtime_recover": True,
"previous_config_id": previous_config_id,
"fallback_config_id": new_config_id,
},
)

View file

@ -0,0 +1,80 @@
"""OpenTelemetry chat-request span wrapper for streaming flows."""
from __future__ import annotations
import contextlib
import sys
from typing import Any, Literal
from app.observability import metrics as ot_metrics
from app.observability import otel as ot
def open_chat_request_span(
*,
chat_id: int,
search_space_id: int,
flow: Literal["new", "regenerate", "resume"],
request_id: str | None,
turn_id: str,
filesystem_mode: str,
client_platform: str,
agent_mode: str,
) -> tuple[Any, Any]:
"""Open the per-request span; returns ``(span_cm, span)`` for finally-close."""
span_cm = ot.chat_request_span(
chat_id=chat_id,
search_space_id=search_space_id,
flow=flow,
request_id=request_id,
turn_id=turn_id,
filesystem_mode=filesystem_mode,
client_platform=client_platform,
agent_mode=agent_mode,
)
span = span_cm.__enter__()
return span_cm, span
def set_agent_mode(span: Any, agent_mode: str) -> None:
"""Tag the span with the resolved agent mode (single / multi)."""
with contextlib.suppress(Exception):
span.set_attribute("agent.mode", agent_mode)
def close_chat_request_span(
*,
span_cm: Any,
span: Any,
chat_outcome: str,
chat_agent_mode: str,
flow: Literal["new", "regenerate", "resume"],
chat_error_category: str | None,
duration_seconds: float,
) -> None:
"""Record metrics + close the span. Swallows errors (finally-block context)."""
with contextlib.suppress(Exception):
span.set_attribute("chat.outcome", chat_outcome)
ot_metrics.record_chat_request_duration(
duration_seconds * 1000,
flow=flow,
outcome=chat_outcome,
agent_mode=chat_agent_mode,
)
ot_metrics.record_chat_request_outcome(
flow=flow,
outcome=chat_outcome,
agent_mode=chat_agent_mode,
error_category=chat_error_category,
)
span_cm.__exit__(*sys.exc_info())
def record_outcome_attrs(
span: Any, *, chat_outcome: str, chat_error_category: str | None
) -> None:
"""Stamp outcome + error.category on the span (used in the except branch)."""
with contextlib.suppress(Exception):
span.set_attribute("chat.outcome", chat_outcome)
if chat_error_category is not None:
span.set_attribute("error.category", chat_error_category)

View file

@ -0,0 +1,85 @@
"""Drive ``stream_agent_events`` with in-stream rate-limit recovery.
Both ``stream_new_chat`` and ``stream_resume_chat`` wrap the agent event loop
in a ``while True`` that catches the *first* provider rate-limit error
(``can_runtime_recover``) before any SSE event reaches the user, rebuilds the
agent on an alternative auto-pin, and retries the turn.
The recovery callback is flow-specific (different ``mentioned_document_ids``
contract, different logging label, etc.) this module owns the loop shape,
the caller owns the rebuild.
"""
from __future__ import annotations
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.streaming.agent.event_loop import stream_agent_events
from app.tasks.chat.streaming.shared.stream_result import StreamResult
# Returns the rebuilt agent on a successful recovery, or ``None`` to re-raise
# the original exception (and let the orchestrator's terminal-error path
# handle it).
RecoverFn = Callable[[BaseException, bool], Awaitable[Any | None]]
async def run_stream_loop(
*,
agent: Any,
streaming_service: VercelStreamingService,
config: dict[str, Any],
input_data: Any,
stream_result: StreamResult,
step_prefix: str = "thinking",
initial_step_id: str | None = None,
initial_step_title: str = "",
initial_step_items: list[str] | None = None,
fallback_commit_search_space_id: int | None,
fallback_commit_created_by_id: str | None,
fallback_commit_filesystem_mode: FilesystemMode,
fallback_commit_thread_id: int | None,
runtime_context: Any,
content_builder: Any | None,
recover: RecoverFn,
on_first_event: Callable[[], None] | None = None,
) -> AsyncGenerator[str, None]:
"""Yield SSE frames; rebuild and retry once on a pre-first-event rate limit.
``on_first_event`` fires after the first frame is observed (used by both
flows to write a one-time ``First agent event in N.NNNs`` perf line).
"""
first_event_logged = False
while True:
try:
async for sse in stream_agent_events(
agent=agent,
config=config,
input_data=input_data,
streaming_service=streaming_service,
result=stream_result,
step_prefix=step_prefix,
initial_step_id=initial_step_id,
initial_step_title=initial_step_title,
initial_step_items=initial_step_items,
fallback_commit_search_space_id=fallback_commit_search_space_id,
fallback_commit_created_by_id=fallback_commit_created_by_id,
fallback_commit_filesystem_mode=fallback_commit_filesystem_mode,
fallback_commit_thread_id=fallback_commit_thread_id,
runtime_context=runtime_context,
content_builder=content_builder,
):
if not first_event_logged:
if on_first_event is not None:
on_first_event()
first_event_logged = True
yield sse
return
except Exception as exc:
new_agent = await recover(exc, first_event_logged)
if new_agent is None:
raise
agent = new_agent
continue

View file

@ -0,0 +1,120 @@
"""Handle the ``except Exception`` branch of a streaming flow.
Classifies the exception, records OpenTelemetry attributes, emits one terminal
error SSE frame and the trailing ``turn-status: idle`` + finish/done frames.
Used by both ``stream_new_chat`` and ``stream_resume_chat``; flow-specific bits
(label, span, BusyError tracking) are passed by the caller.
"""
from __future__ import annotations
import logging
import traceback
from collections.abc import Iterator
from typing import Any, Literal
from app.agents.new_chat.errors import BusyError
from app.observability import metrics as ot_metrics
from app.observability import otel as ot
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.streaming.errors.classifier import classify_stream_exception
from app.tasks.chat.streaming.errors.emitter import emit_stream_terminal_error
from app.tasks.chat.streaming.flows.shared.first_frames import iter_final_frames
from app.tasks.chat.streaming.flows.shared.span import record_outcome_attrs
logger = logging.getLogger(__name__)
def handle_terminal_exception(
exc: Exception,
*,
flow: Literal["new", "regenerate", "resume"],
flow_label: str,
log_prefix: str,
streaming_service: VercelStreamingService,
request_id: str | None,
chat_id: int,
search_space_id: int,
user_id: str | None,
chat_span: Any,
) -> tuple[Iterator[str], dict[str, Any]]:
"""Classify, log, and produce the SSE frames for a terminal exception.
Returns ``(frame_iterator, summary)``. ``summary`` carries::
- ``busy_error_raised``: bool caller must skip the lock-release path
when True (caller never acquired the busy mutex).
- ``chat_outcome``: str span outcome attribute.
- ``chat_error_category``: str categorized error label for metrics.
"""
busy_error_raised = isinstance(exc, BusyError)
(
error_kind,
error_code,
severity,
is_expected,
user_message,
error_extra,
) = classify_stream_exception(exc, flow_label=flow_label)
chat_outcome = error_code or error_kind or "error"
chat_error_category = ot_metrics.categorize_exception(exc)
record_outcome_attrs(
chat_span,
chat_outcome=chat_outcome,
chat_error_category=chat_error_category,
)
with __suppress():
ot.record_error(chat_span, exc)
error_message = f"Error during {flow_label}: {exc!s}"
# Match the original behavior: log full traceback via ``print`` so it lands
# in stderr regardless of the logger config.
print(f"[{log_prefix}] {error_message}")
print(f"[{log_prefix}] Exception type: {type(exc).__name__}")
print(f"[{log_prefix}] Traceback:\n{traceback.format_exc()}")
def _iter_frames() -> Iterator[str]:
if error_code == "TURN_CANCELLING":
status_payload: dict[str, Any] = {"status": "cancelling"}
if error_extra:
status_payload.update(error_extra)
yield streaming_service.format_data("turn-status", status_payload)
else:
yield streaming_service.format_data("turn-status", {"status": "busy"})
yield emit_stream_terminal_error(
streaming_service=streaming_service,
flow=flow,
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
message=user_message,
error_kind=error_kind,
error_code=error_code,
severity=severity,
is_expected=is_expected,
extra=error_extra,
)
yield from iter_final_frames(streaming_service)
return (
_iter_frames(),
{
"busy_error_raised": busy_error_raised,
"chat_outcome": chat_outcome,
"chat_error_category": chat_error_category,
},
)
def __suppress():
"""Local single-use ``contextlib.suppress(Exception)`` factory.
Inlined here so callers don't import ``contextlib`` just for the
``record_error`` call site.
"""
import contextlib
return contextlib.suppress(Exception)

View file

@ -0,0 +1,15 @@
"""Shared building blocks used across every streaming flow."""
from __future__ import annotations
from app.tasks.chat.streaming.shared.stream_result import StreamResult
from app.tasks.chat.streaming.shared.utils import (
resume_step_prefix,
safe_float,
)
__all__ = [
"StreamResult",
"resume_step_prefix",
"safe_float",
]

View file

@ -0,0 +1,37 @@
"""Per-turn streaming state shared between the orchestrator and event loop."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
@dataclass
class StreamResult:
accumulated_text: str = ""
is_interrupted: bool = False
sandbox_files: list[str] = field(default_factory=list)
request_id: str | None = None
turn_id: str = ""
filesystem_mode: str = "cloud"
client_platform: str = "web"
intent_detected: str = "chat_only"
intent_confidence: float = 0.0
write_attempted: bool = False
write_succeeded: bool = False
verification_succeeded: bool = False
commit_gate_passed: bool = True
commit_gate_reason: str = ""
# Pre-allocated assistant ``new_chat_messages.id`` for this turn, captured by
# ``persist_assistant_shell`` right after the user row is persisted. ``None``
# for the legacy/anonymous code paths that don't opt in to server-side
# ``ContentPart[]`` projection.
assistant_message_id: int | None = None
# In-memory mirror of the FE's assistant-ui ``ContentPartsState``, populated
# by the lifecycle methods called from the streaming event loop at each
# ``streaming_service.format_*`` yield site. Snapshot in the streaming
# ``finally`` to produce the rich JSONB persisted by
# ``finalize_assistant_turn``. ``repr=False`` keeps the log-on-error path
# (``StreamResult`` is logged in some error branches) from dumping a
# potentially-large parts list.
content_builder: Any | None = field(default=None, repr=False)

View file

@ -0,0 +1,27 @@
"""Small utilities used by streaming orchestrators and phases."""
from __future__ import annotations
from typing import Any
def resume_step_prefix(turn_id: str) -> str:
"""Per-turn ``step_prefix`` for resume invocations.
Each ``stream_agent_events`` call constructs a fresh
``AgentEventRelayState`` with ``thinking_step_counter=0``, so two consecutive
resume turns would otherwise both emit ``thinking-resume-1``, ``-2`` etc.
The frontend rehydrates ``currentThinkingSteps`` from the immediate prior
assistant message at the start of every resume if the new stream's IDs
collide with the seeded ones, React renders sibling Timeline rows with the
same key. Salting with ``turn_id`` guarantees disjoint IDs across resumes
within one thread.
"""
return f"thinking-resume-{turn_id}"
def safe_float(value: Any, default: float = 0.0) -> float:
try:
return float(value)
except (TypeError, ValueError):
return default