mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-29 19:35:20 +02:00
Merge pull request #1443 from CREDO23/feature-automations
[Feat] Automation V1 — Scheduled Agent Tasks, Created via Chat (HITL) or JSON
This commit is contained in:
commit
4dda02c06c
219 changed files with 13821 additions and 55 deletions
|
|
@ -0,0 +1,8 @@
|
|||
"""Agent construction and per-turn event-loop drivers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.tasks.chat.streaming.agent.builder import build_main_agent_for_thread
|
||||
from app.tasks.chat.streaming.agent.event_loop import stream_agent_events
|
||||
|
||||
__all__ = ["build_main_agent_for_thread", "stream_agent_events"]
|
||||
49
surfsense_backend/app/tasks/chat/streaming/agent/builder.py
Normal file
49
surfsense_backend/app/tasks/chat/streaming/agent/builder.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
"""Single per-thread agent (re)build path.
|
||||
|
||||
A graph swap mid-turn would corrupt checkpointer state for the same
|
||||
``thread_id``, so both the initial build and any mid-stream 429 recovery rebuild
|
||||
must funnel through this single function.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemSelection
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
from app.db import ChatVisibility
|
||||
from app.services.connector_service import ConnectorService
|
||||
|
||||
|
||||
async def build_main_agent_for_thread(
|
||||
agent_factory: Any,
|
||||
*,
|
||||
llm: Any,
|
||||
search_space_id: int,
|
||||
db_session: Any,
|
||||
connector_service: ConnectorService,
|
||||
checkpointer: Any,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
agent_config: AgentConfig | None,
|
||||
firecrawl_api_key: str | None,
|
||||
thread_visibility: ChatVisibility | None,
|
||||
filesystem_selection: FilesystemSelection | None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
) -> Any:
|
||||
return await agent_factory(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=db_session,
|
||||
connector_service=connector_service,
|
||||
checkpointer=checkpointer,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=thread_visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
disabled_tools=disabled_tools,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
)
|
||||
175
surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py
Normal file
175
surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
"""Per-turn agent event-loop driver.
|
||||
|
||||
Drives ``stream_output`` (graph_stream relay) for one agent turn, then runs the
|
||||
post-stream agent-state inspection: safety-net commit of any staged filesystem
|
||||
state (in case ``aafter_agent`` was skipped), file-operation contract scoring,
|
||||
intent classification, and interrupt detection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware.kb_persistence import (
|
||||
commit_staged_filesystem_state,
|
||||
)
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.tasks.chat.streaming.contract.file_contract import (
|
||||
contract_enforcement_active,
|
||||
evaluate_file_contract_outcome,
|
||||
log_file_contract,
|
||||
)
|
||||
from app.tasks.chat.streaming.graph_stream.event_stream import stream_output
|
||||
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
|
||||
all_interrupt_values,
|
||||
)
|
||||
from app.tasks.chat.streaming.shared.stream_result import StreamResult
|
||||
from app.tasks.chat.streaming.shared.utils import safe_float
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
async def stream_agent_events(
|
||||
agent: Any,
|
||||
config: dict[str, Any],
|
||||
input_data: Any,
|
||||
streaming_service: VercelStreamingService,
|
||||
result: StreamResult,
|
||||
step_prefix: str = "thinking",
|
||||
initial_step_id: str | None = None,
|
||||
initial_step_title: str = "",
|
||||
initial_step_items: list[str] | None = None,
|
||||
*,
|
||||
fallback_commit_search_space_id: int | None = None,
|
||||
fallback_commit_created_by_id: str | None = None,
|
||||
fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
||||
fallback_commit_thread_id: int | None = None,
|
||||
runtime_context: Any = None,
|
||||
content_builder: Any | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream and format ``astream_events`` from the agent.
|
||||
|
||||
Yields SSE-formatted strings; after exhausting, ``result`` carries
|
||||
``accumulated_text`` and interrupt state. See ``StreamResult`` for the
|
||||
side-channel surface populated by the underlying relay.
|
||||
"""
|
||||
async for sse in stream_output(
|
||||
agent=agent,
|
||||
config=config,
|
||||
input_data=input_data,
|
||||
streaming_service=streaming_service,
|
||||
result=result,
|
||||
step_prefix=step_prefix,
|
||||
initial_step_id=initial_step_id,
|
||||
initial_step_title=initial_step_title,
|
||||
initial_step_items=initial_step_items,
|
||||
content_builder=content_builder,
|
||||
runtime_context=runtime_context,
|
||||
):
|
||||
yield sse
|
||||
|
||||
accumulated_text = result.accumulated_text
|
||||
|
||||
state = await agent.aget_state(config)
|
||||
state_values = getattr(state, "values", {}) or {}
|
||||
|
||||
# Safety net: if astream_events was cancelled before
|
||||
# KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work
|
||||
# (dirty_paths / staged_dirs / pending_moves / pending_deletes /
|
||||
# pending_dir_deletes) is still in the checkpointed state. Run the SAME
|
||||
# shared commit helper so the turn's writes don't get lost on client
|
||||
# disconnect, then push the delta back into the graph using ``as_node=...``
|
||||
# so reducers fire as if the after_agent hook produced it.
|
||||
if (
|
||||
fallback_commit_filesystem_mode == FilesystemMode.CLOUD
|
||||
and fallback_commit_search_space_id is not None
|
||||
and (
|
||||
(state_values.get("dirty_paths") or [])
|
||||
or (state_values.get("staged_dirs") or [])
|
||||
or (state_values.get("pending_moves") or [])
|
||||
or (state_values.get("pending_deletes") or [])
|
||||
or (state_values.get("pending_dir_deletes") or [])
|
||||
)
|
||||
):
|
||||
try:
|
||||
delta = await commit_staged_filesystem_state(
|
||||
state_values,
|
||||
search_space_id=fallback_commit_search_space_id,
|
||||
created_by_id=fallback_commit_created_by_id,
|
||||
filesystem_mode=fallback_commit_filesystem_mode,
|
||||
thread_id=fallback_commit_thread_id,
|
||||
dispatch_events=False,
|
||||
)
|
||||
if delta:
|
||||
await agent.aupdate_state(
|
||||
config,
|
||||
delta,
|
||||
as_node="KnowledgeBasePersistenceMiddleware.after_agent",
|
||||
)
|
||||
except Exception as exc:
|
||||
_perf_log.warning("[stream_agent_events] safety-net commit failed: %s", exc)
|
||||
|
||||
contract_state = state_values.get("file_operation_contract") or {}
|
||||
contract_turn_id = contract_state.get("turn_id")
|
||||
current_turn_id = config.get("configurable", {}).get("turn_id", "")
|
||||
intent_value = contract_state.get("intent")
|
||||
if (
|
||||
isinstance(intent_value, str)
|
||||
and intent_value in ("chat_only", "file_write", "file_read")
|
||||
and contract_turn_id == current_turn_id
|
||||
):
|
||||
result.intent_detected = intent_value
|
||||
if (
|
||||
isinstance(intent_value, str)
|
||||
and intent_value in ("chat_only", "file_write", "file_read")
|
||||
and contract_turn_id != current_turn_id
|
||||
):
|
||||
# Ignore stale intent contracts from previous turns/checkpoints.
|
||||
result.intent_detected = "chat_only"
|
||||
result.intent_confidence = (
|
||||
safe_float(contract_state.get("confidence"), default=0.0)
|
||||
if contract_turn_id == current_turn_id
|
||||
else 0.0
|
||||
)
|
||||
|
||||
if result.intent_detected == "file_write":
|
||||
result.commit_gate_passed, result.commit_gate_reason = (
|
||||
evaluate_file_contract_outcome(result)
|
||||
)
|
||||
if not result.commit_gate_passed and contract_enforcement_active(result):
|
||||
gate_notice = (
|
||||
"I could not complete the requested file write because no successful "
|
||||
"write_file/edit_file operation was confirmed."
|
||||
)
|
||||
gate_text_id = streaming_service.generate_text_id()
|
||||
yield streaming_service.format_text_start(gate_text_id)
|
||||
if content_builder is not None:
|
||||
content_builder.on_text_start(gate_text_id)
|
||||
yield streaming_service.format_text_delta(gate_text_id, gate_notice)
|
||||
if content_builder is not None:
|
||||
content_builder.on_text_delta(gate_text_id, gate_notice)
|
||||
yield streaming_service.format_text_end(gate_text_id)
|
||||
if content_builder is not None:
|
||||
content_builder.on_text_end(gate_text_id)
|
||||
yield streaming_service.format_terminal_info(gate_notice, "error")
|
||||
accumulated_text = gate_notice
|
||||
else:
|
||||
result.commit_gate_passed = True
|
||||
result.commit_gate_reason = ""
|
||||
|
||||
result.accumulated_text = accumulated_text
|
||||
log_file_contract("turn_outcome", result)
|
||||
|
||||
pending_values = all_interrupt_values(state)
|
||||
if pending_values:
|
||||
result.is_interrupted = True
|
||||
# One frame per paused subagent so each parallel HITL renders its own
|
||||
# approval card on the wire. Order matches ``state.interrupts``, which
|
||||
# the resume slicer in
|
||||
# ``checkpointed_subagent_middleware.resume_routing`` consumes in the
|
||||
# same order — keeping emit and resume in lock-step.
|
||||
for interrupt_value in pending_values:
|
||||
yield streaming_service.format_interrupt_request(interrupt_value)
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
"""Pre-agent context shaping: mentioned-doc rendering and todos extraction."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.tasks.chat.streaming.context.deepagents_todos import (
|
||||
extract_todos_from_deepagents,
|
||||
)
|
||||
from app.tasks.chat.streaming.context.mentioned_docs import (
|
||||
format_mentioned_surfsense_docs_as_context,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"extract_todos_from_deepagents",
|
||||
"format_mentioned_surfsense_docs_as_context",
|
||||
]
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
"""Extract todos from a deepagents ``TodoListMiddleware`` ``Command`` output."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def extract_todos_from_deepagents(command_output: Any) -> dict:
|
||||
"""Normalize todos out of a deepagents ``Command`` or dict payload.
|
||||
|
||||
deepagents returns a ``Command`` whose ``update['todos']`` is a list of
|
||||
``{'content': str, 'status': str}`` dicts. The UI expects the same shape,
|
||||
so no transformation is required — only extraction.
|
||||
"""
|
||||
todos_data: list = []
|
||||
if hasattr(command_output, "update"):
|
||||
update = command_output.update
|
||||
todos_data = update.get("todos", [])
|
||||
elif isinstance(command_output, dict):
|
||||
if "todos" in command_output:
|
||||
todos_data = command_output.get("todos", [])
|
||||
elif "update" in command_output and isinstance(
|
||||
command_output["update"], dict
|
||||
):
|
||||
todos_data = command_output["update"].get("todos", [])
|
||||
|
||||
return {"todos": todos_data}
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
"""Render user-mentioned SurfSense docs as XML context for the agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from app.db import SurfsenseDocsDocument
|
||||
from app.utils.surfsense_docs import surfsense_docs_public_url
|
||||
|
||||
|
||||
def format_mentioned_surfsense_docs_as_context(
|
||||
documents: list[SurfsenseDocsDocument],
|
||||
) -> str:
|
||||
if not documents:
|
||||
return ""
|
||||
|
||||
context_parts = ["<mentioned_surfsense_docs>"]
|
||||
context_parts.append(
|
||||
"The user has explicitly mentioned the following SurfSense documentation pages. "
|
||||
"These are official documentation about how to use SurfSense and should be used to answer questions about the application. "
|
||||
"Use [citation:CHUNK_ID] format for citations (e.g., [citation:doc-123])."
|
||||
)
|
||||
|
||||
for doc in documents:
|
||||
public_url = surfsense_docs_public_url(doc.source)
|
||||
metadata_json = json.dumps(
|
||||
{"source": doc.source, "public_url": public_url}, ensure_ascii=False
|
||||
)
|
||||
|
||||
context_parts.append("<document>")
|
||||
context_parts.append("<document_metadata>")
|
||||
context_parts.append(f" <document_id>doc-{doc.id}</document_id>")
|
||||
context_parts.append(" <document_type>SURFSENSE_DOCS</document_type>")
|
||||
context_parts.append(f" <title><![CDATA[{doc.title}]]></title>")
|
||||
context_parts.append(f" <url><![CDATA[{public_url}]]></url>")
|
||||
context_parts.append(
|
||||
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>"
|
||||
)
|
||||
context_parts.append("</document_metadata>")
|
||||
context_parts.append("")
|
||||
context_parts.append("<document_content>")
|
||||
|
||||
if hasattr(doc, "chunks") and doc.chunks:
|
||||
for chunk in doc.chunks:
|
||||
context_parts.append(
|
||||
f" <chunk id='doc-{chunk.id}'><![CDATA[{chunk.content}]]></chunk>"
|
||||
)
|
||||
else:
|
||||
context_parts.append(
|
||||
f" <chunk id='doc-0'><![CDATA[{doc.content}]]></chunk>"
|
||||
)
|
||||
|
||||
context_parts.append("</document_content>")
|
||||
context_parts.append("</document>")
|
||||
context_parts.append("")
|
||||
|
||||
context_parts.append("</mentioned_surfsense_docs>")
|
||||
return "\n".join(context_parts)
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
"""File-operation contract evaluation and logging."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.tasks.chat.streaming.contract.file_contract import (
|
||||
contract_enforcement_active,
|
||||
evaluate_file_contract_outcome,
|
||||
log_file_contract,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"contract_enforcement_active",
|
||||
"evaluate_file_contract_outcome",
|
||||
"log_file_contract",
|
||||
]
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
"""File-operation contract: when to enforce, how to score, how to log."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.shared.stream_result import StreamResult
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
def contract_enforcement_active(result: StreamResult) -> bool:
|
||||
# Enforce only in desktop local-folder mode. Kept deterministic, no
|
||||
# env-driven progression modes.
|
||||
return result.filesystem_mode == "desktop_local_folder"
|
||||
|
||||
|
||||
def evaluate_file_contract_outcome(result: StreamResult) -> tuple[bool, str]:
|
||||
if result.intent_detected != "file_write":
|
||||
return True, ""
|
||||
if not result.write_attempted:
|
||||
return False, "no_write_attempt"
|
||||
if not result.write_succeeded:
|
||||
return False, "write_failed"
|
||||
if not result.verification_succeeded:
|
||||
return False, "verification_failed"
|
||||
return True, ""
|
||||
|
||||
|
||||
def log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None:
|
||||
payload: dict[str, Any] = {
|
||||
"stage": stage,
|
||||
"request_id": result.request_id or "unknown",
|
||||
"turn_id": result.turn_id or "unknown",
|
||||
"chat_id": (
|
||||
result.turn_id.split(":", 1)[0] if ":" in result.turn_id else "unknown"
|
||||
),
|
||||
"filesystem_mode": result.filesystem_mode,
|
||||
"client_platform": result.client_platform,
|
||||
"intent_detected": result.intent_detected,
|
||||
"intent_confidence": result.intent_confidence,
|
||||
"write_attempted": result.write_attempted,
|
||||
"write_succeeded": result.write_succeeded,
|
||||
"verification_succeeded": result.verification_succeeded,
|
||||
"commit_gate_passed": result.commit_gate_passed,
|
||||
"commit_gate_reason": result.commit_gate_reason or None,
|
||||
}
|
||||
payload.update(extra)
|
||||
_perf_log.info(
|
||||
"[file_operation_contract] %s", json.dumps(payload, ensure_ascii=False)
|
||||
)
|
||||
17
surfsense_backend/app/tasks/chat/streaming/flows/__init__.py
Normal file
17
surfsense_backend/app/tasks/chat/streaming/flows/__init__.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
"""Top-level streaming flows: ``new_chat`` and ``resume_chat`` orchestrators.
|
||||
|
||||
Re-exports the public entry points so callers can write::
|
||||
|
||||
from app.tasks.chat.streaming.flows import stream_new_chat, stream_resume_chat
|
||||
|
||||
The orchestrators themselves live under ``new_chat/orchestrator.py`` and
|
||||
``resume_chat/orchestrator.py`` (slim composition of the per-concern modules in
|
||||
each flow folder and the building blocks in ``shared/``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.tasks.chat.streaming.flows.new_chat import stream_new_chat
|
||||
from app.tasks.chat.streaming.flows.resume_chat import stream_resume_chat
|
||||
|
||||
__all__ = ["stream_new_chat", "stream_resume_chat"]
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
"""New-chat streaming flow.
|
||||
|
||||
The public entry point ``stream_new_chat`` is the slim coroutine in
|
||||
``orchestrator.py`` that composes the per-concern modules in this folder and
|
||||
the building blocks under ``flows/shared/``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.tasks.chat.streaming.flows.new_chat.orchestrator import stream_new_chat
|
||||
|
||||
__all__ = ["stream_new_chat"]
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
"""Resolve the auto-pin for the *initial* turn config.
|
||||
|
||||
Auto-pin (``selected_llm_config_id=0``) picks the best eligible LLM config for
|
||||
this thread / search space / user, optionally filtered to vision-capable
|
||||
configs when the turn carries images.
|
||||
|
||||
Errors classified here:
|
||||
|
||||
* ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` — the auto-pin pool has no
|
||||
vision-capable cfg for an image-bearing turn. The same gate fires later
|
||||
in ``llm_capability`` for explicit selections; mapping both to the same
|
||||
code keeps the FE error UI consistent.
|
||||
* ``SERVER_ERROR`` — any other ``ValueError`` from the resolver.
|
||||
|
||||
This module owns *initial* pin resolution; the rate-limit recovery loop has
|
||||
its own narrower auto-pin call (with ``exclude_config_ids``) in
|
||||
``flows/shared/rate_limit_recovery``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.observability import otel as ot
|
||||
from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoPinResult:
|
||||
"""Outcome of ``resolve_initial_auto_pin``.
|
||||
|
||||
``llm_config_id`` is set when ``error`` is ``None``; ``error`` carries the
|
||||
classified user-facing message plus error code/kind so the orchestrator can
|
||||
emit one terminal-error SSE frame.
|
||||
"""
|
||||
|
||||
llm_config_id: int | None
|
||||
error: tuple[str, str, Literal["user_error", "server_error"]] | None
|
||||
|
||||
|
||||
async def resolve_initial_auto_pin(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
chat_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
selected_llm_config_id: int,
|
||||
requires_image_input: bool,
|
||||
requested_llm_config_id: int,
|
||||
) -> AutoPinResult:
|
||||
"""Run the resolver and classify any ``ValueError`` for the SSE error path."""
|
||||
try:
|
||||
pinned = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=selected_llm_config_id,
|
||||
requires_image_input=requires_image_input,
|
||||
)
|
||||
ot.add_event(
|
||||
"model.pin.resolved",
|
||||
{
|
||||
"pin.requested_id": requested_llm_config_id,
|
||||
"pin.resolved_id": pinned.resolved_llm_config_id,
|
||||
"pin.requires_image_input": requires_image_input,
|
||||
},
|
||||
)
|
||||
return AutoPinResult(
|
||||
llm_config_id=pinned.resolved_llm_config_id, error=None
|
||||
)
|
||||
except ValueError as pin_error:
|
||||
# The "no vision-capable cfg" path raises a ValueError whose message
|
||||
# we map to the friendly image-input SSE error so the user sees the
|
||||
# same message regardless of whether the gate fired in the resolver or
|
||||
# in ``llm_capability.assert_vision_capability_for_image_turn``.
|
||||
is_vision_failure = (
|
||||
requires_image_input and "vision-capable" in str(pin_error)
|
||||
)
|
||||
error_code = (
|
||||
"MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
|
||||
if is_vision_failure
|
||||
else "SERVER_ERROR"
|
||||
)
|
||||
error_kind: Literal["user_error", "server_error"] = (
|
||||
"user_error" if is_vision_failure else "server_error"
|
||||
)
|
||||
if is_vision_failure:
|
||||
ot.add_event("quota.denied", {"quota.code": error_code})
|
||||
return AutoPinResult(
|
||||
llm_config_id=None, error=(str(pin_error), error_code, error_kind)
|
||||
)
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
"""Build and emit the first ``thinking-1`` step for a new-chat turn.
|
||||
|
||||
The step title and "Processing X" items are derived from what the user sent
|
||||
(text snippet, image count, mentioned doc titles) so the FE can render a
|
||||
meaningful placeholder while the agent stream warms up.
|
||||
|
||||
``thinking-1`` is the canonical id for this step — every subsequent
|
||||
``thinking-N`` produced by ``stream_agent_events`` folds into the same
|
||||
singleton ``data-thinking-steps`` part on the FE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from app.db import SurfsenseDocsDocument
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
|
||||
|
||||
@dataclass
|
||||
class InitialThinkingStep:
|
||||
"""Resolved fields passed both into the SSE frame and the builder hook.
|
||||
|
||||
``items`` is the bullet list under the step title; ``title`` is the
|
||||
one-line step header. ``step_id`` is hard-coded ``thinking-1`` so the FE
|
||||
Timeline can de-duplicate against the prior assistant message on resume.
|
||||
"""
|
||||
|
||||
step_id: str
|
||||
title: str
|
||||
items: list[str]
|
||||
|
||||
|
||||
def build_initial_thinking_step(
|
||||
*,
|
||||
user_query: str,
|
||||
user_image_data_urls: list[str] | None,
|
||||
mentioned_surfsense_docs: list[SurfsenseDocsDocument],
|
||||
) -> InitialThinkingStep:
|
||||
if mentioned_surfsense_docs:
|
||||
title = "Analyzing referenced content"
|
||||
action_verb = "Analyzing"
|
||||
else:
|
||||
title = "Understanding your request"
|
||||
action_verb = "Processing"
|
||||
|
||||
processing_parts: list[str] = []
|
||||
if user_query.strip():
|
||||
query_text = user_query[:80] + ("..." if len(user_query) > 80 else "")
|
||||
processing_parts.append(query_text)
|
||||
elif user_image_data_urls:
|
||||
processing_parts.append(f"[{len(user_image_data_urls)} image(s)]")
|
||||
else:
|
||||
processing_parts.append("(message)")
|
||||
|
||||
if mentioned_surfsense_docs:
|
||||
doc_names: list[str] = []
|
||||
for doc in mentioned_surfsense_docs:
|
||||
t = doc.title
|
||||
if len(t) > 30:
|
||||
t = t[:27] + "..."
|
||||
doc_names.append(t)
|
||||
if len(doc_names) == 1:
|
||||
processing_parts.append(f"[{doc_names[0]}]")
|
||||
else:
|
||||
processing_parts.append(f"[{len(doc_names)} docs]")
|
||||
|
||||
items = [f"{action_verb}: {' '.join(processing_parts)}"]
|
||||
return InitialThinkingStep(step_id="thinking-1", title=title, items=items)
|
||||
|
||||
|
||||
def iter_initial_thinking_step_frame(
|
||||
step: InitialThinkingStep,
|
||||
*,
|
||||
streaming_service: VercelStreamingService,
|
||||
content_builder: Any | None,
|
||||
) -> Iterator[str]:
|
||||
"""Drive both the SSE emission and the builder hook for the initial step.
|
||||
|
||||
The FE folds this step into the same singleton ``data-thinking-steps`` part
|
||||
as everything the agent stream emits later, so we mirror that fold
|
||||
server-side by driving the builder lifecycle ourselves.
|
||||
"""
|
||||
if content_builder is not None:
|
||||
content_builder.on_thinking_step(
|
||||
step.step_id, step.title, "in_progress", step.items
|
||||
)
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=step.step_id,
|
||||
title=step.title,
|
||||
status="in_progress",
|
||||
items=step.items,
|
||||
)
|
||||
|
|
@ -0,0 +1,264 @@
|
|||
r"""Assemble the LangGraph ``input_state`` for the new-chat turn.
|
||||
|
||||
Pipeline:
|
||||
|
||||
1. **History bootstrap** — only for cloned chats with no LangGraph checkpoint
|
||||
yet; flips the per-thread ``needs_history_bootstrap`` flag back to False
|
||||
once the rows are loaded.
|
||||
2. **Mentioned SurfSense docs** — eager-load chunks so the formatter has the
|
||||
full content without a second roundtrip.
|
||||
3. **Recent reports** — top 3 by id desc with non-null content, so the LLM
|
||||
can resolve ``report_id`` for versioning without spelunking history.
|
||||
4. **@-mention resolve** (cloud mode) — substitute ``@title`` tokens in the
|
||||
query with canonical ``\`/documents/...\``` paths the LLM expects.
|
||||
5. **Context block render** — XML-wrap surfsense docs + reports, prepend to
|
||||
the rewritten query, optionally prefix with display name for SEARCH_SPACE
|
||||
visibility.
|
||||
6. **HumanMessage** — multimodal content if images are attached.
|
||||
|
||||
Returns the assembled ``input_state`` dict plus side-channel data the
|
||||
orchestrator needs downstream (``accepted_folder_ids`` for runtime context;
|
||||
``mentioned_surfsense_docs`` for the initial thinking step).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.mention_resolver import resolve_mentions, substitute_in_text
|
||||
from app.db import (
|
||||
ChatVisibility,
|
||||
NewChatThread,
|
||||
Report,
|
||||
SurfsenseDocsDocument,
|
||||
)
|
||||
from app.tasks.chat.streaming.context.mentioned_docs import (
|
||||
format_mentioned_surfsense_docs_as_context,
|
||||
)
|
||||
from app.utils.content_utils import bootstrap_history_from_db
|
||||
from app.utils.user_message_multimodal import build_human_message_content
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewChatInputState:
|
||||
"""Everything ``build_new_chat_input_state`` produces.
|
||||
|
||||
``input_state`` is fed straight to the agent. ``accepted_folder_ids``
|
||||
feeds the runtime context (the resolver may have dropped some chips).
|
||||
``mentioned_surfsense_docs`` is consumed by the initial thinking-step
|
||||
builder for the FE placeholder before the agent stream starts.
|
||||
"""
|
||||
|
||||
input_state: dict[str, Any]
|
||||
accepted_folder_ids: list[int]
|
||||
mentioned_surfsense_docs: list[SurfsenseDocsDocument]
|
||||
|
||||
|
||||
async def build_new_chat_input_state(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
chat_id: int,
|
||||
search_space_id: int,
|
||||
user_query: str,
|
||||
user_image_data_urls: list[str] | None,
|
||||
mentioned_document_ids: list[int] | None,
|
||||
mentioned_surfsense_doc_ids: list[int] | None,
|
||||
mentioned_folder_ids: list[int] | None,
|
||||
mentioned_documents: list[dict[str, Any]] | None,
|
||||
needs_history_bootstrap: bool,
|
||||
thread_visibility: ChatVisibility,
|
||||
current_user_display_name: str | None,
|
||||
filesystem_mode: str,
|
||||
request_id: str | None,
|
||||
turn_id: str,
|
||||
) -> NewChatInputState:
|
||||
langchain_messages: list[Any] = []
|
||||
|
||||
if needs_history_bootstrap:
|
||||
langchain_messages = await bootstrap_history_from_db(
|
||||
session, chat_id, thread_visibility=thread_visibility
|
||||
)
|
||||
thread_result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == chat_id)
|
||||
)
|
||||
thread = thread_result.scalars().first()
|
||||
if thread:
|
||||
thread.needs_history_bootstrap = False
|
||||
await session.commit()
|
||||
|
||||
mentioned_surfsense_docs: list[SurfsenseDocsDocument] = []
|
||||
if mentioned_surfsense_doc_ids:
|
||||
result = await session.execute(
|
||||
select(SurfsenseDocsDocument)
|
||||
.options(selectinload(SurfsenseDocsDocument.chunks))
|
||||
.filter(SurfsenseDocsDocument.id.in_(mentioned_surfsense_doc_ids))
|
||||
)
|
||||
mentioned_surfsense_docs = list(result.scalars().all())
|
||||
|
||||
# Top 3 reports keyed by id desc (newest first) with content present,
|
||||
# surfaced inline so the LLM resolves ``report_id`` for versioning without
|
||||
# digging through conversation history.
|
||||
recent_reports_result = await session.execute(
|
||||
select(Report)
|
||||
.filter(
|
||||
Report.thread_id == chat_id,
|
||||
Report.content.isnot(None),
|
||||
)
|
||||
.order_by(Report.id.desc())
|
||||
.limit(3)
|
||||
)
|
||||
recent_reports = list(recent_reports_result.scalars().all())
|
||||
|
||||
agent_user_query, accepted_folder_ids = await _resolve_mentions_for_query(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
user_query=user_query,
|
||||
filesystem_mode=filesystem_mode,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids,
|
||||
mentioned_folder_ids=mentioned_folder_ids,
|
||||
mentioned_documents=mentioned_documents,
|
||||
)
|
||||
|
||||
final_query = _render_query_with_context(
|
||||
agent_user_query=agent_user_query,
|
||||
mentioned_surfsense_docs=mentioned_surfsense_docs,
|
||||
recent_reports=recent_reports,
|
||||
)
|
||||
|
||||
if thread_visibility == ChatVisibility.SEARCH_SPACE and current_user_display_name:
|
||||
final_query = f"**[{current_user_display_name}]:** {final_query}"
|
||||
|
||||
human_content = build_human_message_content(
|
||||
final_query, list(user_image_data_urls or ())
|
||||
)
|
||||
langchain_messages.append(HumanMessage(content=human_content))
|
||||
|
||||
input_state = {
|
||||
"messages": langchain_messages,
|
||||
"search_space_id": search_space_id,
|
||||
"request_id": request_id or "unknown",
|
||||
"turn_id": turn_id,
|
||||
}
|
||||
|
||||
return NewChatInputState(
|
||||
input_state=input_state,
|
||||
accepted_folder_ids=accepted_folder_ids,
|
||||
mentioned_surfsense_docs=mentioned_surfsense_docs,
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_mentions_for_query(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
user_query: str,
|
||||
filesystem_mode: str,
|
||||
mentioned_document_ids: list[int] | None,
|
||||
mentioned_surfsense_doc_ids: list[int] | None,
|
||||
mentioned_folder_ids: list[int] | None,
|
||||
mentioned_documents: list[dict[str, Any]] | None,
|
||||
) -> tuple[str, list[int]]:
|
||||
r"""Resolve @-mention chips and rewrite the user query to canonical paths.
|
||||
|
||||
Cloud mode only: local-folder mode keeps the legacy ``@title`` text path
|
||||
(mention support there is a follow-up task — the path scheme is
|
||||
mount-rooted and the picker UI both need separate work).
|
||||
|
||||
The substitution lands in the returned ``agent_user_query`` ONLY — the
|
||||
original ``user_query`` (with ``@title`` tokens) flows untouched into
|
||||
``persist_user_turn`` so chip rendering on reload still works
|
||||
(``UserTextPart`` → ``parseMentionSegments`` matches ``@title``, not
|
||||
``\`/documents/...\```). It also feeds the human-readable surfaces — SSE
|
||||
"Processing X" status, auto thread title, memory seed — which all want
|
||||
what the user typed.
|
||||
"""
|
||||
agent_user_query = user_query
|
||||
accepted_folder_ids: list[int] = []
|
||||
|
||||
has_any_mention = bool(
|
||||
mentioned_document_ids
|
||||
or mentioned_surfsense_doc_ids
|
||||
or mentioned_folder_ids
|
||||
or mentioned_documents
|
||||
)
|
||||
if filesystem_mode != FilesystemMode.CLOUD.value or not has_any_mention:
|
||||
return agent_user_query, accepted_folder_ids
|
||||
|
||||
from app.schemas.new_chat import MentionedDocumentInfo
|
||||
|
||||
chip_objs: list[MentionedDocumentInfo] | None = None
|
||||
if mentioned_documents:
|
||||
chip_objs = []
|
||||
for raw in mentioned_documents:
|
||||
if isinstance(raw, MentionedDocumentInfo):
|
||||
chip_objs.append(raw)
|
||||
continue
|
||||
try:
|
||||
chip_objs.append(MentionedDocumentInfo.model_validate(raw))
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"stream_new_chat: dropping malformed mention chip %r", raw
|
||||
)
|
||||
|
||||
resolved = await resolve_mentions(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
mentioned_documents=chip_objs,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids,
|
||||
mentioned_folder_ids=mentioned_folder_ids,
|
||||
)
|
||||
agent_user_query = substitute_in_text(user_query, resolved.token_to_path)
|
||||
accepted_folder_ids = resolved.mentioned_folder_ids
|
||||
return agent_user_query, accepted_folder_ids
|
||||
|
||||
|
||||
def _render_query_with_context(
|
||||
*,
|
||||
agent_user_query: str,
|
||||
mentioned_surfsense_docs: list[SurfsenseDocsDocument],
|
||||
recent_reports: list[Report],
|
||||
) -> str:
|
||||
"""Prepend surfsense-docs + recent-reports XML blocks to the user query."""
|
||||
context_parts: list[str] = []
|
||||
|
||||
if mentioned_surfsense_docs:
|
||||
context_parts.append(
|
||||
format_mentioned_surfsense_docs_as_context(mentioned_surfsense_docs)
|
||||
)
|
||||
|
||||
if recent_reports:
|
||||
report_lines: list[str] = []
|
||||
for r in recent_reports:
|
||||
report_lines.append(
|
||||
f' - report_id={r.id}, title="{r.title}", '
|
||||
f'style="{r.report_style or "detailed"}"'
|
||||
)
|
||||
reports_listing = "\n".join(report_lines)
|
||||
context_parts.append(
|
||||
"<report_context>\n"
|
||||
"Previously generated reports in this conversation:\n"
|
||||
f"{reports_listing}\n\n"
|
||||
"If the user wants to MODIFY, REVISE, UPDATE, or ADD to one of "
|
||||
"these reports, set parent_report_id to the relevant report_id above.\n"
|
||||
"If the user wants a completely NEW report on a different topic, "
|
||||
"leave parent_report_id unset.\n"
|
||||
"</report_context>"
|
||||
)
|
||||
|
||||
if context_parts:
|
||||
context = "\n\n".join(context_parts)
|
||||
return f"{context}\n\n<user_query>{agent_user_query}</user_query>"
|
||||
|
||||
return agent_user_query
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
"""Vision-capability gate for image-bearing turns.
|
||||
|
||||
Capability safety net for explicit (non-auto-pin) selections: a turn carrying
|
||||
user-uploaded images cannot be routed to a chat config that LiteLLM's
|
||||
authoritative model map *explicitly* marks as text-only (``supports_vision``
|
||||
set to False). The check is intentionally narrow — it only fires when LiteLLM
|
||||
is *certain* the model can't accept image input; unknown / unmapped /
|
||||
vision-capable models pass through.
|
||||
|
||||
Without this guard a known-text-only model would 404 at the provider with
|
||||
``"No endpoints found that support image input"``, surfacing as an opaque
|
||||
``SERVER_ERROR`` SSE chunk; failing here lets us return a friendly message that
|
||||
tells the user what to change.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
from app.observability import otel as ot
|
||||
|
||||
|
||||
def check_image_input_capability(
|
||||
*,
|
||||
user_image_data_urls: list[str] | None,
|
||||
agent_config: AgentConfig | None,
|
||||
) -> tuple[str, str] | None:
|
||||
"""Return ``(user_message, error_code)`` when the gate trips, else ``None``.
|
||||
|
||||
The caller emits one terminal-error SSE frame on a non-``None`` return.
|
||||
"""
|
||||
if not (user_image_data_urls and agent_config is not None):
|
||||
return None
|
||||
|
||||
from app.services.provider_capabilities import is_known_text_only_chat_model
|
||||
|
||||
agent_litellm_params = agent_config.litellm_params or {}
|
||||
agent_base_model = (
|
||||
agent_litellm_params.get("base_model")
|
||||
if isinstance(agent_litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
if not is_known_text_only_chat_model(
|
||||
provider=agent_config.provider,
|
||||
model_name=agent_config.model_name,
|
||||
base_model=agent_base_model,
|
||||
custom_provider=agent_config.custom_provider,
|
||||
):
|
||||
return None
|
||||
|
||||
model_label = agent_config.config_name or agent_config.model_name or "model"
|
||||
ot.add_event(
|
||||
"quota.denied", {"quota.code": "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"}
|
||||
)
|
||||
return (
|
||||
(
|
||||
f"The selected model ({model_label}) does not support "
|
||||
"image input. Switch to a vision-capable model "
|
||||
"(e.g. GPT-4o, Claude, Gemini) or remove the image "
|
||||
"attachment and try again."
|
||||
),
|
||||
"MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT",
|
||||
)
|
||||
|
|
@ -0,0 +1,868 @@
|
|||
"""``stream_new_chat`` — public entry point for a fresh chat turn.
|
||||
|
||||
Slim composition layer over the per-concern modules in this folder and the
|
||||
building blocks under ``flows/shared/``. Each phase corresponds to a numbered
|
||||
block in the surrounding code so the on-the-wire ordering stays explicit:
|
||||
|
||||
1. Validation / config — auto-pin, LLM bundle, capability, premium reserve.
|
||||
2. Concurrent persistence + pre-stream setup — spawn DB writes, build the
|
||||
connector, fetch the checkpointer, build the agent.
|
||||
3. Input assembly — history bootstrap, mentions, surfsense docs, reports.
|
||||
4. First SSE frames — message_start, start_step, turn-info, turn-status.
|
||||
5. Persistence join + message-id frames (ghost-thread protection).
|
||||
6. Initial thinking step + title task + runtime context.
|
||||
7. Stream loop with in-stream rate-limit recovery + mid-stream title emit.
|
||||
8. Finalize — premium debit, token-usage SSE, finish frames.
|
||||
9. Exception branch — classify, emit terminal error, finish frames.
|
||||
10. Finally — premium release, session close, assistant finalize, GC, span.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from functools import partial
|
||||
from typing import Any, Literal
|
||||
|
||||
import anyio
|
||||
|
||||
from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent
|
||||
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||
from app.agents.new_chat.middleware.busy_mutex import end_turn
|
||||
from app.config import config as _app_config
|
||||
from app.db import ChatVisibility, async_session_maker
|
||||
from app.observability import otel as ot
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.tasks.chat.content_builder import AssistantContentBuilder
|
||||
from app.tasks.chat.streaming.agent.builder import build_main_agent_for_thread
|
||||
from app.tasks.chat.streaming.contract.file_contract import log_file_contract
|
||||
from app.tasks.chat.streaming.errors.emitter import emit_stream_terminal_error
|
||||
from app.tasks.chat.streaming.flows.new_chat.auto_pin import resolve_initial_auto_pin
|
||||
from app.tasks.chat.streaming.flows.new_chat.initial_thinking_step import (
|
||||
build_initial_thinking_step,
|
||||
iter_initial_thinking_step_frame,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.new_chat.input_state import (
|
||||
build_new_chat_input_state,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.new_chat.llm_capability import (
|
||||
check_image_input_capability,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.new_chat.persistence_spawn import (
|
||||
await_persist_task,
|
||||
spawn_persist_assistant_shell_task,
|
||||
spawn_persist_user_task,
|
||||
spawn_set_ai_responding_bg,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.new_chat.runtime_context import (
|
||||
build_new_chat_runtime_context,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.new_chat.title_gen import (
|
||||
await_pending_title_update,
|
||||
maybe_emit_title_update,
|
||||
spawn_title_task,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.assistant_finalize import (
|
||||
finalize_assistant_message,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.finalize_emit import iter_token_usage_frame
|
||||
from app.tasks.chat.streaming.flows.shared.finally_cleanup import (
|
||||
close_session_and_clear_ai_responding,
|
||||
run_gc_pass,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.first_frames import (
|
||||
iter_final_frames,
|
||||
iter_initial_frames,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.llm_bundle import load_llm_bundle
|
||||
from app.tasks.chat.streaming.flows.shared.pre_stream_setup import (
|
||||
get_chat_checkpointer,
|
||||
setup_connector_and_firecrawl,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.premium_quota import (
|
||||
PremiumReservation,
|
||||
finalize_premium,
|
||||
needs_premium_quota,
|
||||
release_premium,
|
||||
reserve_premium,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.rate_limit_recovery import (
|
||||
can_recover_provider_rate_limit,
|
||||
log_rate_limit_recovered,
|
||||
reroute_to_next_auto_pin,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.span import (
|
||||
close_chat_request_span,
|
||||
open_chat_request_span,
|
||||
set_agent_mode,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.stream_loop import run_stream_loop
|
||||
from app.tasks.chat.streaming.flows.shared.terminal_error import (
|
||||
handle_terminal_exception,
|
||||
)
|
||||
from app.tasks.chat.streaming.shared.stream_result import StreamResult
|
||||
from app.utils.perf import get_perf_logger, log_system_snapshot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
# Holds spawned background tasks (set_ai_responding, persist_user, persist_asst)
|
||||
# so the GC doesn't drop them before they finish. Kept at module level so it
|
||||
# survives across turns within one process.
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
async def stream_new_chat(
|
||||
user_query: str,
|
||||
search_space_id: int,
|
||||
chat_id: int,
|
||||
user_id: str | None = None,
|
||||
llm_config_id: int = -1,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
mentioned_surfsense_doc_ids: list[int] | None = None,
|
||||
mentioned_folder_ids: list[int] | None = None,
|
||||
mentioned_documents: list[dict[str, Any]] | None = None,
|
||||
checkpoint_id: str | None = None,
|
||||
needs_history_bootstrap: bool = False,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
current_user_display_name: str | None = None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
filesystem_selection: FilesystemSelection | None = None,
|
||||
request_id: str | None = None,
|
||||
user_image_data_urls: list[str] | None = None,
|
||||
flow: Literal["new", "regenerate"] = "new",
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream a new chat turn using the SurfSense deep agent.
|
||||
|
||||
Uses the Vercel AI SDK Data Stream Protocol (SSE). ``chat_id`` is the
|
||||
LangGraph thread id (durable conversation memory via the checkpointer).
|
||||
Manages its own database session so cleanup runs even when Starlette
|
||||
cancels the task on client disconnect.
|
||||
"""
|
||||
streaming_service = VercelStreamingService()
|
||||
stream_result = StreamResult()
|
||||
_t_total = time.perf_counter()
|
||||
fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud"
|
||||
fs_platform = (
|
||||
filesystem_selection.client_platform.value if filesystem_selection else "web"
|
||||
)
|
||||
stream_result.request_id = request_id
|
||||
stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}"
|
||||
stream_result.filesystem_mode = fs_mode
|
||||
stream_result.client_platform = fs_platform
|
||||
|
||||
chat_agent_mode = "unknown"
|
||||
chat_outcome = "success"
|
||||
chat_error_category: str | None = None
|
||||
chat_span_cm, chat_span = open_chat_request_span(
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
flow=flow,
|
||||
request_id=request_id,
|
||||
turn_id=stream_result.turn_id,
|
||||
filesystem_mode=fs_mode,
|
||||
client_platform=fs_platform,
|
||||
agent_mode=chat_agent_mode,
|
||||
)
|
||||
log_file_contract("turn_start", stream_result)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] filesystem_mode=%s client_platform=%s",
|
||||
fs_mode,
|
||||
fs_platform,
|
||||
)
|
||||
log_system_snapshot("stream_new_chat_START")
|
||||
|
||||
from app.services.token_tracking_service import start_turn
|
||||
|
||||
accumulator = start_turn()
|
||||
|
||||
premium_reservation: PremiumReservation | None = None
|
||||
busy_error_raised = False
|
||||
|
||||
emit_stream_error = partial(
|
||||
emit_stream_terminal_error,
|
||||
streaming_service=streaming_service,
|
||||
flow=flow,
|
||||
request_id=request_id,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
session = async_session_maker()
|
||||
# Declared at function scope so SSE-yield join points and the finally
|
||||
# clause see them on every exit path.
|
||||
persist_user_task: asyncio.Task[int | None] | None = None
|
||||
persist_asst_task: asyncio.Task[int | None] | None = None
|
||||
try:
|
||||
spawn_set_ai_responding_bg(
|
||||
chat_id=chat_id, user_id=user_id, background_tasks=_background_tasks
|
||||
)
|
||||
|
||||
# --- Block 1: LLM config + capability ---
|
||||
|
||||
requested_llm_config_id = llm_config_id
|
||||
requires_image_input = bool(user_image_data_urls)
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
pin_result = await resolve_initial_auto_pin(
|
||||
session,
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=llm_config_id,
|
||||
requires_image_input=requires_image_input,
|
||||
requested_llm_config_id=requested_llm_config_id,
|
||||
)
|
||||
if pin_result.error is not None:
|
||||
message, error_code, error_kind = pin_result.error
|
||||
yield emit_stream_error(
|
||||
message=message, error_kind=error_kind, error_code=error_code
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
llm_config_id = pin_result.llm_config_id # type: ignore[assignment]
|
||||
|
||||
llm, agent_config, llm_load_error = await load_llm_bundle(
|
||||
session, config_id=llm_config_id, search_space_id=search_space_id
|
||||
)
|
||||
if llm_load_error:
|
||||
yield emit_stream_error(
|
||||
message=llm_load_error,
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)",
|
||||
time.perf_counter() - _t0,
|
||||
llm_config_id,
|
||||
)
|
||||
|
||||
capability_error = check_image_input_capability(
|
||||
user_image_data_urls=user_image_data_urls, agent_config=agent_config
|
||||
)
|
||||
if capability_error is not None:
|
||||
message, error_code = capability_error
|
||||
yield emit_stream_error(
|
||||
message=message,
|
||||
error_kind="user_error",
|
||||
error_code=error_code,
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
if needs_premium_quota(agent_config, user_id):
|
||||
premium_reservation = await reserve_premium(
|
||||
agent_config=agent_config, user_id=user_id # type: ignore[arg-type]
|
||||
)
|
||||
if not premium_reservation.allowed:
|
||||
ot.add_event("quota.denied", {"quota.code": "PREMIUM_QUOTA_EXHAUSTED"})
|
||||
if requested_llm_config_id == 0:
|
||||
pin_fallback = await resolve_initial_auto_pin(
|
||||
session,
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=0,
|
||||
requires_image_input=requires_image_input,
|
||||
requested_llm_config_id=requested_llm_config_id,
|
||||
)
|
||||
if pin_fallback.error is not None:
|
||||
message, error_code, error_kind = pin_fallback.error
|
||||
yield emit_stream_error(
|
||||
message=message,
|
||||
error_kind=error_kind,
|
||||
error_code=error_code,
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
llm_config_id = pin_fallback.llm_config_id # type: ignore[assignment]
|
||||
ot.add_event(
|
||||
"model.repin",
|
||||
{
|
||||
"repin.reason": "premium_quota_exhausted",
|
||||
"repin.to_config_id": llm_config_id,
|
||||
},
|
||||
)
|
||||
llm, agent_config, llm_load_error = await load_llm_bundle(
|
||||
session,
|
||||
config_id=llm_config_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
if llm_load_error:
|
||||
yield emit_stream_error(
|
||||
message=llm_load_error,
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
premium_reservation = None
|
||||
# Re-route to free fallback logged via the structured
|
||||
# stream-error logger so cost/analytics see the auto-switch.
|
||||
from app.tasks.chat.streaming.errors.classifier import (
|
||||
log_chat_stream_error,
|
||||
)
|
||||
|
||||
log_chat_stream_error(
|
||||
flow=flow,
|
||||
error_kind="premium_quota_exhausted",
|
||||
error_code="PREMIUM_QUOTA_EXHAUSTED",
|
||||
severity="info",
|
||||
is_expected=True,
|
||||
request_id=request_id,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
message=(
|
||||
"Premium quota exhausted on pinned model; "
|
||||
"auto-fallback switched to a free model"
|
||||
),
|
||||
extra={
|
||||
"fallback_config_id": llm_config_id,
|
||||
"auto_fallback": True,
|
||||
},
|
||||
)
|
||||
else:
|
||||
yield emit_stream_error(
|
||||
message=(
|
||||
"Buy more tokens to continue with this model, or "
|
||||
"switch to a free model"
|
||||
),
|
||||
error_kind="premium_quota_exhausted",
|
||||
error_code="PREMIUM_QUOTA_EXHAUSTED",
|
||||
severity="info",
|
||||
is_expected=True,
|
||||
extra={
|
||||
"resolved_config_id": llm_config_id,
|
||||
"auto_fallback": False,
|
||||
},
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
if not llm:
|
||||
yield emit_stream_error(
|
||||
message="Failed to create LLM instance",
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
# --- Block 2: Spawn concurrent persistence; build pre-stream setup ---
|
||||
|
||||
persist_user_task = spawn_persist_user_task(
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
turn_id=stream_result.turn_id,
|
||||
user_query=user_query,
|
||||
user_image_data_urls=user_image_data_urls,
|
||||
mentioned_documents=mentioned_documents,
|
||||
background_tasks=_background_tasks,
|
||||
)
|
||||
persist_asst_task = spawn_persist_assistant_shell_task(
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
turn_id=stream_result.turn_id,
|
||||
background_tasks=_background_tasks,
|
||||
)
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
connector_service, firecrawl_api_key = await setup_connector_and_firecrawl(
|
||||
session, search_space_id=search_space_id
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Connector service + firecrawl key in %.3fs",
|
||||
time.perf_counter() - _t0,
|
||||
)
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
checkpointer = await get_chat_checkpointer()
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
use_multi_agent = bool(_app_config.MULTI_AGENT_CHAT_ENABLED)
|
||||
chat_agent_mode = "multi" if use_multi_agent else "single"
|
||||
set_agent_mode(chat_span, chat_agent_mode)
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
agent_factory = (
|
||||
create_multi_agent_chat_deep_agent
|
||||
if use_multi_agent
|
||||
else create_surfsense_deep_agent
|
||||
)
|
||||
# Build the agent inline. Provider 429s surface through the in-stream
|
||||
# recovery loop below, which repins the thread to an eligible
|
||||
# alternative config and rebuilds the agent before the user sees any
|
||||
# output.
|
||||
agent = await build_main_agent_for_thread(
|
||||
agent_factory,
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
connector_service=connector_service,
|
||||
checkpointer=checkpointer,
|
||||
user_id=user_id,
|
||||
thread_id=chat_id,
|
||||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
disabled_tools=disabled_tools,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
||||
# --- Block 3: Input assembly ---
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
assembled = await build_new_chat_input_state(
|
||||
session,
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_query=user_query,
|
||||
user_image_data_urls=user_image_data_urls,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids,
|
||||
mentioned_folder_ids=mentioned_folder_ids,
|
||||
mentioned_documents=mentioned_documents,
|
||||
needs_history_bootstrap=needs_history_bootstrap,
|
||||
thread_visibility=visibility,
|
||||
current_user_display_name=current_user_display_name,
|
||||
filesystem_mode=fs_mode,
|
||||
request_id=request_id,
|
||||
turn_id=stream_result.turn_id,
|
||||
)
|
||||
input_state = assembled.input_state
|
||||
accepted_folder_ids = assembled.accepted_folder_ids
|
||||
mentioned_surfsense_docs = assembled.mentioned_surfsense_docs
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] History bootstrap + doc/report queries in %.3fs",
|
||||
time.perf_counter() - _t0,
|
||||
)
|
||||
|
||||
# All pre-streaming DB reads done. Commit to release the transaction
|
||||
# and its ACCESS SHARE locks so we don't block DDL (e.g. migrations)
|
||||
# for the entire LLM streaming duration. Tools that need DB access
|
||||
# during streaming start their own short-lived transactions (or use
|
||||
# isolated sessions).
|
||||
await session.commit()
|
||||
# Detach heavy ORM objects (documents with chunks, reports, etc.)
|
||||
# from the session identity map now that we've extracted what we
|
||||
# need. Without this they accumulate in memory for the entire
|
||||
# streaming duration (which can be several minutes).
|
||||
session.expunge_all()
|
||||
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)",
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
)
|
||||
|
||||
configurable: dict[str, Any] = {
|
||||
"thread_id": str(chat_id),
|
||||
"request_id": request_id or "unknown",
|
||||
"turn_id": stream_result.turn_id,
|
||||
}
|
||||
if checkpoint_id:
|
||||
configurable["checkpoint_id"] = checkpoint_id
|
||||
|
||||
config = {
|
||||
"configurable": configurable,
|
||||
# Effectively uncapped, matching the agent-level ``with_config``
|
||||
# default in ``chat_deepagent.create_agent`` and the unbounded
|
||||
# ``while(true)`` in OpenCode's ``session/processor.ts``. Real
|
||||
# circuit-breakers live in middleware (``DoomLoopMiddleware``,
|
||||
# plus ``enable_tool_call_limit`` / ``enable_model_call_limit``).
|
||||
# The original 25 (and our previous 80 bump) hit users on
|
||||
# legitimate multi-tool plans.
|
||||
"recursion_limit": 10_000,
|
||||
}
|
||||
|
||||
# --- Block 4: First SSE frames ---
|
||||
|
||||
for sse in iter_initial_frames(streaming_service, turn_id=stream_result.turn_id):
|
||||
yield sse
|
||||
|
||||
# --- Block 5: Persistence join + message-id frames ---
|
||||
|
||||
user_message_id = await await_persist_task(
|
||||
persist_user_task,
|
||||
chat_id=chat_id,
|
||||
turn_id=stream_result.turn_id,
|
||||
log_label="persist_user_task",
|
||||
)
|
||||
if user_message_id is None:
|
||||
yield emit_stream_error(
|
||||
message="We couldn't save your message. Please try again in a moment.",
|
||||
error_kind="server_error",
|
||||
error_code="MESSAGE_PERSIST_FAILED",
|
||||
)
|
||||
for sse in iter_final_frames(streaming_service):
|
||||
yield sse
|
||||
return
|
||||
|
||||
# Emit canonical user message id BEFORE any LLM streaming so the FE
|
||||
# can rename its optimistic ``msg-user-XXX`` placeholder to
|
||||
# ``msg-{user_message_id}`` and unlock features gated on a real DB id
|
||||
# (comments, edit-from-this-message). See B4 in the
|
||||
# ``sse-based_message_id_handshake`` plan.
|
||||
yield streaming_service.format_data(
|
||||
"user-message-id",
|
||||
{"message_id": user_message_id, "turn_id": stream_result.turn_id},
|
||||
)
|
||||
|
||||
assistant_message_id = await await_persist_task(
|
||||
persist_asst_task,
|
||||
chat_id=chat_id,
|
||||
turn_id=stream_result.turn_id,
|
||||
log_label="persist_asst_task",
|
||||
)
|
||||
if assistant_message_id is None:
|
||||
# Genuine DB failure — abort the turn rather than stream into a
|
||||
# void. The user row is already persisted so the legacy
|
||||
# ghost-thread gate isn't reopened.
|
||||
yield emit_stream_error(
|
||||
message=(
|
||||
"We couldn't initialize the assistant message. Please try again."
|
||||
),
|
||||
error_kind="server_error",
|
||||
error_code="MESSAGE_PERSIST_FAILED",
|
||||
)
|
||||
for sse in iter_final_frames(streaming_service):
|
||||
yield sse
|
||||
return
|
||||
|
||||
yield streaming_service.format_data(
|
||||
"assistant-message-id",
|
||||
{"message_id": assistant_message_id, "turn_id": stream_result.turn_id},
|
||||
)
|
||||
|
||||
stream_result.assistant_message_id = assistant_message_id
|
||||
stream_result.content_builder = AssistantContentBuilder()
|
||||
|
||||
# --- Block 6: Initial thinking step + title task + runtime context ---
|
||||
|
||||
initial_step = build_initial_thinking_step(
|
||||
user_query=user_query,
|
||||
user_image_data_urls=user_image_data_urls,
|
||||
mentioned_surfsense_docs=mentioned_surfsense_docs,
|
||||
)
|
||||
for sse in iter_initial_thinking_step_frame(
|
||||
initial_step,
|
||||
streaming_service=streaming_service,
|
||||
content_builder=stream_result.content_builder,
|
||||
):
|
||||
yield sse
|
||||
|
||||
initial_step_id = initial_step.step_id
|
||||
initial_step_title = initial_step.title
|
||||
initial_step_items = initial_step.items
|
||||
# Drop the heavy ORM objects + the container that holds them so they
|
||||
# aren't retained for the entire streaming duration. ``input_state``
|
||||
# already carries the langchain_messages list independently.
|
||||
del assembled, mentioned_surfsense_docs
|
||||
|
||||
title_task = spawn_title_task(
|
||||
chat_id=chat_id,
|
||||
user_query=user_query,
|
||||
user_image_data_urls=user_image_data_urls,
|
||||
assistant_message_id=assistant_message_id,
|
||||
llm=llm,
|
||||
agent_config=agent_config,
|
||||
)
|
||||
title_emitted = False
|
||||
|
||||
runtime_context = build_new_chat_runtime_context(
|
||||
search_space_id=search_space_id,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
accepted_folder_ids=accepted_folder_ids,
|
||||
mentioned_folder_ids=mentioned_folder_ids,
|
||||
request_id=request_id,
|
||||
turn_id=stream_result.turn_id,
|
||||
)
|
||||
|
||||
# --- Block 7: Stream loop ---
|
||||
|
||||
_t_stream_start = time.perf_counter()
|
||||
runtime_rate_limit_recovered = False
|
||||
|
||||
def _on_first_event() -> None:
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] First agent event in %.3fs (time since stream start), "
|
||||
"%.3fs (total since request start) (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
)
|
||||
|
||||
async def _recover(exc: BaseException, first_event_seen: bool):
|
||||
nonlocal llm_config_id, llm, agent_config, runtime_rate_limit_recovered
|
||||
nonlocal title_task
|
||||
if not can_recover_provider_rate_limit(
|
||||
exc,
|
||||
first_event_seen=first_event_seen,
|
||||
runtime_rate_limit_recovered=runtime_rate_limit_recovered,
|
||||
requested_llm_config_id=requested_llm_config_id,
|
||||
current_llm_config_id=llm_config_id,
|
||||
):
|
||||
return None
|
||||
runtime_rate_limit_recovered = True
|
||||
previous_config_id = llm_config_id
|
||||
llm_config_id = await reroute_to_next_auto_pin(
|
||||
session,
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
current_llm_config_id=llm_config_id,
|
||||
requires_image_input=requires_image_input,
|
||||
)
|
||||
new_llm, new_agent_config, llm_load_err = await load_llm_bundle(
|
||||
session, config_id=llm_config_id, search_space_id=search_space_id
|
||||
)
|
||||
if llm_load_err:
|
||||
# Re-raise the original so the terminal-error path classifies
|
||||
# it correctly (don't swallow as "config load error").
|
||||
return None
|
||||
llm = new_llm
|
||||
agent_config = new_agent_config
|
||||
|
||||
# Title gen used the initial llm object. After a runtime repin we
|
||||
# keep the stream focused on response recovery and skip title gen
|
||||
# for this turn.
|
||||
if title_task is not None and not title_task.done():
|
||||
title_task.cancel()
|
||||
title_task = None
|
||||
|
||||
_t_rebuild = time.perf_counter()
|
||||
new_agent = await build_main_agent_for_thread(
|
||||
agent_factory,
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
connector_service=connector_service,
|
||||
checkpointer=checkpointer,
|
||||
user_id=user_id,
|
||||
thread_id=chat_id,
|
||||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
disabled_tools=disabled_tools,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Runtime rate-limit recovery repinned "
|
||||
"config_id=%s -> %s and rebuilt agent in %.3fs",
|
||||
previous_config_id,
|
||||
llm_config_id,
|
||||
time.perf_counter() - _t_rebuild,
|
||||
)
|
||||
log_rate_limit_recovered(
|
||||
flow=flow,
|
||||
request_id=request_id,
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
previous_config_id=previous_config_id,
|
||||
new_config_id=llm_config_id,
|
||||
)
|
||||
return new_agent
|
||||
|
||||
async for sse in run_stream_loop(
|
||||
agent=agent,
|
||||
streaming_service=streaming_service,
|
||||
config=config,
|
||||
input_data=input_state,
|
||||
stream_result=stream_result,
|
||||
step_prefix="thinking",
|
||||
initial_step_id=initial_step_id,
|
||||
initial_step_title=initial_step_title,
|
||||
initial_step_items=initial_step_items,
|
||||
fallback_commit_search_space_id=search_space_id,
|
||||
fallback_commit_created_by_id=user_id,
|
||||
fallback_commit_filesystem_mode=(
|
||||
filesystem_selection.mode if filesystem_selection else FilesystemMode.CLOUD
|
||||
),
|
||||
fallback_commit_thread_id=chat_id,
|
||||
runtime_context=runtime_context,
|
||||
content_builder=stream_result.content_builder,
|
||||
recover=_recover,
|
||||
on_first_event=_on_first_event,
|
||||
):
|
||||
yield sse
|
||||
# Inject the title update mid-stream as soon as the background
|
||||
# task finishes; gated so we emit at most once.
|
||||
async for title_sse in maybe_emit_title_update(
|
||||
title_task=title_task,
|
||||
title_emitted=title_emitted,
|
||||
chat_id=chat_id,
|
||||
accumulator=accumulator,
|
||||
streaming_service=streaming_service,
|
||||
):
|
||||
yield title_sse
|
||||
title_emitted = True
|
||||
# Account for the case where the task completed but produced no
|
||||
# title — flip the flag anyway so we don't keep checking it.
|
||||
if (
|
||||
title_task is not None
|
||||
and title_task.done()
|
||||
and not title_emitted
|
||||
):
|
||||
title_emitted = True
|
||||
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
chat_id,
|
||||
)
|
||||
log_system_snapshot("stream_new_chat_END")
|
||||
|
||||
# --- Block 8: Finalize ---
|
||||
|
||||
if stream_result.is_interrupted:
|
||||
ot.add_event("chat.interrupted", {"chat.flow": flow})
|
||||
if title_task is not None and not title_task.done():
|
||||
title_task.cancel()
|
||||
for sse in iter_token_usage_frame(
|
||||
streaming_service,
|
||||
accumulator=accumulator,
|
||||
log_label="interrupted new_chat",
|
||||
):
|
||||
yield sse
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
async for title_sse in await_pending_title_update(
|
||||
title_task=title_task,
|
||||
title_emitted=title_emitted,
|
||||
chat_id=chat_id,
|
||||
accumulator=accumulator,
|
||||
streaming_service=streaming_service,
|
||||
):
|
||||
yield title_sse
|
||||
|
||||
# Finalize premium credit debit with the actual provider cost reported
|
||||
# by LiteLLM, summed across every call in the turn. Mirrors the
|
||||
# pre-cost behaviour of "premium turn → all calls count" so free
|
||||
# sub-agent calls during a premium turn still contribute to the bill
|
||||
# (they're $0 in practice anyway).
|
||||
if premium_reservation is not None and user_id:
|
||||
await finalize_premium(
|
||||
reservation=premium_reservation,
|
||||
user_id=user_id,
|
||||
accumulator=accumulator,
|
||||
)
|
||||
premium_reservation = None
|
||||
|
||||
for sse in iter_token_usage_frame(
|
||||
streaming_service, accumulator=accumulator, log_label="normal new_chat"
|
||||
):
|
||||
yield sse
|
||||
|
||||
for sse in iter_final_frames(streaming_service):
|
||||
yield sse
|
||||
|
||||
except Exception as exc:
|
||||
frames, summary = handle_terminal_exception(
|
||||
exc,
|
||||
flow=flow,
|
||||
flow_label="chat",
|
||||
log_prefix="stream_new_chat",
|
||||
streaming_service=streaming_service,
|
||||
request_id=request_id,
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
chat_span=chat_span,
|
||||
)
|
||||
if summary["busy_error_raised"]:
|
||||
busy_error_raised = True
|
||||
chat_outcome = summary["chat_outcome"]
|
||||
chat_error_category = summary["chat_error_category"]
|
||||
for sse in frames:
|
||||
yield sse
|
||||
|
||||
finally:
|
||||
# Shield the ENTIRE async cleanup from anyio cancel-scope cancellation.
|
||||
# Starlette's BaseHTTPMiddleware uses anyio task groups; on client
|
||||
# disconnect, it cancels the scope with level-triggered cancellation
|
||||
# — every unshielded ``await`` would raise CancelledError immediately.
|
||||
# Without this the very first ``await`` (session.rollback) would
|
||||
# raise, ``except Exception`` wouldn't catch it (CancelledError is a
|
||||
# BaseException), and the rest of cleanup — including session.close()
|
||||
# — would never run.
|
||||
with anyio.CancelScope(shield=True):
|
||||
# Authoritative fallback cleanup for lock/cancel state. Middleware
|
||||
# teardown can be skipped on some client-abort paths.
|
||||
end_turn(str(chat_id))
|
||||
|
||||
if premium_reservation is not None and user_id:
|
||||
await release_premium(
|
||||
reservation=premium_reservation, user_id=user_id
|
||||
)
|
||||
|
||||
await close_session_and_clear_ai_responding(session, chat_id)
|
||||
|
||||
await finalize_assistant_message(
|
||||
stream_result=stream_result,
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
accumulator=accumulator,
|
||||
log_prefix="stream_new_chat",
|
||||
)
|
||||
|
||||
# Persist any sandbox-produced files to local storage so they remain
|
||||
# downloadable after the Daytona sandbox auto-deletes.
|
||||
if stream_result and stream_result.sandbox_files:
|
||||
with contextlib.suppress(Exception):
|
||||
from app.agents.new_chat.sandbox import (
|
||||
is_sandbox_enabled,
|
||||
persist_and_delete_sandbox,
|
||||
)
|
||||
|
||||
if is_sandbox_enabled():
|
||||
with anyio.CancelScope(shield=True):
|
||||
await persist_and_delete_sandbox(
|
||||
chat_id, stream_result.sandbox_files
|
||||
)
|
||||
|
||||
# ``aafter_agent`` doesn't fire on ``interrupt()`` or early bailout.
|
||||
# Skip on ``BusyError`` (caller never acquired the lock).
|
||||
if not busy_error_raised:
|
||||
with contextlib.suppress(Exception):
|
||||
end_turn(str(chat_id))
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] end_turn cleanup (chat_id=%s)", chat_id
|
||||
)
|
||||
|
||||
# Break circular refs held by the agent graph, tools, and LLM
|
||||
# wrappers so the GC can reclaim them in a single pass.
|
||||
agent = llm = connector_service = None # noqa: F841
|
||||
input_state = stream_result = None # noqa: F841
|
||||
session = None # noqa: F841
|
||||
|
||||
run_gc_pass(log_prefix="stream_new_chat", chat_id=chat_id)
|
||||
close_chat_request_span(
|
||||
span_cm=chat_span_cm,
|
||||
span=chat_span,
|
||||
chat_outcome=chat_outcome,
|
||||
chat_agent_mode=chat_agent_mode,
|
||||
flow=flow,
|
||||
chat_error_category=chat_error_category,
|
||||
duration_seconds=time.perf_counter() - _t_total,
|
||||
)
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
"""Concurrent persistence tasks spawned right after the initial validation gate.
|
||||
|
||||
These run *during* the rest of the pre-stream setup so we don't serialize
|
||||
their latency against agent construction. Awaiting them at the SSE message-id
|
||||
yield sites preserves the ghost-thread protection (the user-row INSERT must
|
||||
succeed before any LLM streaming begins).
|
||||
|
||||
The ``set_ai_responding`` flag flip runs fully fire-and-forget on its own
|
||||
shielded session — failures only delay the "AI is responding…" UI flag, not
|
||||
the response itself.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from app.db import shielded_async_session
|
||||
from app.services.chat_session_state_service import set_ai_responding
|
||||
from app.tasks.chat.persistence import (
|
||||
persist_assistant_shell,
|
||||
persist_user_turn,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def spawn_set_ai_responding_bg(
|
||||
*,
|
||||
chat_id: int,
|
||||
user_id: str | None,
|
||||
background_tasks: set[asyncio.Task[Any]],
|
||||
) -> None:
|
||||
"""Fire-and-forget: flip the per-thread AI-responding flag on its own session.
|
||||
|
||||
Errors are swallowed and logged — the worst case is a stale UI flag, which
|
||||
is preferable to delaying the SSE stream behind a flag write.
|
||||
"""
|
||||
if not user_id:
|
||||
return
|
||||
|
||||
async def _bg_set_ai_responding() -> None:
|
||||
try:
|
||||
async with shielded_async_session() as s:
|
||||
await set_ai_responding(s, chat_id, UUID(user_id))
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"set_ai_responding failed (chat_id=%s)",
|
||||
chat_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
t = asyncio.create_task(_bg_set_ai_responding())
|
||||
background_tasks.add(t)
|
||||
t.add_done_callback(background_tasks.discard)
|
||||
|
||||
|
||||
def spawn_persist_user_task(
|
||||
*,
|
||||
chat_id: int,
|
||||
user_id: str | None,
|
||||
turn_id: str,
|
||||
user_query: str,
|
||||
user_image_data_urls: list[str] | None,
|
||||
mentioned_documents: list[dict[str, Any]] | None,
|
||||
background_tasks: set[asyncio.Task[Any]],
|
||||
) -> asyncio.Task[int | None]:
|
||||
"""Spawn the user-row INSERT; await at the user-message-id yield site."""
|
||||
task = asyncio.create_task(
|
||||
persist_user_turn(
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
turn_id=turn_id,
|
||||
user_query=user_query,
|
||||
user_image_data_urls=user_image_data_urls,
|
||||
mentioned_documents=mentioned_documents,
|
||||
)
|
||||
)
|
||||
background_tasks.add(task)
|
||||
task.add_done_callback(background_tasks.discard)
|
||||
return task
|
||||
|
||||
|
||||
def spawn_persist_assistant_shell_task(
|
||||
*,
|
||||
chat_id: int,
|
||||
user_id: str | None,
|
||||
turn_id: str,
|
||||
background_tasks: set[asyncio.Task[Any]],
|
||||
) -> asyncio.Task[int | None]:
|
||||
"""Spawn the assistant-shell INSERT; await at the assistant-message-id yield site."""
|
||||
task = asyncio.create_task(
|
||||
persist_assistant_shell(
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
)
|
||||
background_tasks.add(task)
|
||||
task.add_done_callback(background_tasks.discard)
|
||||
return task
|
||||
|
||||
|
||||
async def await_persist_task(
|
||||
task: asyncio.Task[int | None] | None,
|
||||
*,
|
||||
chat_id: int,
|
||||
turn_id: str,
|
||||
log_label: str,
|
||||
) -> int | None:
|
||||
"""Join a spawned persistence task with ``shield`` + uniform error handling.
|
||||
|
||||
``shield`` keeps the DB write alive if the SSE generator is cancelled by
|
||||
client disconnect mid-await. Returns ``None`` on failure; the caller
|
||||
abort-paths the turn with a friendly error SSE.
|
||||
"""
|
||||
if task is None:
|
||||
return None
|
||||
try:
|
||||
return await asyncio.shield(task)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"%s failed (chat_id=%s, turn_id=%s)", log_label, chat_id, turn_id
|
||||
)
|
||||
return None
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
"""Build the per-invocation ``SurfSenseContextSchema`` for a new-chat turn.
|
||||
|
||||
Carries the per-turn read inputs that middlewares read via
|
||||
``runtime.context.*`` instead of from their ``__init__`` closures, so the same
|
||||
compiled-agent instance can serve multiple turns with different
|
||||
mention lists / request ids / turn ids without rebuilding the graph.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||
|
||||
|
||||
def build_new_chat_runtime_context(
|
||||
*,
|
||||
search_space_id: int,
|
||||
mentioned_document_ids: list[int] | None,
|
||||
accepted_folder_ids: list[int],
|
||||
mentioned_folder_ids: list[int] | None,
|
||||
request_id: str | None,
|
||||
turn_id: str,
|
||||
) -> SurfSenseContextSchema:
|
||||
"""``mentioned_document_ids`` is consumed by ``KnowledgePriorityMiddleware``.
|
||||
|
||||
``accepted_folder_ids`` (post-resolve) wins over the raw
|
||||
``mentioned_folder_ids`` from the request: the resolver drops chips that
|
||||
pointed at deleted folders or folders the caller can't see, so middlewares
|
||||
only get authorized ids.
|
||||
"""
|
||||
return SurfSenseContextSchema(
|
||||
search_space_id=search_space_id,
|
||||
mentioned_document_ids=list(mentioned_document_ids or []),
|
||||
mentioned_folder_ids=list(
|
||||
accepted_folder_ids or mentioned_folder_ids or []
|
||||
),
|
||||
request_id=request_id,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
|
|
@ -0,0 +1,237 @@
|
|||
"""Background thread-title generation (first-response only).
|
||||
|
||||
The first assistant response in a thread gets a short auto-generated title
|
||||
inserted into ``new_chat_threads.title``. We:
|
||||
|
||||
1. Spawn the generation as an ``asyncio.Task`` so it runs in parallel with
|
||||
the agent stream (no extra TTFT).
|
||||
2. Probe inside the task (on its own shielded session) whether this is
|
||||
actually the first response — newer turns short-circuit to ``None``.
|
||||
3. Inject the resulting ``thread-title-update`` SSE frame on the first agent
|
||||
event after the task completes (mid-stream interlock), or right before
|
||||
the finish frames (post-stream join) if the task hadn't finished yet.
|
||||
|
||||
Usage tokens come directly off the response (LiteLLM's async callback fires
|
||||
via fire-and-forget ``create_task``, so the ``TokenTrackingCallback`` would
|
||||
run too late). We also blank the per-task accumulator so the late callback
|
||||
doesn't double-count.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import NewChatMessage, NewChatThread, shielded_async_session
|
||||
from app.prompts import TITLE_GENERATION_PROMPT
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
from app.services.token_tracking_service import TokenAccumulator
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def spawn_title_task(
|
||||
*,
|
||||
chat_id: int,
|
||||
user_query: str,
|
||||
user_image_data_urls: list[str] | None,
|
||||
assistant_message_id: int | None,
|
||||
llm: Any,
|
||||
agent_config: AgentConfig | None,
|
||||
) -> asyncio.Task[tuple[str | None, dict | None]] | None:
|
||||
"""Spawn ``_generate_title``; returns ``None`` when prerequisites aren't met.
|
||||
|
||||
Title gen is gated on a real ``assistant_message_id`` so a stream that
|
||||
aborts before persistence can never leave a thread with a title and no
|
||||
anchoring rows.
|
||||
"""
|
||||
if assistant_message_id is None:
|
||||
return None
|
||||
return asyncio.create_task(
|
||||
_generate_title(
|
||||
chat_id=chat_id,
|
||||
user_query=user_query,
|
||||
user_image_data_urls=user_image_data_urls,
|
||||
assistant_message_id=assistant_message_id,
|
||||
llm=llm,
|
||||
agent_config=agent_config,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _generate_title(
|
||||
*,
|
||||
chat_id: int,
|
||||
user_query: str,
|
||||
user_image_data_urls: list[str] | None,
|
||||
assistant_message_id: int,
|
||||
llm: Any,
|
||||
agent_config: AgentConfig | None,
|
||||
) -> tuple[str | None, dict | None]:
|
||||
"""Probe is-first-response, then call ``acompletion``. Returns ``(title, usage)``."""
|
||||
try:
|
||||
from litellm import acompletion
|
||||
|
||||
from app.services.llm_router_service import LLMRouterService
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.services.token_tracking_service import _turn_accumulator
|
||||
|
||||
# Excludes this turn's own assistant row (pre-written by
|
||||
# ``persist_assistant_shell``) — without the ``!=`` filter the gate
|
||||
# would false-negative on every turn after the first.
|
||||
try:
|
||||
async with shielded_async_session() as probe_session:
|
||||
probe_result = await probe_session.execute(
|
||||
select(NewChatMessage.id)
|
||||
.filter(
|
||||
NewChatMessage.thread_id == chat_id,
|
||||
NewChatMessage.role == "assistant",
|
||||
NewChatMessage.id != assistant_message_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
is_first_response = probe_result.scalars().first() is None
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[TitleGen] first-response probe failed (chat_id=%s)",
|
||||
chat_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return None, None
|
||||
|
||||
if not is_first_response:
|
||||
return None, None
|
||||
|
||||
_turn_accumulator.set(None)
|
||||
|
||||
title_seed = user_query.strip() or (
|
||||
f"[{len(user_image_data_urls or [])} image(s)]"
|
||||
if user_image_data_urls
|
||||
else ""
|
||||
)
|
||||
prompt = TITLE_GENERATION_PROMPT.replace(
|
||||
"{user_query}", title_seed[:500] or "(message)"
|
||||
)
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
if getattr(llm, "model", None) == "auto":
|
||||
router = LLMRouterService.get_router()
|
||||
response = await router.acompletion(model="auto", messages=messages)
|
||||
else:
|
||||
# Apply the same ``api_base`` cascade chat / vision / image-gen
|
||||
# call sites use so we never inherit ``litellm.api_base``
|
||||
# (commonly set by ``AZURE_OPENAI_ENDPOINT``) when the chat
|
||||
# config itself ships an empty ``api_base``. Without this the
|
||||
# title-gen on an OpenRouter chat config would 404 against the
|
||||
# inherited Azure endpoint — see ``provider_api_base`` for the
|
||||
# same bug repro on the image-gen / vision paths.
|
||||
raw_model = getattr(llm, "model", "") or ""
|
||||
provider_prefix = (
|
||||
raw_model.split("/", 1)[0] if "/" in raw_model else None
|
||||
)
|
||||
provider_value = (
|
||||
agent_config.provider if agent_config is not None else None
|
||||
)
|
||||
title_api_base = resolve_api_base(
|
||||
provider=provider_value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=getattr(llm, "api_base", None),
|
||||
)
|
||||
response = await acompletion(
|
||||
model=raw_model,
|
||||
messages=messages,
|
||||
api_key=getattr(llm, "api_key", None),
|
||||
api_base=title_api_base,
|
||||
)
|
||||
|
||||
usage_info = None
|
||||
usage = getattr(response, "usage", None)
|
||||
if usage:
|
||||
raw_model = getattr(llm, "model", "") or ""
|
||||
model_name = (
|
||||
raw_model.split("/", 1)[-1]
|
||||
if "/" in raw_model
|
||||
else (raw_model or response.model or "unknown")
|
||||
)
|
||||
usage_info = {
|
||||
"model": model_name,
|
||||
"prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0,
|
||||
"completion_tokens": getattr(usage, "completion_tokens", 0) or 0,
|
||||
"total_tokens": getattr(usage, "total_tokens", 0) or 0,
|
||||
}
|
||||
|
||||
raw_title = response.choices[0].message.content.strip()
|
||||
if raw_title and len(raw_title) <= 100:
|
||||
return raw_title.strip("\"'"), usage_info
|
||||
return None, usage_info
|
||||
except Exception:
|
||||
logger.exception("[TitleGen] _generate_title failed")
|
||||
return None, None
|
||||
|
||||
|
||||
async def maybe_emit_title_update(
|
||||
*,
|
||||
title_task: asyncio.Task[tuple[str | None, dict | None]] | None,
|
||||
title_emitted: bool,
|
||||
chat_id: int,
|
||||
accumulator: TokenAccumulator,
|
||||
streaming_service: VercelStreamingService,
|
||||
):
|
||||
"""Inject one ``thread-title-update`` SSE if the task completed.
|
||||
|
||||
Yields the SSE frame (when applicable). Returns nothing; the orchestrator
|
||||
flips ``title_emitted`` itself after iterating so we don't fight Python's
|
||||
nonlocal-in-generator semantics.
|
||||
"""
|
||||
if title_task is None or title_emitted or not title_task.done():
|
||||
return
|
||||
generated_title, title_usage = title_task.result()
|
||||
if title_usage:
|
||||
accumulator.add(**title_usage)
|
||||
if generated_title:
|
||||
async with shielded_async_session() as title_session:
|
||||
title_thread_result = await title_session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == chat_id)
|
||||
)
|
||||
title_thread = title_thread_result.scalars().first()
|
||||
if title_thread:
|
||||
title_thread.title = generated_title
|
||||
await title_session.commit()
|
||||
yield streaming_service.format_thread_title_update(chat_id, generated_title)
|
||||
|
||||
|
||||
async def await_pending_title_update(
|
||||
*,
|
||||
title_task: asyncio.Task[tuple[str | None, dict | None]] | None,
|
||||
title_emitted: bool,
|
||||
chat_id: int,
|
||||
accumulator: TokenAccumulator,
|
||||
streaming_service: VercelStreamingService,
|
||||
):
|
||||
"""If the task hadn't completed during the stream, await it now and emit.
|
||||
|
||||
Used right before the finish frames in the success path. Mirror of
|
||||
``maybe_emit_title_update`` but unconditionally awaits.
|
||||
"""
|
||||
if title_task is None or title_emitted:
|
||||
return
|
||||
generated_title, title_usage = await title_task
|
||||
if title_usage:
|
||||
accumulator.add(**title_usage)
|
||||
if generated_title:
|
||||
async with shielded_async_session() as title_session:
|
||||
title_thread_result = await title_session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == chat_id)
|
||||
)
|
||||
title_thread = title_thread_result.scalars().first()
|
||||
if title_thread:
|
||||
title_thread.title = generated_title
|
||||
await title_session.commit()
|
||||
yield streaming_service.format_thread_title_update(chat_id, generated_title)
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
"""Resume-chat streaming flow.
|
||||
|
||||
Public entry point ``stream_resume_chat`` is the slim coroutine in
|
||||
``orchestrator.py`` that composes the per-concern modules in this folder and
|
||||
the building blocks under ``flows/shared/``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.tasks.chat.streaming.flows.resume_chat.orchestrator import stream_resume_chat
|
||||
|
||||
__all__ = ["stream_resume_chat"]
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
"""Pre-write a fresh assistant row for this resume turn.
|
||||
|
||||
The original (interrupted) ``stream_new_chat`` invocation already persisted
|
||||
its own assistant row anchored to a different ``turn_id``; resume allocates a
|
||||
new ``turn_id`` (per-request, see ``orchestrator``) so we need a separate row
|
||||
keyed on the same ``(thread_id, turn_id, ASSISTANT)`` invariant.
|
||||
|
||||
Idempotent against migration 141's partial unique index — recovers the
|
||||
existing id on retry.
|
||||
|
||||
Resume does NOT emit ``data-user-message-id``: the user row is from the
|
||||
original interrupted turn (different ``turn_id``) and is never re-persisted
|
||||
here. See B5 in the ``sse-based_message_id_handshake`` plan.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.tasks.chat.persistence import persist_assistant_shell
|
||||
|
||||
|
||||
async def persist_resume_assistant_shell(
|
||||
*,
|
||||
chat_id: int,
|
||||
user_id: str | None,
|
||||
turn_id: str,
|
||||
) -> int | None:
|
||||
return await persist_assistant_shell(
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
|
|
@ -0,0 +1,629 @@
|
|||
"""``stream_resume_chat`` — public entry point for a HITL resume turn.
|
||||
|
||||
Slim composition layer over the per-concern modules in this folder and the
|
||||
building blocks under ``flows/shared/``. Mirrors ``stream_new_chat`` but:
|
||||
|
||||
* No user-message persistence (the original turn already wrote it).
|
||||
* No mentions / surfsense-doc / report context assembly (seeded by original).
|
||||
* No title generation (only fires on first-response).
|
||||
* Synchronous ``persist_assistant_shell`` call (we have no other in-flight
|
||||
pre-stream work to overlap it with).
|
||||
* ``input_data`` is a ``Command(resume=lg_resume_map)`` instead of a
|
||||
LangChain message list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import gc
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import uuid as _uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import anyio
|
||||
|
||||
from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent
|
||||
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||
from app.agents.new_chat.middleware.busy_mutex import end_turn
|
||||
from app.config import config as _app_config
|
||||
from app.db import ChatVisibility, async_session_maker, shielded_async_session
|
||||
from app.observability import otel as ot
|
||||
from app.services.chat_session_state_service import set_ai_responding
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.tasks.chat.content_builder import AssistantContentBuilder
|
||||
from app.tasks.chat.streaming.agent.builder import build_main_agent_for_thread
|
||||
from app.tasks.chat.streaming.contract.file_contract import log_file_contract
|
||||
from app.tasks.chat.streaming.errors.emitter import emit_stream_terminal_error
|
||||
from app.tasks.chat.streaming.flows.resume_chat.assistant_shell import (
|
||||
persist_resume_assistant_shell,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.resume_chat.resume_routing import (
|
||||
build_resume_routing,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.resume_chat.runtime_context import (
|
||||
build_resume_chat_runtime_context,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.assistant_finalize import (
|
||||
finalize_assistant_message,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.finalize_emit import iter_token_usage_frame
|
||||
from app.tasks.chat.streaming.flows.shared.finally_cleanup import (
|
||||
close_session_and_clear_ai_responding,
|
||||
run_gc_pass,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.first_frames import (
|
||||
iter_final_frames,
|
||||
iter_initial_frames,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.llm_bundle import load_llm_bundle
|
||||
from app.tasks.chat.streaming.flows.shared.pre_stream_setup import (
|
||||
get_chat_checkpointer,
|
||||
setup_connector_and_firecrawl,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.premium_quota import (
|
||||
PremiumReservation,
|
||||
finalize_premium,
|
||||
needs_premium_quota,
|
||||
release_premium,
|
||||
reserve_premium,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.rate_limit_recovery import (
|
||||
can_recover_provider_rate_limit,
|
||||
log_rate_limit_recovered,
|
||||
reroute_to_next_auto_pin,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.span import (
|
||||
close_chat_request_span,
|
||||
open_chat_request_span,
|
||||
set_agent_mode,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.stream_loop import run_stream_loop
|
||||
from app.tasks.chat.streaming.flows.shared.terminal_error import (
|
||||
handle_terminal_exception,
|
||||
)
|
||||
from app.tasks.chat.streaming.shared.stream_result import StreamResult
|
||||
from app.tasks.chat.streaming.shared.utils import resume_step_prefix
|
||||
from app.utils.perf import get_perf_logger, log_system_snapshot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
async def stream_resume_chat(
|
||||
chat_id: int,
|
||||
search_space_id: int,
|
||||
decisions: list[dict],
|
||||
user_id: str | None = None,
|
||||
llm_config_id: int = -1,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
filesystem_selection: FilesystemSelection | None = None,
|
||||
request_id: str | None = None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Resume a paused HITL turn with the user's decisions.
|
||||
|
||||
Mirrors ``stream_new_chat`` except for the resume-specific routing of
|
||||
``decisions`` to per-``tool_call_id`` slices (``build_resume_routing``).
|
||||
"""
|
||||
streaming_service = VercelStreamingService()
|
||||
stream_result = StreamResult()
|
||||
_t_total = time.perf_counter()
|
||||
fs_mode = filesystem_selection.mode.value if filesystem_selection else "cloud"
|
||||
fs_platform = (
|
||||
filesystem_selection.client_platform.value if filesystem_selection else "web"
|
||||
)
|
||||
stream_result.request_id = request_id
|
||||
stream_result.turn_id = f"{chat_id}:{int(time.time() * 1000)}"
|
||||
stream_result.filesystem_mode = fs_mode
|
||||
stream_result.client_platform = fs_platform
|
||||
|
||||
chat_agent_mode = "unknown"
|
||||
chat_outcome = "success"
|
||||
chat_error_category: str | None = None
|
||||
chat_span_cm, chat_span = open_chat_request_span(
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
flow="resume",
|
||||
request_id=request_id,
|
||||
turn_id=stream_result.turn_id,
|
||||
filesystem_mode=fs_mode,
|
||||
client_platform=fs_platform,
|
||||
agent_mode=chat_agent_mode,
|
||||
)
|
||||
log_file_contract("turn_start", stream_result)
|
||||
_perf_log.info(
|
||||
"[stream_resume] filesystem_mode=%s client_platform=%s",
|
||||
fs_mode,
|
||||
fs_platform,
|
||||
)
|
||||
|
||||
from app.services.token_tracking_service import start_turn
|
||||
|
||||
accumulator = start_turn()
|
||||
|
||||
premium_reservation: PremiumReservation | None = None
|
||||
busy_error_raised = False
|
||||
|
||||
emit_stream_error = partial(
|
||||
emit_stream_terminal_error,
|
||||
streaming_service=streaming_service,
|
||||
flow="resume",
|
||||
request_id=request_id,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
session = async_session_maker()
|
||||
try:
|
||||
if user_id:
|
||||
await set_ai_responding(session, chat_id, UUID(user_id))
|
||||
|
||||
requested_llm_config_id = llm_config_id
|
||||
|
||||
# --- LLM config ---
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
try:
|
||||
from app.services.auto_model_pin_service import (
|
||||
resolve_or_get_pinned_llm_config_id,
|
||||
)
|
||||
|
||||
pinned = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=llm_config_id,
|
||||
)
|
||||
llm_config_id = pinned.resolved_llm_config_id
|
||||
ot.add_event(
|
||||
"model.pin.resolved",
|
||||
{
|
||||
"pin.requested_id": requested_llm_config_id,
|
||||
"pin.resolved_id": llm_config_id,
|
||||
"pin.requires_image_input": False,
|
||||
},
|
||||
)
|
||||
except ValueError as pin_error:
|
||||
yield emit_stream_error(
|
||||
message=str(pin_error),
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
llm, agent_config, llm_load_error = await load_llm_bundle(
|
||||
session, config_id=llm_config_id, search_space_id=search_space_id
|
||||
)
|
||||
if llm_load_error:
|
||||
yield emit_stream_error(
|
||||
message=llm_load_error,
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
_perf_log.info(
|
||||
"[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
||||
if needs_premium_quota(agent_config, user_id):
|
||||
premium_reservation = await reserve_premium(
|
||||
agent_config=agent_config, user_id=user_id # type: ignore[arg-type]
|
||||
)
|
||||
if not premium_reservation.allowed:
|
||||
ot.add_event(
|
||||
"quota.denied", {"quota.code": "PREMIUM_QUOTA_EXHAUSTED"}
|
||||
)
|
||||
if requested_llm_config_id == 0:
|
||||
try:
|
||||
pinned_fb = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=0,
|
||||
force_repin_free=True,
|
||||
)
|
||||
llm_config_id = pinned_fb.resolved_llm_config_id
|
||||
ot.add_event(
|
||||
"model.repin",
|
||||
{
|
||||
"repin.reason": "premium_quota_exhausted",
|
||||
"repin.to_config_id": llm_config_id,
|
||||
},
|
||||
)
|
||||
except ValueError as pin_error:
|
||||
yield emit_stream_error(
|
||||
message=str(pin_error),
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
llm, agent_config, llm_load_error = await load_llm_bundle(
|
||||
session,
|
||||
config_id=llm_config_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
if llm_load_error:
|
||||
yield emit_stream_error(
|
||||
message=llm_load_error,
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
premium_reservation = None
|
||||
from app.tasks.chat.streaming.errors.classifier import (
|
||||
log_chat_stream_error,
|
||||
)
|
||||
|
||||
log_chat_stream_error(
|
||||
flow="resume",
|
||||
error_kind="premium_quota_exhausted",
|
||||
error_code="PREMIUM_QUOTA_EXHAUSTED",
|
||||
severity="info",
|
||||
is_expected=True,
|
||||
request_id=request_id,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
message=(
|
||||
"Premium quota exhausted on pinned model; "
|
||||
"auto-fallback switched to a free model"
|
||||
),
|
||||
extra={
|
||||
"fallback_config_id": llm_config_id,
|
||||
"auto_fallback": True,
|
||||
},
|
||||
)
|
||||
else:
|
||||
yield emit_stream_error(
|
||||
message=(
|
||||
"Buy more tokens to continue with this model, or "
|
||||
"switch to a free model"
|
||||
),
|
||||
error_kind="premium_quota_exhausted",
|
||||
error_code="PREMIUM_QUOTA_EXHAUSTED",
|
||||
severity="info",
|
||||
is_expected=True,
|
||||
extra={
|
||||
"resolved_config_id": llm_config_id,
|
||||
"auto_fallback": False,
|
||||
},
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
if not llm:
|
||||
yield emit_stream_error(
|
||||
message="Failed to create LLM instance",
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
# --- Pre-stream setup ---
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
connector_service, firecrawl_api_key = await setup_connector_and_firecrawl(
|
||||
session, search_space_id=search_space_id
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_resume] Connector service + firecrawl key in %.3fs",
|
||||
time.perf_counter() - _t0,
|
||||
)
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
checkpointer = await get_chat_checkpointer()
|
||||
_perf_log.info(
|
||||
"[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
use_multi_agent = bool(_app_config.MULTI_AGENT_CHAT_ENABLED)
|
||||
chat_agent_mode = "multi" if use_multi_agent else "single"
|
||||
set_agent_mode(chat_span, chat_agent_mode)
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
agent_factory = (
|
||||
create_multi_agent_chat_deep_agent
|
||||
if use_multi_agent
|
||||
else create_surfsense_deep_agent
|
||||
)
|
||||
agent = await build_main_agent_for_thread(
|
||||
agent_factory,
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
connector_service=connector_service,
|
||||
checkpointer=checkpointer,
|
||||
user_id=user_id,
|
||||
thread_id=chat_id,
|
||||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
disabled_tools=disabled_tools,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
|
||||
)
|
||||
|
||||
# Release the transaction before streaming (same rationale as stream_new_chat).
|
||||
await session.commit()
|
||||
session.expunge_all()
|
||||
|
||||
_perf_log.info(
|
||||
"[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)",
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
)
|
||||
|
||||
# --- Resume routing ---
|
||||
|
||||
from langgraph.types import Command
|
||||
|
||||
routing = await build_resume_routing(
|
||||
agent, chat_id=chat_id, decisions=decisions
|
||||
)
|
||||
|
||||
config = {
|
||||
"configurable": {
|
||||
"thread_id": str(chat_id),
|
||||
"request_id": request_id or "unknown",
|
||||
"turn_id": stream_result.turn_id,
|
||||
# Per-``tool_call_id`` resume slices read by
|
||||
# ``SurfSenseCheckpointedSubAgentMiddleware``. Parallel
|
||||
# siblings each pop their own entry, so they never race.
|
||||
"surfsense_resume_value": routing.routed_resume_value,
|
||||
},
|
||||
# Same rationale as ``stream_new_chat``: effectively uncapped to
|
||||
# mirror the agent default and OpenCode's session loop. Doom-loop
|
||||
# / call-limit middleware enforce the real ceiling.
|
||||
"recursion_limit": 10_000,
|
||||
}
|
||||
|
||||
# --- First SSE frames ---
|
||||
|
||||
for sse in iter_initial_frames(streaming_service, turn_id=stream_result.turn_id):
|
||||
yield sse
|
||||
|
||||
# --- Assistant-shell persistence + id frame ---
|
||||
|
||||
assistant_message_id = await persist_resume_assistant_shell(
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
turn_id=stream_result.turn_id,
|
||||
)
|
||||
if assistant_message_id is None:
|
||||
yield emit_stream_error(
|
||||
message=(
|
||||
"We couldn't initialize the assistant message. Please try again."
|
||||
),
|
||||
error_kind="server_error",
|
||||
error_code="MESSAGE_PERSIST_FAILED",
|
||||
)
|
||||
for sse in iter_final_frames(streaming_service):
|
||||
yield sse
|
||||
return
|
||||
|
||||
yield streaming_service.format_data(
|
||||
"assistant-message-id",
|
||||
{"message_id": assistant_message_id, "turn_id": stream_result.turn_id},
|
||||
)
|
||||
|
||||
stream_result.assistant_message_id = assistant_message_id
|
||||
stream_result.content_builder = AssistantContentBuilder()
|
||||
|
||||
runtime_context = build_resume_chat_runtime_context(
|
||||
search_space_id=search_space_id,
|
||||
request_id=request_id,
|
||||
turn_id=stream_result.turn_id,
|
||||
)
|
||||
|
||||
# --- Stream loop ---
|
||||
|
||||
_t_stream_start = time.perf_counter()
|
||||
runtime_rate_limit_recovered = False
|
||||
|
||||
def _on_first_event() -> None:
|
||||
_perf_log.info(
|
||||
"[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
)
|
||||
|
||||
async def _recover(exc: BaseException, first_event_seen: bool):
|
||||
nonlocal llm_config_id, llm, agent_config, runtime_rate_limit_recovered
|
||||
if not can_recover_provider_rate_limit(
|
||||
exc,
|
||||
first_event_seen=first_event_seen,
|
||||
runtime_rate_limit_recovered=runtime_rate_limit_recovered,
|
||||
requested_llm_config_id=requested_llm_config_id,
|
||||
current_llm_config_id=llm_config_id,
|
||||
):
|
||||
return None
|
||||
runtime_rate_limit_recovered = True
|
||||
previous_config_id = llm_config_id
|
||||
llm_config_id = await reroute_to_next_auto_pin(
|
||||
session,
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
current_llm_config_id=llm_config_id,
|
||||
requires_image_input=False,
|
||||
)
|
||||
new_llm, new_agent_config, llm_load_err = await load_llm_bundle(
|
||||
session, config_id=llm_config_id, search_space_id=search_space_id
|
||||
)
|
||||
if llm_load_err:
|
||||
return None
|
||||
llm = new_llm
|
||||
agent_config = new_agent_config
|
||||
|
||||
_t_rebuild = time.perf_counter()
|
||||
new_agent = await build_main_agent_for_thread(
|
||||
agent_factory,
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
connector_service=connector_service,
|
||||
checkpointer=checkpointer,
|
||||
user_id=user_id,
|
||||
thread_id=chat_id,
|
||||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
disabled_tools=disabled_tools,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_resume] Runtime rate-limit recovery repinned "
|
||||
"config_id=%s -> %s and rebuilt agent in %.3fs",
|
||||
previous_config_id,
|
||||
llm_config_id,
|
||||
time.perf_counter() - _t_rebuild,
|
||||
)
|
||||
log_rate_limit_recovered(
|
||||
flow="resume",
|
||||
request_id=request_id,
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
previous_config_id=previous_config_id,
|
||||
new_config_id=llm_config_id,
|
||||
)
|
||||
return new_agent
|
||||
|
||||
async for sse in run_stream_loop(
|
||||
agent=agent,
|
||||
streaming_service=streaming_service,
|
||||
config=config,
|
||||
input_data=Command(resume=routing.lg_resume_map),
|
||||
stream_result=stream_result,
|
||||
step_prefix=resume_step_prefix(stream_result.turn_id),
|
||||
fallback_commit_search_space_id=search_space_id,
|
||||
fallback_commit_created_by_id=user_id,
|
||||
fallback_commit_filesystem_mode=(
|
||||
filesystem_selection.mode if filesystem_selection else FilesystemMode.CLOUD
|
||||
),
|
||||
fallback_commit_thread_id=chat_id,
|
||||
runtime_context=runtime_context,
|
||||
content_builder=stream_result.content_builder,
|
||||
recover=_recover,
|
||||
on_first_event=_on_first_event,
|
||||
):
|
||||
yield sse
|
||||
|
||||
_perf_log.info(
|
||||
"[stream_resume] Agent stream completed in %.3fs (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
chat_id,
|
||||
)
|
||||
|
||||
# --- Finalize ---
|
||||
|
||||
if stream_result.is_interrupted:
|
||||
ot.add_event("chat.interrupted", {"chat.flow": "resume"})
|
||||
for sse in iter_token_usage_frame(
|
||||
streaming_service,
|
||||
accumulator=accumulator,
|
||||
log_label="interrupted resume_chat",
|
||||
):
|
||||
yield sse
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
if premium_reservation is not None and user_id:
|
||||
await finalize_premium(
|
||||
reservation=premium_reservation,
|
||||
user_id=user_id,
|
||||
accumulator=accumulator,
|
||||
)
|
||||
premium_reservation = None
|
||||
|
||||
for sse in iter_token_usage_frame(
|
||||
streaming_service, accumulator=accumulator, log_label="normal resume_chat"
|
||||
):
|
||||
yield sse
|
||||
|
||||
for sse in iter_final_frames(streaming_service):
|
||||
yield sse
|
||||
|
||||
except Exception as exc:
|
||||
frames, summary = handle_terminal_exception(
|
||||
exc,
|
||||
flow="resume",
|
||||
flow_label="resume",
|
||||
log_prefix="stream_resume_chat",
|
||||
streaming_service=streaming_service,
|
||||
request_id=request_id,
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
chat_span=chat_span,
|
||||
)
|
||||
if summary["busy_error_raised"]:
|
||||
busy_error_raised = True
|
||||
chat_outcome = summary["chat_outcome"]
|
||||
chat_error_category = summary["chat_error_category"]
|
||||
for sse in frames:
|
||||
yield sse
|
||||
|
||||
finally:
|
||||
with anyio.CancelScope(shield=True):
|
||||
end_turn(str(chat_id))
|
||||
|
||||
if premium_reservation is not None and user_id:
|
||||
await release_premium(
|
||||
reservation=premium_reservation, user_id=user_id
|
||||
)
|
||||
|
||||
await close_session_and_clear_ai_responding(session, chat_id)
|
||||
|
||||
await finalize_assistant_message(
|
||||
stream_result=stream_result,
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
accumulator=accumulator,
|
||||
log_prefix="stream_resume",
|
||||
)
|
||||
|
||||
# Release the lock from the original interrupted turn or any
|
||||
# re-interrupt/bailout. Skip on ``BusyError`` (lock not held here).
|
||||
if not busy_error_raised:
|
||||
with contextlib.suppress(Exception):
|
||||
end_turn(str(chat_id))
|
||||
_perf_log.info(
|
||||
"[stream_resume] end_turn cleanup (chat_id=%s)", chat_id
|
||||
)
|
||||
|
||||
agent = llm = connector_service = None # noqa: F841
|
||||
stream_result = None # noqa: F841
|
||||
session = None # noqa: F841
|
||||
|
||||
run_gc_pass(log_prefix="stream_resume", chat_id=chat_id)
|
||||
close_chat_request_span(
|
||||
span_cm=chat_span_cm,
|
||||
span=chat_span,
|
||||
chat_outcome=chat_outcome,
|
||||
chat_agent_mode=chat_agent_mode,
|
||||
flow="resume",
|
||||
chat_error_category=chat_error_category,
|
||||
duration_seconds=time.perf_counter() - _t_total,
|
||||
)
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
"""Route a flat ``decisions`` list back to the right paused subagent.
|
||||
|
||||
Each pending interrupt is stamped with its originating ``tool_call_id`` (see
|
||||
``checkpointed_subagent_middleware.propagation``) so the resume slicer can
|
||||
re-target each ``HumanReview`` decision at the right ``tool_call_id``.
|
||||
|
||||
LangGraph rejects scalar ``Command(resume=...)`` when multiple interrupts are
|
||||
pending (parallel HITL); the mapped form works for the single-pause case too,
|
||||
so we always use it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResumeRoutingPayload:
|
||||
"""Resolved per-``tool_call_id`` resume slices + the lg-shaped resume map."""
|
||||
|
||||
routed_resume_value: dict[str, Any]
|
||||
lg_resume_map: dict[str, Any]
|
||||
|
||||
|
||||
async def build_resume_routing(
|
||||
agent: Any,
|
||||
*,
|
||||
chat_id: int,
|
||||
decisions: list[dict],
|
||||
) -> ResumeRoutingPayload:
|
||||
"""Read parent_state, collect pending tool-calls, slice decisions, build map.
|
||||
|
||||
The middleware reads its per-``tool_call_id`` resume slice from the
|
||||
``surfsense_resume_value`` configurable; parallel siblings each pop their
|
||||
own entry so they never race.
|
||||
"""
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
|
||||
build_lg_resume_map,
|
||||
collect_pending_tool_calls,
|
||||
slice_decisions_by_tool_call,
|
||||
)
|
||||
|
||||
parent_state = await agent.aget_state(
|
||||
{"configurable": {"thread_id": str(chat_id)}}
|
||||
)
|
||||
pending = collect_pending_tool_calls(parent_state)
|
||||
_perf_log.info(
|
||||
"[hitl_route] resume_entry chat_id=%s decisions=%d pending_subagents=%d",
|
||||
chat_id,
|
||||
len(decisions),
|
||||
len(pending),
|
||||
)
|
||||
routed_resume_value = slice_decisions_by_tool_call(decisions, pending)
|
||||
lg_resume_map = build_lg_resume_map(parent_state, routed_resume_value)
|
||||
return ResumeRoutingPayload(
|
||||
routed_resume_value=routed_resume_value,
|
||||
lg_resume_map=lg_resume_map,
|
||||
)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
"""Build the per-invocation ``SurfSenseContextSchema`` for a resume turn.
|
||||
|
||||
Resume doesn't carry new ``mentioned_document_ids`` (those are seeded by the
|
||||
original turn). We still build the context so future middleware extensions
|
||||
can rely on ``runtime.context`` always being populated.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||
|
||||
|
||||
def build_resume_chat_runtime_context(
|
||||
*,
|
||||
search_space_id: int,
|
||||
request_id: str | None,
|
||||
turn_id: str,
|
||||
) -> SurfSenseContextSchema:
|
||||
return SurfSenseContextSchema(
|
||||
search_space_id=search_space_id,
|
||||
request_id=request_id,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
"""Building blocks shared by ``new_chat`` and ``resume_chat`` orchestrators."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
"""Server-side assistant-message + token_usage finalization.
|
||||
|
||||
Runs inside the streaming flow's ``finally`` block, after the main session has
|
||||
been closed (uses its own shielded session, so we don't fight the same DB
|
||||
connection).
|
||||
|
||||
Idempotent against the legacy frontend ``appendMessage`` recovery branch:
|
||||
|
||||
* the assistant row was already INSERTed by ``persist_assistant_shell``
|
||||
earlier in the turn, so this just UPDATEs it with the rich
|
||||
``ContentPart[]`` projection from the builder.
|
||||
* ``token_usage`` uses ``INSERT ... ON CONFLICT DO NOTHING`` against the
|
||||
partial unique index from migration 142, so a racing append_message
|
||||
recovery branch can never double-write.
|
||||
|
||||
``mark_interrupted`` closes any open text/reasoning blocks and flips running
|
||||
tool-calls (no result) to ``state=aborted`` so the persisted JSONB reflects a
|
||||
coherent end-state even on client disconnect.
|
||||
|
||||
Never raises (best-effort, logs only).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.tasks.chat.streaming.shared.stream_result import StreamResult
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.token_tracking_service import TokenAccumulator
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
async def finalize_assistant_message(
|
||||
*,
|
||||
stream_result: StreamResult | None,
|
||||
chat_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
accumulator: TokenAccumulator,
|
||||
log_prefix: str,
|
||||
) -> None:
|
||||
"""Snapshot the content builder and persist the final assistant payload.
|
||||
|
||||
No-op when ``stream_result`` was never populated, the turn never reached
|
||||
``persist_assistant_shell`` (no ``assistant_message_id``), or the turn id
|
||||
was never assigned.
|
||||
"""
|
||||
if not (
|
||||
stream_result
|
||||
and stream_result.turn_id
|
||||
and stream_result.assistant_message_id
|
||||
):
|
||||
return
|
||||
|
||||
from app.tasks.chat.persistence import finalize_assistant_turn
|
||||
|
||||
builder_stats: dict[str, int] | None = None
|
||||
if stream_result.content_builder is not None:
|
||||
stream_result.content_builder.mark_interrupted()
|
||||
# Snapshot stats BEFORE ``snapshot()`` deepcopies so the perf log
|
||||
# records the actual finalised payload (post-mark_interrupted), not
|
||||
# the live-mutating builder state.
|
||||
builder_stats = stream_result.content_builder.stats()
|
||||
content_payload = stream_result.content_builder.snapshot()
|
||||
else:
|
||||
# Defensive fallback — we always set the builder alongside
|
||||
# ``assistant_message_id`` in the orchestrator, so this branch only
|
||||
# fires if a future refactor ever decouples them. Persist whatever
|
||||
# accumulated text we captured so the row at least renders.
|
||||
content_payload = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": stream_result.accumulated_text or "",
|
||||
}
|
||||
]
|
||||
|
||||
if builder_stats is not None:
|
||||
_perf_log.info(
|
||||
"[%s] finalize_payload chat_id=%s "
|
||||
"message_id=%s parts=%d bytes=%d text=%d "
|
||||
"reasoning=%d tool_calls=%d "
|
||||
"tool_calls_completed=%d tool_calls_aborted=%d "
|
||||
"thinking_step_parts=%d step_separators=%d",
|
||||
log_prefix,
|
||||
chat_id,
|
||||
stream_result.assistant_message_id,
|
||||
builder_stats["parts"],
|
||||
builder_stats["bytes"],
|
||||
builder_stats["text"],
|
||||
builder_stats["reasoning"],
|
||||
builder_stats["tool_calls"],
|
||||
builder_stats["tool_calls_completed"],
|
||||
builder_stats["tool_calls_aborted"],
|
||||
builder_stats["thinking_step_parts"],
|
||||
builder_stats["step_separators"],
|
||||
)
|
||||
|
||||
await finalize_assistant_turn(
|
||||
message_id=stream_result.assistant_message_id,
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
turn_id=stream_result.turn_id,
|
||||
content=content_payload,
|
||||
accumulator=accumulator,
|
||||
)
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
"""Emit the per-turn token-usage SSE frame from the accumulator.
|
||||
|
||||
``per_message_summary()`` returns ``None`` when the turn made no chargeable
|
||||
LLM calls (e.g. interrupt-on-input). In that case we skip the frame; the
|
||||
frontend has no usage to render.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.token_tracking_service import TokenAccumulator
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def iter_token_usage_frame(
|
||||
streaming_service: VercelStreamingService,
|
||||
*,
|
||||
accumulator: TokenAccumulator,
|
||||
log_label: str,
|
||||
):
|
||||
"""Yield zero or one ``data: token-usage`` SSE frame.
|
||||
|
||||
Side effect: logs a one-line ``[token_usage] {log_label}: ...`` summary so
|
||||
cost analysis can grep call/total/cost across all flows.
|
||||
"""
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] %s: calls=%d total=%d cost_micros=%d summary=%s",
|
||||
log_label,
|
||||
len(accumulator.calls),
|
||||
accumulator.grand_total,
|
||||
accumulator.total_cost_micros,
|
||||
usage_summary,
|
||||
)
|
||||
if usage_summary:
|
||||
yield streaming_service.format_data(
|
||||
"token-usage",
|
||||
{
|
||||
"usage": usage_summary,
|
||||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"cost_micros": accumulator.total_cost_micros,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
},
|
||||
)
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
"""Shared finally-block helpers: session close, GC pass, native-heap trim.
|
||||
|
||||
These are called from inside an ``anyio.CancelScope(shield=True)`` block in
|
||||
each flow's ``finally`` (Starlette's BaseHTTPMiddleware cancels the scope on
|
||||
client disconnect; without the shield the very first ``await`` would raise
|
||||
``CancelledError`` and the rest of cleanup — including ``session.close()`` —
|
||||
would never run).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import gc
|
||||
import logging
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import shielded_async_session
|
||||
from app.services.chat_session_state_service import clear_ai_responding
|
||||
from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def close_session_and_clear_ai_responding(
|
||||
session: AsyncSession, chat_id: int
|
||||
) -> None:
|
||||
"""Rollback + clear AI-responding flag + expunge_all + close.
|
||||
|
||||
On rollback failure we fall back to a fresh shielded session for the flag
|
||||
clear so a UI is never stuck on "AI is responding…" after a crash.
|
||||
"""
|
||||
try:
|
||||
await session.rollback()
|
||||
await clear_ai_responding(session, chat_id)
|
||||
except Exception:
|
||||
try:
|
||||
async with shielded_async_session() as fresh_session:
|
||||
await clear_ai_responding(fresh_session, chat_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to clear AI responding state for thread %s", chat_id
|
||||
)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
session.expunge_all()
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
await session.close()
|
||||
|
||||
|
||||
def run_gc_pass(*, log_prefix: str, chat_id: int) -> None:
|
||||
"""One full gen0/1/2 pass + native-heap trim + END system snapshot.
|
||||
|
||||
Breaking circular refs held by the agent graph, tools, and LLM wrappers
|
||||
needs to happen in the caller (set the locals to ``None``) — this just
|
||||
runs the collector and logs how many objects came back.
|
||||
"""
|
||||
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
|
||||
if collected:
|
||||
_perf_log.info(
|
||||
"[%s] gc.collect() reclaimed %d objects (chat_id=%s)",
|
||||
log_prefix,
|
||||
collected,
|
||||
chat_id,
|
||||
)
|
||||
trim_native_heap()
|
||||
log_system_snapshot(f"{log_prefix}_END")
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
"""Initial SSE frames every flow emits right after pre-stream setup.
|
||||
|
||||
Order matters: ``message_start`` opens the assistant message, ``start_step``
|
||||
opens the first thinking step, ``turn-info`` lets the frontend stamp the
|
||||
correlation id onto the in-flight message, and ``turn-status: busy`` flips the
|
||||
UI into the streaming state.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
|
||||
|
||||
def iter_initial_frames(
|
||||
streaming_service: VercelStreamingService,
|
||||
*,
|
||||
turn_id: str,
|
||||
) -> Iterator[str]:
|
||||
"""Yield the four canonical opening frames in order.
|
||||
|
||||
``turn-info`` carries ``chat_turn_id`` so even pure-text turns (which
|
||||
never produce a tool / action-log event) still teach the frontend the
|
||||
turn correlation id used for ``appendMessage`` durable storage.
|
||||
"""
|
||||
yield streaming_service.format_message_start()
|
||||
yield streaming_service.format_start_step()
|
||||
yield streaming_service.format_data("turn-info", {"chat_turn_id": turn_id})
|
||||
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||
|
||||
|
||||
def iter_final_frames(
|
||||
streaming_service: VercelStreamingService,
|
||||
) -> Iterator[str]:
|
||||
"""Yield ``turn-status: idle`` plus the finish/done trailer in order."""
|
||||
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
"""Load an LLM + AgentConfig bundle for a given config id.
|
||||
|
||||
Handles both code paths uniformly:
|
||||
- ``config_id >= 0`` → database-backed ``NewLLMConfig`` row (per-user/per-space).
|
||||
- ``config_id < 0`` → YAML-defined global LLM config (built-in defaults).
|
||||
|
||||
Returns ``(llm, agent_config, error_message)``; on success ``error_message`` is
|
||||
``None``. The caller emits the friendly SSE error frame.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.llm_config import (
|
||||
AgentConfig,
|
||||
create_chat_litellm_from_agent_config,
|
||||
create_chat_litellm_from_config,
|
||||
load_agent_config,
|
||||
load_global_llm_config_by_id,
|
||||
)
|
||||
|
||||
|
||||
async def load_llm_bundle(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
config_id: int,
|
||||
search_space_id: int,
|
||||
) -> tuple[Any, AgentConfig | None, str | None]:
|
||||
if config_id >= 0:
|
||||
loaded_agent_config = await load_agent_config(
|
||||
session=session,
|
||||
config_id=config_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
if not loaded_agent_config:
|
||||
return (
|
||||
None,
|
||||
None,
|
||||
f"Failed to load NewLLMConfig with id {config_id}",
|
||||
)
|
||||
return (
|
||||
create_chat_litellm_from_agent_config(loaded_agent_config),
|
||||
loaded_agent_config,
|
||||
None,
|
||||
)
|
||||
|
||||
loaded_llm_config = load_global_llm_config_by_id(config_id)
|
||||
if not loaded_llm_config:
|
||||
return None, None, f"Failed to load LLM config with id {config_id}"
|
||||
return (
|
||||
create_chat_litellm_from_config(loaded_llm_config),
|
||||
AgentConfig.from_yaml_config(loaded_llm_config),
|
||||
None,
|
||||
)
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
"""Pre-stream setup: connector service, firecrawl key, checkpointer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||
from app.db import SearchSourceConnectorType
|
||||
from app.services.connector_service import ConnectorService
|
||||
|
||||
|
||||
async def setup_connector_and_firecrawl(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
) -> tuple[ConnectorService, str | None]:
|
||||
"""Build the per-turn connector service and pull the firecrawl API key.
|
||||
|
||||
Returns ``(connector_service, firecrawl_api_key)``. ``firecrawl_api_key`` is
|
||||
``None`` when no web-crawler connector is configured (the agent simply
|
||||
skips firecrawl-backed tools in that case).
|
||||
"""
|
||||
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
||||
firecrawl_api_key: str | None = None
|
||||
webcrawler_connector = await connector_service.get_connector_by_type(
|
||||
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
|
||||
)
|
||||
if webcrawler_connector and webcrawler_connector.config:
|
||||
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
|
||||
return connector_service, firecrawl_api_key
|
||||
|
||||
|
||||
async def get_chat_checkpointer():
|
||||
"""Resolve the PostgreSQL checkpointer for persistent conversation memory.
|
||||
|
||||
Thin wrapper around ``app.agents.new_chat.checkpointer.get_checkpointer`` so
|
||||
flow orchestrators can rely on a streaming-local symbol and we have a hook
|
||||
point if the checkpointer source ever needs to vary per flow.
|
||||
"""
|
||||
return await get_checkpointer()
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
"""Premium credit (USD micro-units) reserve / finalize / release lifecycle.
|
||||
|
||||
Both ``stream_new_chat`` and ``stream_resume_chat`` reserve premium credits up
|
||||
front (so a single LLM call can't run away with the budget), then finalize the
|
||||
actual provider cost reported by LiteLLM when the turn completes successfully,
|
||||
or release the reservation on the cancellation / interrupted-without-finalize
|
||||
paths.
|
||||
|
||||
State is held by the orchestrator as a simple ``PremiumReservation`` tuple
|
||||
so reservation, fallback-on-denied, finalize, and release can all be reasoned
|
||||
about from one place.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid as _uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
from app.db import shielded_async_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.token_tracking_service import TokenAccumulator
|
||||
|
||||
|
||||
@dataclass
|
||||
class PremiumReservation:
|
||||
"""Active premium-credit reservation for one turn.
|
||||
|
||||
``request_id`` is the per-reservation idempotency key (also passed to
|
||||
``finalize``/``release`` so racing branches resolve to the same row).
|
||||
``reserved_micros`` is the up-front estimate; ``finalize`` debits the
|
||||
actual cost, ``release`` returns it untouched.
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
reserved_micros: int
|
||||
allowed: bool
|
||||
|
||||
|
||||
def needs_premium_quota(
|
||||
agent_config: AgentConfig | None, user_id: str | None
|
||||
) -> bool:
|
||||
return bool(agent_config is not None and user_id and agent_config.is_premium)
|
||||
|
||||
|
||||
async def reserve_premium(
|
||||
*,
|
||||
agent_config: AgentConfig,
|
||||
user_id: str,
|
||||
) -> PremiumReservation:
|
||||
"""Reserve estimated micros up front; returns the reservation handle."""
|
||||
from app.services.token_quota_service import (
|
||||
TokenQuotaService,
|
||||
estimate_call_reserve_micros,
|
||||
)
|
||||
|
||||
request_id = _uuid.uuid4().hex[:16]
|
||||
litellm_params = agent_config.litellm_params or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
|
||||
) or agent_config.model_name or ""
|
||||
reserve_amount_micros = estimate_call_reserve_micros(
|
||||
base_model=base_model,
|
||||
quota_reserve_tokens=agent_config.quota_reserve_tokens,
|
||||
)
|
||||
async with shielded_async_session() as quota_session:
|
||||
quota_result = await TokenQuotaService.premium_reserve(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=request_id,
|
||||
reserve_micros=reserve_amount_micros,
|
||||
)
|
||||
return PremiumReservation(
|
||||
request_id=request_id,
|
||||
reserved_micros=reserve_amount_micros,
|
||||
allowed=quota_result.allowed,
|
||||
)
|
||||
|
||||
|
||||
async def finalize_premium(
|
||||
*,
|
||||
reservation: PremiumReservation,
|
||||
user_id: str,
|
||||
accumulator: TokenAccumulator,
|
||||
) -> None:
|
||||
"""Finalize debit using the actual provider cost reported by LiteLLM.
|
||||
|
||||
Best-effort: failures here must not bubble up to the SSE stream — the user
|
||||
has already received their tokens; we log and move on.
|
||||
"""
|
||||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
async with shielded_async_session() as quota_session:
|
||||
await TokenQuotaService.premium_finalize(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=reservation.request_id,
|
||||
actual_micros=accumulator.total_cost_micros,
|
||||
reserved_micros=reservation.reserved_micros,
|
||||
)
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to finalize premium quota for user %s",
|
||||
user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def release_premium(
|
||||
*,
|
||||
reservation: PremiumReservation,
|
||||
user_id: str,
|
||||
) -> None:
|
||||
"""Release the reservation on cancellation paths; never raises."""
|
||||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
async with shielded_async_session() as quota_session:
|
||||
await TokenQuotaService.premium_release(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
reserved_micros=reservation.reserved_micros,
|
||||
)
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to release premium quota for user %s", user_id
|
||||
)
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
"""Shared steps for the in-stream provider rate-limit recovery loop.
|
||||
|
||||
Both flows wrap ``run_stream_loop`` with a flow-specific ``recover`` closure;
|
||||
the *guard*, the *auto-pin reroute*, and the *post-recovery telemetry* are the
|
||||
same on both sides and live here so behaviour can't drift.
|
||||
|
||||
The orchestrator owns the parts that genuinely diverge:
|
||||
|
||||
* cancelling the title task (new_chat only),
|
||||
* passing ``mentioned_document_ids`` to ``build_main_agent_for_thread``,
|
||||
* the log prefix (``stream_new_chat`` vs ``stream_resume``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.middleware.busy_mutex import end_turn
|
||||
from app.observability import otel as ot
|
||||
from app.services.auto_model_pin_service import (
|
||||
mark_runtime_cooldown,
|
||||
resolve_or_get_pinned_llm_config_id,
|
||||
)
|
||||
from app.tasks.chat.streaming.errors.classifier import (
|
||||
is_provider_rate_limited,
|
||||
log_chat_stream_error,
|
||||
)
|
||||
|
||||
|
||||
def can_recover_provider_rate_limit(
|
||||
exc: BaseException,
|
||||
*,
|
||||
first_event_seen: bool,
|
||||
runtime_rate_limit_recovered: bool,
|
||||
requested_llm_config_id: int,
|
||||
current_llm_config_id: int,
|
||||
) -> bool:
|
||||
"""Guard: only the first auto-pin → provider-rate-limited failure recovers.
|
||||
|
||||
All conditions must hold:
|
||||
|
||||
* ``runtime_rate_limit_recovered is False`` — at most one recovery per turn.
|
||||
* ``requested_llm_config_id == 0`` — caller opted into auto-pin (id=0).
|
||||
* ``current_llm_config_id < 0`` — currently on a YAML config (the only
|
||||
kind the auto-pin pool draws from).
|
||||
* ``first_event_seen is False`` — we haven't sent any SSE to the user yet,
|
||||
so a silent rebuild + retry is invisible.
|
||||
* The exception is provider-side rate-limited (HTTP 429 or known shape).
|
||||
"""
|
||||
return (
|
||||
not runtime_rate_limit_recovered
|
||||
and requested_llm_config_id == 0
|
||||
and current_llm_config_id < 0
|
||||
and not first_event_seen
|
||||
and is_provider_rate_limited(exc)
|
||||
)
|
||||
|
||||
|
||||
async def reroute_to_next_auto_pin(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
chat_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
current_llm_config_id: int,
|
||||
requires_image_input: bool,
|
||||
) -> int:
|
||||
"""Release lock, cool down the failing config, pick a new auto-pin id.
|
||||
|
||||
Returns the new ``llm_config_id``. ``end_turn`` is called because the failed
|
||||
attempt may still hold the per-thread busy mutex (middleware teardown can
|
||||
lag behind raised provider errors) — the same-request retry would otherwise
|
||||
bounce on ``BusyError``.
|
||||
"""
|
||||
end_turn(str(chat_id))
|
||||
mark_runtime_cooldown(current_llm_config_id, reason="provider_rate_limited")
|
||||
pinned = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=0,
|
||||
exclude_config_ids={current_llm_config_id},
|
||||
requires_image_input=requires_image_input,
|
||||
)
|
||||
return pinned.resolved_llm_config_id
|
||||
|
||||
|
||||
def log_rate_limit_recovered(
|
||||
*,
|
||||
flow: Literal["new", "regenerate", "resume"],
|
||||
request_id: str | None,
|
||||
chat_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
previous_config_id: int,
|
||||
new_config_id: int,
|
||||
) -> None:
|
||||
"""Emit the OTEL event + structured ``[chat_stream_error]`` log line."""
|
||||
ot.add_event(
|
||||
"chat.rate_limit.recovered",
|
||||
{
|
||||
"recovery.reason": "provider_rate_limited",
|
||||
"recovery.previous_config_id": previous_config_id,
|
||||
"recovery.fallback_config_id": new_config_id,
|
||||
},
|
||||
)
|
||||
log_chat_stream_error(
|
||||
flow=flow,
|
||||
error_kind="rate_limited",
|
||||
error_code="RATE_LIMITED",
|
||||
severity="info",
|
||||
is_expected=True,
|
||||
request_id=request_id,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
message=(
|
||||
"Auto-pinned model hit runtime rate limit; switched to "
|
||||
"another eligible model and retried."
|
||||
),
|
||||
extra={
|
||||
"auto_runtime_recover": True,
|
||||
"previous_config_id": previous_config_id,
|
||||
"fallback_config_id": new_config_id,
|
||||
},
|
||||
)
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
"""OpenTelemetry chat-request span wrapper for streaming flows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import sys
|
||||
from typing import Any, Literal
|
||||
|
||||
from app.observability import metrics as ot_metrics
|
||||
from app.observability import otel as ot
|
||||
|
||||
|
||||
def open_chat_request_span(
|
||||
*,
|
||||
chat_id: int,
|
||||
search_space_id: int,
|
||||
flow: Literal["new", "regenerate", "resume"],
|
||||
request_id: str | None,
|
||||
turn_id: str,
|
||||
filesystem_mode: str,
|
||||
client_platform: str,
|
||||
agent_mode: str,
|
||||
) -> tuple[Any, Any]:
|
||||
"""Open the per-request span; returns ``(span_cm, span)`` for finally-close."""
|
||||
span_cm = ot.chat_request_span(
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
flow=flow,
|
||||
request_id=request_id,
|
||||
turn_id=turn_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
client_platform=client_platform,
|
||||
agent_mode=agent_mode,
|
||||
)
|
||||
span = span_cm.__enter__()
|
||||
return span_cm, span
|
||||
|
||||
|
||||
def set_agent_mode(span: Any, agent_mode: str) -> None:
|
||||
"""Tag the span with the resolved agent mode (single / multi)."""
|
||||
with contextlib.suppress(Exception):
|
||||
span.set_attribute("agent.mode", agent_mode)
|
||||
|
||||
|
||||
def close_chat_request_span(
|
||||
*,
|
||||
span_cm: Any,
|
||||
span: Any,
|
||||
chat_outcome: str,
|
||||
chat_agent_mode: str,
|
||||
flow: Literal["new", "regenerate", "resume"],
|
||||
chat_error_category: str | None,
|
||||
duration_seconds: float,
|
||||
) -> None:
|
||||
"""Record metrics + close the span. Swallows errors (finally-block context)."""
|
||||
with contextlib.suppress(Exception):
|
||||
span.set_attribute("chat.outcome", chat_outcome)
|
||||
ot_metrics.record_chat_request_duration(
|
||||
duration_seconds * 1000,
|
||||
flow=flow,
|
||||
outcome=chat_outcome,
|
||||
agent_mode=chat_agent_mode,
|
||||
)
|
||||
ot_metrics.record_chat_request_outcome(
|
||||
flow=flow,
|
||||
outcome=chat_outcome,
|
||||
agent_mode=chat_agent_mode,
|
||||
error_category=chat_error_category,
|
||||
)
|
||||
span_cm.__exit__(*sys.exc_info())
|
||||
|
||||
|
||||
def record_outcome_attrs(
|
||||
span: Any, *, chat_outcome: str, chat_error_category: str | None
|
||||
) -> None:
|
||||
"""Stamp outcome + error.category on the span (used in the except branch)."""
|
||||
with contextlib.suppress(Exception):
|
||||
span.set_attribute("chat.outcome", chat_outcome)
|
||||
if chat_error_category is not None:
|
||||
span.set_attribute("error.category", chat_error_category)
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
"""Drive ``stream_agent_events`` with in-stream rate-limit recovery.
|
||||
|
||||
Both ``stream_new_chat`` and ``stream_resume_chat`` wrap the agent event loop
|
||||
in a ``while True`` that catches the *first* provider rate-limit error
|
||||
(``can_runtime_recover``) before any SSE event reaches the user, rebuilds the
|
||||
agent on an alternative auto-pin, and retries the turn.
|
||||
|
||||
The recovery callback is flow-specific (different ``mentioned_document_ids``
|
||||
contract, different logging label, etc.) — this module owns the loop shape,
|
||||
the caller owns the rebuild.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.tasks.chat.streaming.agent.event_loop import stream_agent_events
|
||||
from app.tasks.chat.streaming.shared.stream_result import StreamResult
|
||||
|
||||
# Returns the rebuilt agent on a successful recovery, or ``None`` to re-raise
|
||||
# the original exception (and let the orchestrator's terminal-error path
|
||||
# handle it).
|
||||
RecoverFn = Callable[[BaseException, bool], Awaitable[Any | None]]
|
||||
|
||||
|
||||
async def run_stream_loop(
|
||||
*,
|
||||
agent: Any,
|
||||
streaming_service: VercelStreamingService,
|
||||
config: dict[str, Any],
|
||||
input_data: Any,
|
||||
stream_result: StreamResult,
|
||||
step_prefix: str = "thinking",
|
||||
initial_step_id: str | None = None,
|
||||
initial_step_title: str = "",
|
||||
initial_step_items: list[str] | None = None,
|
||||
fallback_commit_search_space_id: int | None,
|
||||
fallback_commit_created_by_id: str | None,
|
||||
fallback_commit_filesystem_mode: FilesystemMode,
|
||||
fallback_commit_thread_id: int | None,
|
||||
runtime_context: Any,
|
||||
content_builder: Any | None,
|
||||
recover: RecoverFn,
|
||||
on_first_event: Callable[[], None] | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Yield SSE frames; rebuild and retry once on a pre-first-event rate limit.
|
||||
|
||||
``on_first_event`` fires after the first frame is observed (used by both
|
||||
flows to write a one-time ``First agent event in N.NNNs`` perf line).
|
||||
"""
|
||||
first_event_logged = False
|
||||
while True:
|
||||
try:
|
||||
async for sse in stream_agent_events(
|
||||
agent=agent,
|
||||
config=config,
|
||||
input_data=input_data,
|
||||
streaming_service=streaming_service,
|
||||
result=stream_result,
|
||||
step_prefix=step_prefix,
|
||||
initial_step_id=initial_step_id,
|
||||
initial_step_title=initial_step_title,
|
||||
initial_step_items=initial_step_items,
|
||||
fallback_commit_search_space_id=fallback_commit_search_space_id,
|
||||
fallback_commit_created_by_id=fallback_commit_created_by_id,
|
||||
fallback_commit_filesystem_mode=fallback_commit_filesystem_mode,
|
||||
fallback_commit_thread_id=fallback_commit_thread_id,
|
||||
runtime_context=runtime_context,
|
||||
content_builder=content_builder,
|
||||
):
|
||||
if not first_event_logged:
|
||||
if on_first_event is not None:
|
||||
on_first_event()
|
||||
first_event_logged = True
|
||||
yield sse
|
||||
return
|
||||
except Exception as exc:
|
||||
new_agent = await recover(exc, first_event_logged)
|
||||
if new_agent is None:
|
||||
raise
|
||||
agent = new_agent
|
||||
continue
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
"""Handle the ``except Exception`` branch of a streaming flow.
|
||||
|
||||
Classifies the exception, records OpenTelemetry attributes, emits one terminal
|
||||
error SSE frame and the trailing ``turn-status: idle`` + finish/done frames.
|
||||
|
||||
Used by both ``stream_new_chat`` and ``stream_resume_chat``; flow-specific bits
|
||||
(label, span, BusyError tracking) are passed by the caller.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, Literal
|
||||
|
||||
from app.agents.new_chat.errors import BusyError
|
||||
from app.observability import metrics as ot_metrics
|
||||
from app.observability import otel as ot
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.tasks.chat.streaming.errors.classifier import classify_stream_exception
|
||||
from app.tasks.chat.streaming.errors.emitter import emit_stream_terminal_error
|
||||
from app.tasks.chat.streaming.flows.shared.first_frames import iter_final_frames
|
||||
from app.tasks.chat.streaming.flows.shared.span import record_outcome_attrs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def handle_terminal_exception(
|
||||
exc: Exception,
|
||||
*,
|
||||
flow: Literal["new", "regenerate", "resume"],
|
||||
flow_label: str,
|
||||
log_prefix: str,
|
||||
streaming_service: VercelStreamingService,
|
||||
request_id: str | None,
|
||||
chat_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
chat_span: Any,
|
||||
) -> tuple[Iterator[str], dict[str, Any]]:
|
||||
"""Classify, log, and produce the SSE frames for a terminal exception.
|
||||
|
||||
Returns ``(frame_iterator, summary)``. ``summary`` carries::
|
||||
|
||||
- ``busy_error_raised``: bool — caller must skip the lock-release path
|
||||
when True (caller never acquired the busy mutex).
|
||||
- ``chat_outcome``: str — span outcome attribute.
|
||||
- ``chat_error_category``: str — categorized error label for metrics.
|
||||
"""
|
||||
busy_error_raised = isinstance(exc, BusyError)
|
||||
|
||||
(
|
||||
error_kind,
|
||||
error_code,
|
||||
severity,
|
||||
is_expected,
|
||||
user_message,
|
||||
error_extra,
|
||||
) = classify_stream_exception(exc, flow_label=flow_label)
|
||||
chat_outcome = error_code or error_kind or "error"
|
||||
chat_error_category = ot_metrics.categorize_exception(exc)
|
||||
record_outcome_attrs(
|
||||
chat_span,
|
||||
chat_outcome=chat_outcome,
|
||||
chat_error_category=chat_error_category,
|
||||
)
|
||||
with __suppress():
|
||||
ot.record_error(chat_span, exc)
|
||||
error_message = f"Error during {flow_label}: {exc!s}"
|
||||
# Match the original behavior: log full traceback via ``print`` so it lands
|
||||
# in stderr regardless of the logger config.
|
||||
print(f"[{log_prefix}] {error_message}")
|
||||
print(f"[{log_prefix}] Exception type: {type(exc).__name__}")
|
||||
print(f"[{log_prefix}] Traceback:\n{traceback.format_exc()}")
|
||||
|
||||
def _iter_frames() -> Iterator[str]:
|
||||
if error_code == "TURN_CANCELLING":
|
||||
status_payload: dict[str, Any] = {"status": "cancelling"}
|
||||
if error_extra:
|
||||
status_payload.update(error_extra)
|
||||
yield streaming_service.format_data("turn-status", status_payload)
|
||||
else:
|
||||
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||
|
||||
yield emit_stream_terminal_error(
|
||||
streaming_service=streaming_service,
|
||||
flow=flow,
|
||||
request_id=request_id,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
message=user_message,
|
||||
error_kind=error_kind,
|
||||
error_code=error_code,
|
||||
severity=severity,
|
||||
is_expected=is_expected,
|
||||
extra=error_extra,
|
||||
)
|
||||
yield from iter_final_frames(streaming_service)
|
||||
|
||||
return (
|
||||
_iter_frames(),
|
||||
{
|
||||
"busy_error_raised": busy_error_raised,
|
||||
"chat_outcome": chat_outcome,
|
||||
"chat_error_category": chat_error_category,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def __suppress():
|
||||
"""Local single-use ``contextlib.suppress(Exception)`` factory.
|
||||
|
||||
Inlined here so callers don't import ``contextlib`` just for the
|
||||
``record_error`` call site.
|
||||
"""
|
||||
import contextlib
|
||||
|
||||
return contextlib.suppress(Exception)
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
"""Shared building blocks used across every streaming flow."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.tasks.chat.streaming.shared.stream_result import StreamResult
|
||||
from app.tasks.chat.streaming.shared.utils import (
|
||||
resume_step_prefix,
|
||||
safe_float,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"StreamResult",
|
||||
"resume_step_prefix",
|
||||
"safe_float",
|
||||
]
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
"""Per-turn streaming state shared between the orchestrator and event loop."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamResult:
|
||||
accumulated_text: str = ""
|
||||
is_interrupted: bool = False
|
||||
sandbox_files: list[str] = field(default_factory=list)
|
||||
request_id: str | None = None
|
||||
turn_id: str = ""
|
||||
filesystem_mode: str = "cloud"
|
||||
client_platform: str = "web"
|
||||
intent_detected: str = "chat_only"
|
||||
intent_confidence: float = 0.0
|
||||
write_attempted: bool = False
|
||||
write_succeeded: bool = False
|
||||
verification_succeeded: bool = False
|
||||
commit_gate_passed: bool = True
|
||||
commit_gate_reason: str = ""
|
||||
# Pre-allocated assistant ``new_chat_messages.id`` for this turn, captured by
|
||||
# ``persist_assistant_shell`` right after the user row is persisted. ``None``
|
||||
# for the legacy/anonymous code paths that don't opt in to server-side
|
||||
# ``ContentPart[]`` projection.
|
||||
assistant_message_id: int | None = None
|
||||
# In-memory mirror of the FE's assistant-ui ``ContentPartsState``, populated
|
||||
# by the lifecycle methods called from the streaming event loop at each
|
||||
# ``streaming_service.format_*`` yield site. Snapshot in the streaming
|
||||
# ``finally`` to produce the rich JSONB persisted by
|
||||
# ``finalize_assistant_turn``. ``repr=False`` keeps the log-on-error path
|
||||
# (``StreamResult`` is logged in some error branches) from dumping a
|
||||
# potentially-large parts list.
|
||||
content_builder: Any | None = field(default=None, repr=False)
|
||||
27
surfsense_backend/app/tasks/chat/streaming/shared/utils.py
Normal file
27
surfsense_backend/app/tasks/chat/streaming/shared/utils.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"""Small utilities used by streaming orchestrators and phases."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def resume_step_prefix(turn_id: str) -> str:
|
||||
"""Per-turn ``step_prefix`` for resume invocations.
|
||||
|
||||
Each ``stream_agent_events`` call constructs a fresh
|
||||
``AgentEventRelayState`` with ``thinking_step_counter=0``, so two consecutive
|
||||
resume turns would otherwise both emit ``thinking-resume-1``, ``-2`` etc.
|
||||
The frontend rehydrates ``currentThinkingSteps`` from the immediate prior
|
||||
assistant message at the start of every resume — if the new stream's IDs
|
||||
collide with the seeded ones, React renders sibling Timeline rows with the
|
||||
same key. Salting with ``turn_id`` guarantees disjoint IDs across resumes
|
||||
within one thread.
|
||||
"""
|
||||
return f"thinking-resume-{turn_id}"
|
||||
|
||||
|
||||
def safe_float(value: Any, default: float = 0.0) -> float:
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
Loading…
Add table
Add a link
Reference in a new issue