refactor(chat): add streaming/flows/shared/ base helpers

Six small, single-purpose modules shared by the upcoming new_chat and
resume_chat orchestrators:

* llm_bundle: dispatches negative config_id to the YAML loader and
  non-negative config_id to the DB loader, returning (llm, AgentConfig).
* pre_stream_setup: builds the connector service, resolves the
  Firecrawl API key, and returns the chat checkpointer.
* first_frames: iter_initial_frames + iter_final_frames emit the canonical
  message-start / step-start / idle / finish / done SSE envelope.
* finalize_emit: iter_token_usage_frame emits the per-turn usage frame
  from a TokenAccumulator summary.
* finally_cleanup: close_session_and_clear_ai_responding and run_gc_pass
  centralize the finally-block bookkeeping.
* span: open_chat_request_span / set_agent_mode / close_chat_request_span /
  record_outcome_attrs wrap the OpenTelemetry chat_request span.

Add-only; these are not yet wired into stream_new_chat.py.
This commit is contained in:
CREDO23 2026-05-25 21:49:09 +02:00
parent 26c569467d
commit e9a98ecafb
7 changed files with 343 additions and 0 deletions

View file

@ -0,0 +1,3 @@
"""Building blocks shared by ``new_chat`` and ``resume_chat`` orchestrators."""
from __future__ import annotations

View file

@ -0,0 +1,54 @@
"""Emit the per-turn token-usage SSE frame from the accumulator.
``per_message_summary()`` returns ``None`` when the turn made no chargeable
LLM calls (e.g. interrupt-on-input). In that case we skip the frame; the
frontend has no usage to render.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from app.services.new_streaming_service import VercelStreamingService
from app.utils.perf import get_perf_logger
if TYPE_CHECKING:
from app.services.token_tracking_service import TokenAccumulator
_perf_log = get_perf_logger()
logger = logging.getLogger(__name__)
def iter_token_usage_frame(
streaming_service: VercelStreamingService,
*,
accumulator: TokenAccumulator,
log_label: str,
):
"""Yield zero or one ``data: token-usage`` SSE frame.
Side effect: logs a one-line ``[token_usage] {log_label}: ...`` summary so
cost analysis can grep call/total/cost across all flows.
"""
usage_summary = accumulator.per_message_summary()
_perf_log.info(
"[token_usage] %s: calls=%d total=%d cost_micros=%d summary=%s",
log_label,
len(accumulator.calls),
accumulator.grand_total,
accumulator.total_cost_micros,
usage_summary,
)
if usage_summary:
yield streaming_service.format_data(
"token-usage",
{
"usage": usage_summary,
"prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total,
"cost_micros": accumulator.total_cost_micros,
"call_details": accumulator.serialized_calls(),
},
)

View file

@ -0,0 +1,69 @@
"""Shared finally-block helpers: session close, GC pass, native-heap trim.
These are called from inside an ``anyio.CancelScope(shield=True)`` block in
each flow's ``finally`` (Starlette's BaseHTTPMiddleware cancels the scope on
client disconnect; without the shield the very first ``await`` would raise
``CancelledError`` and the rest of cleanup including ``session.close()``
would never run).
"""
from __future__ import annotations
import contextlib
import gc
import logging
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import shielded_async_session
from app.services.chat_session_state_service import clear_ai_responding
from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
_perf_log = get_perf_logger()
logger = logging.getLogger(__name__)
async def close_session_and_clear_ai_responding(
session: AsyncSession, chat_id: int
) -> None:
"""Rollback + clear AI-responding flag + expunge_all + close.
On rollback failure we fall back to a fresh shielded session for the flag
clear so a UI is never stuck on "AI is responding…" after a crash.
"""
try:
await session.rollback()
await clear_ai_responding(session, chat_id)
except Exception:
try:
async with shielded_async_session() as fresh_session:
await clear_ai_responding(fresh_session, chat_id)
except Exception:
logger.warning(
"Failed to clear AI responding state for thread %s", chat_id
)
with contextlib.suppress(Exception):
session.expunge_all()
with contextlib.suppress(Exception):
await session.close()
def run_gc_pass(*, log_prefix: str, chat_id: int) -> None:
"""One full gen0/1/2 pass + native-heap trim + END system snapshot.
Breaking circular refs held by the agent graph, tools, and LLM wrappers
needs to happen in the caller (set the locals to ``None``) this just
runs the collector and logs how many objects came back.
"""
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
if collected:
_perf_log.info(
"[%s] gc.collect() reclaimed %d objects (chat_id=%s)",
log_prefix,
collected,
chat_id,
)
trim_native_heap()
log_system_snapshot(f"{log_prefix}_END")

View file

@ -0,0 +1,40 @@
"""Initial SSE frames every flow emits right after pre-stream setup.
Order matters: ``message_start`` opens the assistant message, ``start_step``
opens the first thinking step, ``turn-info`` lets the frontend stamp the
correlation id onto the in-flight message, and ``turn-status: busy`` flips the
UI into the streaming state.
"""
from __future__ import annotations
from collections.abc import Iterator
from app.services.new_streaming_service import VercelStreamingService
def iter_initial_frames(
streaming_service: VercelStreamingService,
*,
turn_id: str,
) -> Iterator[str]:
"""Yield the four canonical opening frames in order.
``turn-info`` carries ``chat_turn_id`` so even pure-text turns (which
never produce a tool / action-log event) still teach the frontend the
turn correlation id used for ``appendMessage`` durable storage.
"""
yield streaming_service.format_message_start()
yield streaming_service.format_start_step()
yield streaming_service.format_data("turn-info", {"chat_turn_id": turn_id})
yield streaming_service.format_data("turn-status", {"status": "busy"})
def iter_final_frames(
streaming_service: VercelStreamingService,
) -> Iterator[str]:
"""Yield ``turn-status: idle`` plus the finish/done trailer in order."""
yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()

View file

@ -0,0 +1,57 @@
"""Load an LLM + AgentConfig bundle for a given config id.
Handles both code paths uniformly:
- ``config_id >= 0`` database-backed ``NewLLMConfig`` row (per-user/per-space).
- ``config_id < 0`` YAML-defined global LLM config (built-in defaults).
Returns ``(llm, agent_config, error_message)``; on success ``error_message`` is
``None``. The caller emits the friendly SSE error frame.
"""
from __future__ import annotations
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.llm_config import (
AgentConfig,
create_chat_litellm_from_agent_config,
create_chat_litellm_from_config,
load_agent_config,
load_global_llm_config_by_id,
)
async def load_llm_bundle(
session: AsyncSession,
*,
config_id: int,
search_space_id: int,
) -> tuple[Any, AgentConfig | None, str | None]:
if config_id >= 0:
loaded_agent_config = await load_agent_config(
session=session,
config_id=config_id,
search_space_id=search_space_id,
)
if not loaded_agent_config:
return (
None,
None,
f"Failed to load NewLLMConfig with id {config_id}",
)
return (
create_chat_litellm_from_agent_config(loaded_agent_config),
loaded_agent_config,
None,
)
loaded_llm_config = load_global_llm_config_by_id(config_id)
if not loaded_llm_config:
return None, None, f"Failed to load LLM config with id {config_id}"
return (
create_chat_litellm_from_config(loaded_llm_config),
AgentConfig.from_yaml_config(loaded_llm_config),
None,
)

View file

@ -0,0 +1,40 @@
"""Pre-stream setup: connector service, firecrawl key, checkpointer."""
from __future__ import annotations
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.checkpointer import get_checkpointer
from app.db import SearchSourceConnectorType
from app.services.connector_service import ConnectorService
async def setup_connector_and_firecrawl(
session: AsyncSession,
*,
search_space_id: int,
) -> tuple[ConnectorService, str | None]:
"""Build the per-turn connector service and pull the firecrawl API key.
Returns ``(connector_service, firecrawl_api_key)``. ``firecrawl_api_key`` is
``None`` when no web-crawler connector is configured (the agent simply
skips firecrawl-backed tools in that case).
"""
connector_service = ConnectorService(session, search_space_id=search_space_id)
firecrawl_api_key: str | None = None
webcrawler_connector = await connector_service.get_connector_by_type(
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
)
if webcrawler_connector and webcrawler_connector.config:
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
return connector_service, firecrawl_api_key
async def get_chat_checkpointer():
"""Resolve the PostgreSQL checkpointer for persistent conversation memory.
Thin wrapper around ``app.agents.new_chat.checkpointer.get_checkpointer`` so
flow orchestrators can rely on a streaming-local symbol and we have a hook
point if the checkpointer source ever needs to vary per flow.
"""
return await get_checkpointer()

View file

@ -0,0 +1,80 @@
"""OpenTelemetry chat-request span wrapper for streaming flows."""
from __future__ import annotations
import contextlib
import sys
from typing import Any, Literal
from app.observability import metrics as ot_metrics
from app.observability import otel as ot
def open_chat_request_span(
*,
chat_id: int,
search_space_id: int,
flow: Literal["new", "regenerate", "resume"],
request_id: str | None,
turn_id: str,
filesystem_mode: str,
client_platform: str,
agent_mode: str,
) -> tuple[Any, Any]:
"""Open the per-request span; returns ``(span_cm, span)`` for finally-close."""
span_cm = ot.chat_request_span(
chat_id=chat_id,
search_space_id=search_space_id,
flow=flow,
request_id=request_id,
turn_id=turn_id,
filesystem_mode=filesystem_mode,
client_platform=client_platform,
agent_mode=agent_mode,
)
span = span_cm.__enter__()
return span_cm, span
def set_agent_mode(span: Any, agent_mode: str) -> None:
"""Tag the span with the resolved agent mode (single / multi)."""
with contextlib.suppress(Exception):
span.set_attribute("agent.mode", agent_mode)
def close_chat_request_span(
*,
span_cm: Any,
span: Any,
chat_outcome: str,
chat_agent_mode: str,
flow: Literal["new", "regenerate", "resume"],
chat_error_category: str | None,
duration_seconds: float,
) -> None:
"""Record metrics + close the span. Swallows errors (finally-block context)."""
with contextlib.suppress(Exception):
span.set_attribute("chat.outcome", chat_outcome)
ot_metrics.record_chat_request_duration(
duration_seconds * 1000,
flow=flow,
outcome=chat_outcome,
agent_mode=chat_agent_mode,
)
ot_metrics.record_chat_request_outcome(
flow=flow,
outcome=chat_outcome,
agent_mode=chat_agent_mode,
error_category=chat_error_category,
)
span_cm.__exit__(*sys.exc_info())
def record_outcome_attrs(
span: Any, *, chat_outcome: str, chat_error_category: str | None
) -> None:
"""Stamp outcome + error.category on the span (used in the except branch)."""
with contextlib.suppress(Exception):
span.set_attribute("chat.outcome", chat_outcome)
if chat_error_category is not None:
span.set_attribute("error.category", chat_error_category)