mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-29 19:35:20 +02:00
refactor(chat): add streaming/flows/shared/ base helpers
Six small, single-purpose modules shared by the upcoming new_chat and resume_chat orchestrators: * llm_bundle: dispatches negative config_id to the YAML loader and non-negative config_id to the DB loader, returning (llm, AgentConfig). * pre_stream_setup: builds the connector service, resolves the Firecrawl API key, and returns the chat checkpointer. * first_frames: iter_initial_frames + iter_final_frames emit the canonical message-start / step-start / idle / finish / done SSE envelope. * finalize_emit: iter_token_usage_frame emits the per-turn usage frame from a TokenAccumulator summary. * finally_cleanup: close_session_and_clear_ai_responding and run_gc_pass centralize the finally-block bookkeeping. * span: open_chat_request_span / set_agent_mode / close_chat_request_span / record_outcome_attrs wrap the OpenTelemetry chat_request span. Add-only; these are not yet wired into stream_new_chat.py.
This commit is contained in:
parent
26c569467d
commit
e9a98ecafb
7 changed files with 343 additions and 0 deletions
|
|
@ -0,0 +1,3 @@
|
||||||
|
"""Building blocks shared by ``new_chat`` and ``resume_chat`` orchestrators."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
@ -0,0 +1,54 @@
|
||||||
|
"""Emit the per-turn token-usage SSE frame from the accumulator.
|
||||||
|
|
||||||
|
``per_message_summary()`` returns ``None`` when the turn made no chargeable
|
||||||
|
LLM calls (e.g. interrupt-on-input). In that case we skip the frame; the
|
||||||
|
frontend has no usage to render.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from app.services.new_streaming_service import VercelStreamingService
|
||||||
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.services.token_tracking_service import TokenAccumulator
|
||||||
|
|
||||||
|
_perf_log = get_perf_logger()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def iter_token_usage_frame(
|
||||||
|
streaming_service: VercelStreamingService,
|
||||||
|
*,
|
||||||
|
accumulator: TokenAccumulator,
|
||||||
|
log_label: str,
|
||||||
|
):
|
||||||
|
"""Yield zero or one ``data: token-usage`` SSE frame.
|
||||||
|
|
||||||
|
Side effect: logs a one-line ``[token_usage] {log_label}: ...`` summary so
|
||||||
|
cost analysis can grep call/total/cost across all flows.
|
||||||
|
"""
|
||||||
|
usage_summary = accumulator.per_message_summary()
|
||||||
|
_perf_log.info(
|
||||||
|
"[token_usage] %s: calls=%d total=%d cost_micros=%d summary=%s",
|
||||||
|
log_label,
|
||||||
|
len(accumulator.calls),
|
||||||
|
accumulator.grand_total,
|
||||||
|
accumulator.total_cost_micros,
|
||||||
|
usage_summary,
|
||||||
|
)
|
||||||
|
if usage_summary:
|
||||||
|
yield streaming_service.format_data(
|
||||||
|
"token-usage",
|
||||||
|
{
|
||||||
|
"usage": usage_summary,
|
||||||
|
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||||
|
"completion_tokens": accumulator.total_completion_tokens,
|
||||||
|
"total_tokens": accumulator.grand_total,
|
||||||
|
"cost_micros": accumulator.total_cost_micros,
|
||||||
|
"call_details": accumulator.serialized_calls(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
"""Shared finally-block helpers: session close, GC pass, native-heap trim.
|
||||||
|
|
||||||
|
These are called from inside an ``anyio.CancelScope(shield=True)`` block in
|
||||||
|
each flow's ``finally`` (Starlette's BaseHTTPMiddleware cancels the scope on
|
||||||
|
client disconnect; without the shield the very first ``await`` would raise
|
||||||
|
``CancelledError`` and the rest of cleanup — including ``session.close()`` —
|
||||||
|
would never run).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import gc
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import shielded_async_session
|
||||||
|
from app.services.chat_session_state_service import clear_ai_responding
|
||||||
|
from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
|
||||||
|
|
||||||
|
_perf_log = get_perf_logger()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def close_session_and_clear_ai_responding(
|
||||||
|
session: AsyncSession, chat_id: int
|
||||||
|
) -> None:
|
||||||
|
"""Rollback + clear AI-responding flag + expunge_all + close.
|
||||||
|
|
||||||
|
On rollback failure we fall back to a fresh shielded session for the flag
|
||||||
|
clear so a UI is never stuck on "AI is responding…" after a crash.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await session.rollback()
|
||||||
|
await clear_ai_responding(session, chat_id)
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
async with shielded_async_session() as fresh_session:
|
||||||
|
await clear_ai_responding(fresh_session, chat_id)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to clear AI responding state for thread %s", chat_id
|
||||||
|
)
|
||||||
|
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
session.expunge_all()
|
||||||
|
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
|
def run_gc_pass(*, log_prefix: str, chat_id: int) -> None:
|
||||||
|
"""One full gen0/1/2 pass + native-heap trim + END system snapshot.
|
||||||
|
|
||||||
|
Breaking circular refs held by the agent graph, tools, and LLM wrappers
|
||||||
|
needs to happen in the caller (set the locals to ``None``) — this just
|
||||||
|
runs the collector and logs how many objects came back.
|
||||||
|
"""
|
||||||
|
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
|
||||||
|
if collected:
|
||||||
|
_perf_log.info(
|
||||||
|
"[%s] gc.collect() reclaimed %d objects (chat_id=%s)",
|
||||||
|
log_prefix,
|
||||||
|
collected,
|
||||||
|
chat_id,
|
||||||
|
)
|
||||||
|
trim_native_heap()
|
||||||
|
log_system_snapshot(f"{log_prefix}_END")
|
||||||
|
|
@ -0,0 +1,40 @@
|
||||||
|
"""Initial SSE frames every flow emits right after pre-stream setup.
|
||||||
|
|
||||||
|
Order matters: ``message_start`` opens the assistant message, ``start_step``
|
||||||
|
opens the first thinking step, ``turn-info`` lets the frontend stamp the
|
||||||
|
correlation id onto the in-flight message, and ``turn-status: busy`` flips the
|
||||||
|
UI into the streaming state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
from app.services.new_streaming_service import VercelStreamingService
|
||||||
|
|
||||||
|
|
||||||
|
def iter_initial_frames(
|
||||||
|
streaming_service: VercelStreamingService,
|
||||||
|
*,
|
||||||
|
turn_id: str,
|
||||||
|
) -> Iterator[str]:
|
||||||
|
"""Yield the four canonical opening frames in order.
|
||||||
|
|
||||||
|
``turn-info`` carries ``chat_turn_id`` so even pure-text turns (which
|
||||||
|
never produce a tool / action-log event) still teach the frontend the
|
||||||
|
turn correlation id used for ``appendMessage`` durable storage.
|
||||||
|
"""
|
||||||
|
yield streaming_service.format_message_start()
|
||||||
|
yield streaming_service.format_start_step()
|
||||||
|
yield streaming_service.format_data("turn-info", {"chat_turn_id": turn_id})
|
||||||
|
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||||
|
|
||||||
|
|
||||||
|
def iter_final_frames(
|
||||||
|
streaming_service: VercelStreamingService,
|
||||||
|
) -> Iterator[str]:
|
||||||
|
"""Yield ``turn-status: idle`` plus the finish/done trailer in order."""
|
||||||
|
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||||
|
yield streaming_service.format_finish_step()
|
||||||
|
yield streaming_service.format_finish()
|
||||||
|
yield streaming_service.format_done()
|
||||||
|
|
@ -0,0 +1,57 @@
|
||||||
|
"""Load an LLM + AgentConfig bundle for a given config id.
|
||||||
|
|
||||||
|
Handles both code paths uniformly:
|
||||||
|
- ``config_id >= 0`` → database-backed ``NewLLMConfig`` row (per-user/per-space).
|
||||||
|
- ``config_id < 0`` → YAML-defined global LLM config (built-in defaults).
|
||||||
|
|
||||||
|
Returns ``(llm, agent_config, error_message)``; on success ``error_message`` is
|
||||||
|
``None``. The caller emits the friendly SSE error frame.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.agents.new_chat.llm_config import (
|
||||||
|
AgentConfig,
|
||||||
|
create_chat_litellm_from_agent_config,
|
||||||
|
create_chat_litellm_from_config,
|
||||||
|
load_agent_config,
|
||||||
|
load_global_llm_config_by_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def load_llm_bundle(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
config_id: int,
|
||||||
|
search_space_id: int,
|
||||||
|
) -> tuple[Any, AgentConfig | None, str | None]:
|
||||||
|
if config_id >= 0:
|
||||||
|
loaded_agent_config = await load_agent_config(
|
||||||
|
session=session,
|
||||||
|
config_id=config_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
)
|
||||||
|
if not loaded_agent_config:
|
||||||
|
return (
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
f"Failed to load NewLLMConfig with id {config_id}",
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
create_chat_litellm_from_agent_config(loaded_agent_config),
|
||||||
|
loaded_agent_config,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded_llm_config = load_global_llm_config_by_id(config_id)
|
||||||
|
if not loaded_llm_config:
|
||||||
|
return None, None, f"Failed to load LLM config with id {config_id}"
|
||||||
|
return (
|
||||||
|
create_chat_litellm_from_config(loaded_llm_config),
|
||||||
|
AgentConfig.from_yaml_config(loaded_llm_config),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,40 @@
|
||||||
|
"""Pre-stream setup: connector service, firecrawl key, checkpointer."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||||
|
from app.db import SearchSourceConnectorType
|
||||||
|
from app.services.connector_service import ConnectorService
|
||||||
|
|
||||||
|
|
||||||
|
async def setup_connector_and_firecrawl(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
search_space_id: int,
|
||||||
|
) -> tuple[ConnectorService, str | None]:
|
||||||
|
"""Build the per-turn connector service and pull the firecrawl API key.
|
||||||
|
|
||||||
|
Returns ``(connector_service, firecrawl_api_key)``. ``firecrawl_api_key`` is
|
||||||
|
``None`` when no web-crawler connector is configured (the agent simply
|
||||||
|
skips firecrawl-backed tools in that case).
|
||||||
|
"""
|
||||||
|
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
||||||
|
firecrawl_api_key: str | None = None
|
||||||
|
webcrawler_connector = await connector_service.get_connector_by_type(
|
||||||
|
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
|
||||||
|
)
|
||||||
|
if webcrawler_connector and webcrawler_connector.config:
|
||||||
|
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
|
||||||
|
return connector_service, firecrawl_api_key
|
||||||
|
|
||||||
|
|
||||||
|
async def get_chat_checkpointer():
|
||||||
|
"""Resolve the PostgreSQL checkpointer for persistent conversation memory.
|
||||||
|
|
||||||
|
Thin wrapper around ``app.agents.new_chat.checkpointer.get_checkpointer`` so
|
||||||
|
flow orchestrators can rely on a streaming-local symbol and we have a hook
|
||||||
|
point if the checkpointer source ever needs to vary per flow.
|
||||||
|
"""
|
||||||
|
return await get_checkpointer()
|
||||||
|
|
@ -0,0 +1,80 @@
|
||||||
|
"""OpenTelemetry chat-request span wrapper for streaming flows."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import sys
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from app.observability import metrics as ot_metrics
|
||||||
|
from app.observability import otel as ot
|
||||||
|
|
||||||
|
|
||||||
|
def open_chat_request_span(
|
||||||
|
*,
|
||||||
|
chat_id: int,
|
||||||
|
search_space_id: int,
|
||||||
|
flow: Literal["new", "regenerate", "resume"],
|
||||||
|
request_id: str | None,
|
||||||
|
turn_id: str,
|
||||||
|
filesystem_mode: str,
|
||||||
|
client_platform: str,
|
||||||
|
agent_mode: str,
|
||||||
|
) -> tuple[Any, Any]:
|
||||||
|
"""Open the per-request span; returns ``(span_cm, span)`` for finally-close."""
|
||||||
|
span_cm = ot.chat_request_span(
|
||||||
|
chat_id=chat_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
flow=flow,
|
||||||
|
request_id=request_id,
|
||||||
|
turn_id=turn_id,
|
||||||
|
filesystem_mode=filesystem_mode,
|
||||||
|
client_platform=client_platform,
|
||||||
|
agent_mode=agent_mode,
|
||||||
|
)
|
||||||
|
span = span_cm.__enter__()
|
||||||
|
return span_cm, span
|
||||||
|
|
||||||
|
|
||||||
|
def set_agent_mode(span: Any, agent_mode: str) -> None:
|
||||||
|
"""Tag the span with the resolved agent mode (single / multi)."""
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
span.set_attribute("agent.mode", agent_mode)
|
||||||
|
|
||||||
|
|
||||||
|
def close_chat_request_span(
|
||||||
|
*,
|
||||||
|
span_cm: Any,
|
||||||
|
span: Any,
|
||||||
|
chat_outcome: str,
|
||||||
|
chat_agent_mode: str,
|
||||||
|
flow: Literal["new", "regenerate", "resume"],
|
||||||
|
chat_error_category: str | None,
|
||||||
|
duration_seconds: float,
|
||||||
|
) -> None:
|
||||||
|
"""Record metrics + close the span. Swallows errors (finally-block context)."""
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
span.set_attribute("chat.outcome", chat_outcome)
|
||||||
|
ot_metrics.record_chat_request_duration(
|
||||||
|
duration_seconds * 1000,
|
||||||
|
flow=flow,
|
||||||
|
outcome=chat_outcome,
|
||||||
|
agent_mode=chat_agent_mode,
|
||||||
|
)
|
||||||
|
ot_metrics.record_chat_request_outcome(
|
||||||
|
flow=flow,
|
||||||
|
outcome=chat_outcome,
|
||||||
|
agent_mode=chat_agent_mode,
|
||||||
|
error_category=chat_error_category,
|
||||||
|
)
|
||||||
|
span_cm.__exit__(*sys.exc_info())
|
||||||
|
|
||||||
|
|
||||||
|
def record_outcome_attrs(
|
||||||
|
span: Any, *, chat_outcome: str, chat_error_category: str | None
|
||||||
|
) -> None:
|
||||||
|
"""Stamp outcome + error.category on the span (used in the except branch)."""
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
span.set_attribute("chat.outcome", chat_outcome)
|
||||||
|
if chat_error_category is not None:
|
||||||
|
span.set_attribute("error.category", chat_error_category)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue