mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-29 19:35:20 +02:00
refactor(chat): add streaming/flows/shared/ base helpers
Six small, single-purpose modules shared by the upcoming new_chat and resume_chat orchestrators: * llm_bundle: dispatches negative config_id to the YAML loader and non-negative config_id to the DB loader, returning (llm, AgentConfig). * pre_stream_setup: builds the connector service, resolves the Firecrawl API key, and returns the chat checkpointer. * first_frames: iter_initial_frames + iter_final_frames emit the canonical message-start / step-start / idle / finish / done SSE envelope. * finalize_emit: iter_token_usage_frame emits the per-turn usage frame from a TokenAccumulator summary. * finally_cleanup: close_session_and_clear_ai_responding and run_gc_pass centralize the finally-block bookkeeping. * span: open_chat_request_span / set_agent_mode / close_chat_request_span / record_outcome_attrs wrap the OpenTelemetry chat_request span. Add-only; these are not yet wired into stream_new_chat.py.
This commit is contained in:
parent
26c569467d
commit
e9a98ecafb
7 changed files with 343 additions and 0 deletions
|
|
@ -0,0 +1,3 @@
|
|||
"""Building blocks shared by ``new_chat`` and ``resume_chat`` orchestrators."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
"""Emit the per-turn token-usage SSE frame from the accumulator.
|
||||
|
||||
``per_message_summary()`` returns ``None`` when the turn made no chargeable
|
||||
LLM calls (e.g. interrupt-on-input). In that case we skip the frame; the
|
||||
frontend has no usage to render.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.token_tracking_service import TokenAccumulator
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def iter_token_usage_frame(
|
||||
streaming_service: VercelStreamingService,
|
||||
*,
|
||||
accumulator: TokenAccumulator,
|
||||
log_label: str,
|
||||
):
|
||||
"""Yield zero or one ``data: token-usage`` SSE frame.
|
||||
|
||||
Side effect: logs a one-line ``[token_usage] {log_label}: ...`` summary so
|
||||
cost analysis can grep call/total/cost across all flows.
|
||||
"""
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] %s: calls=%d total=%d cost_micros=%d summary=%s",
|
||||
log_label,
|
||||
len(accumulator.calls),
|
||||
accumulator.grand_total,
|
||||
accumulator.total_cost_micros,
|
||||
usage_summary,
|
||||
)
|
||||
if usage_summary:
|
||||
yield streaming_service.format_data(
|
||||
"token-usage",
|
||||
{
|
||||
"usage": usage_summary,
|
||||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"cost_micros": accumulator.total_cost_micros,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
},
|
||||
)
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
"""Shared finally-block helpers: session close, GC pass, native-heap trim.
|
||||
|
||||
These are called from inside an ``anyio.CancelScope(shield=True)`` block in
|
||||
each flow's ``finally`` (Starlette's BaseHTTPMiddleware cancels the scope on
|
||||
client disconnect; without the shield the very first ``await`` would raise
|
||||
``CancelledError`` and the rest of cleanup — including ``session.close()`` —
|
||||
would never run).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import gc
|
||||
import logging
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import shielded_async_session
|
||||
from app.services.chat_session_state_service import clear_ai_responding
|
||||
from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def close_session_and_clear_ai_responding(
|
||||
session: AsyncSession, chat_id: int
|
||||
) -> None:
|
||||
"""Rollback + clear AI-responding flag + expunge_all + close.
|
||||
|
||||
On rollback failure we fall back to a fresh shielded session for the flag
|
||||
clear so a UI is never stuck on "AI is responding…" after a crash.
|
||||
"""
|
||||
try:
|
||||
await session.rollback()
|
||||
await clear_ai_responding(session, chat_id)
|
||||
except Exception:
|
||||
try:
|
||||
async with shielded_async_session() as fresh_session:
|
||||
await clear_ai_responding(fresh_session, chat_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to clear AI responding state for thread %s", chat_id
|
||||
)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
session.expunge_all()
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
await session.close()
|
||||
|
||||
|
||||
def run_gc_pass(*, log_prefix: str, chat_id: int) -> None:
|
||||
"""One full gen0/1/2 pass + native-heap trim + END system snapshot.
|
||||
|
||||
Breaking circular refs held by the agent graph, tools, and LLM wrappers
|
||||
needs to happen in the caller (set the locals to ``None``) — this just
|
||||
runs the collector and logs how many objects came back.
|
||||
"""
|
||||
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
|
||||
if collected:
|
||||
_perf_log.info(
|
||||
"[%s] gc.collect() reclaimed %d objects (chat_id=%s)",
|
||||
log_prefix,
|
||||
collected,
|
||||
chat_id,
|
||||
)
|
||||
trim_native_heap()
|
||||
log_system_snapshot(f"{log_prefix}_END")
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
"""Initial SSE frames every flow emits right after pre-stream setup.
|
||||
|
||||
Order matters: ``message_start`` opens the assistant message, ``start_step``
|
||||
opens the first thinking step, ``turn-info`` lets the frontend stamp the
|
||||
correlation id onto the in-flight message, and ``turn-status: busy`` flips the
|
||||
UI into the streaming state.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
|
||||
|
||||
def iter_initial_frames(
|
||||
streaming_service: VercelStreamingService,
|
||||
*,
|
||||
turn_id: str,
|
||||
) -> Iterator[str]:
|
||||
"""Yield the four canonical opening frames in order.
|
||||
|
||||
``turn-info`` carries ``chat_turn_id`` so even pure-text turns (which
|
||||
never produce a tool / action-log event) still teach the frontend the
|
||||
turn correlation id used for ``appendMessage`` durable storage.
|
||||
"""
|
||||
yield streaming_service.format_message_start()
|
||||
yield streaming_service.format_start_step()
|
||||
yield streaming_service.format_data("turn-info", {"chat_turn_id": turn_id})
|
||||
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||
|
||||
|
||||
def iter_final_frames(
|
||||
streaming_service: VercelStreamingService,
|
||||
) -> Iterator[str]:
|
||||
"""Yield ``turn-status: idle`` plus the finish/done trailer in order."""
|
||||
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
"""Load an LLM + AgentConfig bundle for a given config id.
|
||||
|
||||
Handles both code paths uniformly:
|
||||
- ``config_id >= 0`` → database-backed ``NewLLMConfig`` row (per-user/per-space).
|
||||
- ``config_id < 0`` → YAML-defined global LLM config (built-in defaults).
|
||||
|
||||
Returns ``(llm, agent_config, error_message)``; on success ``error_message`` is
|
||||
``None``. The caller emits the friendly SSE error frame.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.llm_config import (
|
||||
AgentConfig,
|
||||
create_chat_litellm_from_agent_config,
|
||||
create_chat_litellm_from_config,
|
||||
load_agent_config,
|
||||
load_global_llm_config_by_id,
|
||||
)
|
||||
|
||||
|
||||
async def load_llm_bundle(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
config_id: int,
|
||||
search_space_id: int,
|
||||
) -> tuple[Any, AgentConfig | None, str | None]:
|
||||
if config_id >= 0:
|
||||
loaded_agent_config = await load_agent_config(
|
||||
session=session,
|
||||
config_id=config_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
if not loaded_agent_config:
|
||||
return (
|
||||
None,
|
||||
None,
|
||||
f"Failed to load NewLLMConfig with id {config_id}",
|
||||
)
|
||||
return (
|
||||
create_chat_litellm_from_agent_config(loaded_agent_config),
|
||||
loaded_agent_config,
|
||||
None,
|
||||
)
|
||||
|
||||
loaded_llm_config = load_global_llm_config_by_id(config_id)
|
||||
if not loaded_llm_config:
|
||||
return None, None, f"Failed to load LLM config with id {config_id}"
|
||||
return (
|
||||
create_chat_litellm_from_config(loaded_llm_config),
|
||||
AgentConfig.from_yaml_config(loaded_llm_config),
|
||||
None,
|
||||
)
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
"""Pre-stream setup: connector service, firecrawl key, checkpointer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||
from app.db import SearchSourceConnectorType
|
||||
from app.services.connector_service import ConnectorService
|
||||
|
||||
|
||||
async def setup_connector_and_firecrawl(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
) -> tuple[ConnectorService, str | None]:
|
||||
"""Build the per-turn connector service and pull the firecrawl API key.
|
||||
|
||||
Returns ``(connector_service, firecrawl_api_key)``. ``firecrawl_api_key`` is
|
||||
``None`` when no web-crawler connector is configured (the agent simply
|
||||
skips firecrawl-backed tools in that case).
|
||||
"""
|
||||
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
||||
firecrawl_api_key: str | None = None
|
||||
webcrawler_connector = await connector_service.get_connector_by_type(
|
||||
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
|
||||
)
|
||||
if webcrawler_connector and webcrawler_connector.config:
|
||||
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
|
||||
return connector_service, firecrawl_api_key
|
||||
|
||||
|
||||
async def get_chat_checkpointer():
|
||||
"""Resolve the PostgreSQL checkpointer for persistent conversation memory.
|
||||
|
||||
Thin wrapper around ``app.agents.new_chat.checkpointer.get_checkpointer`` so
|
||||
flow orchestrators can rely on a streaming-local symbol and we have a hook
|
||||
point if the checkpointer source ever needs to vary per flow.
|
||||
"""
|
||||
return await get_checkpointer()
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
"""OpenTelemetry chat-request span wrapper for streaming flows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import sys
|
||||
from typing import Any, Literal
|
||||
|
||||
from app.observability import metrics as ot_metrics
|
||||
from app.observability import otel as ot
|
||||
|
||||
|
||||
def open_chat_request_span(
|
||||
*,
|
||||
chat_id: int,
|
||||
search_space_id: int,
|
||||
flow: Literal["new", "regenerate", "resume"],
|
||||
request_id: str | None,
|
||||
turn_id: str,
|
||||
filesystem_mode: str,
|
||||
client_platform: str,
|
||||
agent_mode: str,
|
||||
) -> tuple[Any, Any]:
|
||||
"""Open the per-request span; returns ``(span_cm, span)`` for finally-close."""
|
||||
span_cm = ot.chat_request_span(
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
flow=flow,
|
||||
request_id=request_id,
|
||||
turn_id=turn_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
client_platform=client_platform,
|
||||
agent_mode=agent_mode,
|
||||
)
|
||||
span = span_cm.__enter__()
|
||||
return span_cm, span
|
||||
|
||||
|
||||
def set_agent_mode(span: Any, agent_mode: str) -> None:
|
||||
"""Tag the span with the resolved agent mode (single / multi)."""
|
||||
with contextlib.suppress(Exception):
|
||||
span.set_attribute("agent.mode", agent_mode)
|
||||
|
||||
|
||||
def close_chat_request_span(
|
||||
*,
|
||||
span_cm: Any,
|
||||
span: Any,
|
||||
chat_outcome: str,
|
||||
chat_agent_mode: str,
|
||||
flow: Literal["new", "regenerate", "resume"],
|
||||
chat_error_category: str | None,
|
||||
duration_seconds: float,
|
||||
) -> None:
|
||||
"""Record metrics + close the span. Swallows errors (finally-block context)."""
|
||||
with contextlib.suppress(Exception):
|
||||
span.set_attribute("chat.outcome", chat_outcome)
|
||||
ot_metrics.record_chat_request_duration(
|
||||
duration_seconds * 1000,
|
||||
flow=flow,
|
||||
outcome=chat_outcome,
|
||||
agent_mode=chat_agent_mode,
|
||||
)
|
||||
ot_metrics.record_chat_request_outcome(
|
||||
flow=flow,
|
||||
outcome=chat_outcome,
|
||||
agent_mode=chat_agent_mode,
|
||||
error_category=chat_error_category,
|
||||
)
|
||||
span_cm.__exit__(*sys.exc_info())
|
||||
|
||||
|
||||
def record_outcome_attrs(
|
||||
span: Any, *, chat_outcome: str, chat_error_category: str | None
|
||||
) -> None:
|
||||
"""Stamp outcome + error.category on the span (used in the except branch)."""
|
||||
with contextlib.suppress(Exception):
|
||||
span.set_attribute("chat.outcome", chat_outcome)
|
||||
if chat_error_category is not None:
|
||||
span.set_attribute("error.category", chat_error_category)
|
||||
Loading…
Add table
Add a link
Reference in a new issue