feat(chat): add model retry and stream lifecycle events

This commit is contained in:
Anish Sarkar 2026-05-22 17:48:43 +05:30
parent dbb652d4f8
commit dc893281ba
3 changed files with 138 additions and 6 deletions

View file

@ -45,6 +45,8 @@ from langchain.agents.middleware.types import (
from langchain_core.callbacks import adispatch_custom_event, dispatch_custom_event
from langchain_core.messages import AIMessage
from app.observability import metrics as ot_metrics, otel as ot
logger = logging.getLogger(__name__)
# Names of exception classes for which a retry would not help — context
@ -198,6 +200,15 @@ class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Resp
if not self._should_retry(exc) or attempt >= self.max_retries:
raise
delay = self._delay_for_attempt(attempt, exc)
ot.add_event(
"model.retry.scheduled",
{
"retry.attempt": attempt + 1,
"retry.max": self.max_retries,
"retry.delay_ms": int(delay * 1000),
"retry.reason": ot_metrics.categorize_exception(exc),
},
)
try:
dispatch_custom_event(
"surfsense.retrying",
@ -231,6 +242,15 @@ class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Resp
if not self._should_retry(exc) or attempt >= self.max_retries:
raise
delay = self._delay_for_attempt(attempt, exc)
ot.add_event(
"model.retry.scheduled",
{
"retry.attempt": attempt + 1,
"retry.max": self.max_retries,
"retry.delay_ms": int(delay * 1000),
"retry.reason": ot_metrics.categorize_exception(exc),
},
)
try:
await adispatch_custom_event(
"surfsense.retrying",

View file

@ -6,6 +6,8 @@ from typing import TYPE_CHECKING, Any
from langchain.agents.middleware import ModelFallbackMiddleware
from app.observability import metrics as ot_metrics, otel as ot
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
@ -55,7 +57,16 @@ class ScopedModelFallbackMiddleware(ModelFallbackMiddleware):
raise
last_exception = e
for fallback_model in self.models:
for attempt, fallback_model in enumerate(self.models, start=1):
ot.add_event(
"model.fallback",
{
"fallback.attempt": attempt,
"fallback.from": attempt - 1,
"fallback.to": attempt,
"fallback.reason": ot_metrics.categorize_exception(last_exception),
},
)
try:
return handler(request.override(model=fallback_model))
except Exception as e:
@ -79,7 +90,16 @@ class ScopedModelFallbackMiddleware(ModelFallbackMiddleware):
raise
last_exception = e
for fallback_model in self.models:
for attempt, fallback_model in enumerate(self.models, start=1):
ot.add_event(
"model.fallback",
{
"fallback.attempt": attempt,
"fallback.from": attempt - 1,
"fallback.to": attempt,
"fallback.reason": ot_metrics.categorize_exception(last_exception),
},
)
try:
return await handler(request.override(model=fallback_model))
except Exception as e:

View file

@ -14,6 +14,7 @@ import contextlib
import gc
import json
import logging
import sys
import time
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
@ -886,6 +887,7 @@ async def stream_new_chat(
stream_result.client_platform = fs_platform
chat_agent_mode = "unknown"
chat_outcome = "success"
chat_error_category: str | None = None
chat_span_cm = ot.chat_request_span(
chat_id=chat_id,
search_space_id=search_space_id,
@ -985,6 +987,14 @@ async def stream_new_chat(
requires_image_input=_requires_image_input,
)
).resolved_llm_config_id
ot.add_event(
"model.pin.resolved",
{
"pin.requested_id": requested_llm_config_id,
"pin.resolved_id": llm_config_id,
"pin.requires_image_input": _requires_image_input,
},
)
except ValueError as pin_error:
# Auto-pin's "no vision-capable cfg" path raises a ValueError
# whose message we map to the friendly image-input SSE error
@ -1001,6 +1011,13 @@ async def stream_new_chat(
if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
else "server_error"
)
if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT":
ot.add_event(
"quota.denied",
{
"quota.code": error_code,
},
)
yield _emit_stream_error(
message=str(pin_error),
error_kind=error_kind,
@ -1055,6 +1072,12 @@ async def stream_new_chat(
model_label = (
agent_config.config_name or agent_config.model_name or "model"
)
ot.add_event(
"quota.denied",
{
"quota.code": "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT",
},
)
yield _emit_stream_error(
message=(
f"The selected model ({model_label}) does not support "
@ -1098,6 +1121,12 @@ async def stream_new_chat(
)
_premium_reserved_micros = reserve_amount_micros
if not quota_result.allowed:
ot.add_event(
"quota.denied",
{
"quota.code": "PREMIUM_QUOTA_EXHAUSTED",
},
)
if requested_llm_config_id == 0:
try:
llm_config_id = (
@ -1111,6 +1140,13 @@ async def stream_new_chat(
requires_image_input=_requires_image_input,
)
).resolved_llm_config_id
ot.add_event(
"model.repin",
{
"repin.reason": "premium_quota_exhausted",
"repin.to_config_id": llm_config_id,
},
)
except ValueError as pin_error:
yield _emit_stream_error(
message=str(pin_error),
@ -1880,6 +1916,14 @@ async def stream_new_chat(
llm_config_id,
time.perf_counter() - _t0,
)
ot.add_event(
"chat.rate_limit.recovered",
{
"recovery.reason": "provider_rate_limited",
"recovery.previous_config_id": previous_config_id,
"recovery.fallback_config_id": llm_config_id,
},
)
_log_chat_stream_error(
flow=flow,
error_kind="rate_limited",
@ -1910,6 +1954,12 @@ async def stream_new_chat(
log_system_snapshot("stream_new_chat_END")
if stream_result.is_interrupted:
ot.add_event(
"chat.interrupted",
{
"chat.flow": flow,
},
)
if title_task is not None and not title_task.done():
title_task.cancel()
@ -2029,9 +2079,11 @@ async def stream_new_chat(
error_extra,
) = _classify_stream_exception(e, flow_label="chat")
chat_outcome = error_code or error_kind or "error"
chat_error_category = ot_metrics.categorize_exception(e)
with contextlib.suppress(Exception):
chat_span.set_attribute("chat.outcome", chat_outcome)
chat_span.record_exception(e)
chat_span.set_attribute("error.category", chat_error_category)
ot.record_error(chat_span, e)
error_message = f"Error during chat: {e!s}"
print(f"[stream_new_chat] {error_message}")
print(f"[stream_new_chat] Exception type: {type(e).__name__}")
@ -2234,8 +2286,9 @@ async def stream_new_chat(
flow=flow,
outcome=chat_outcome,
agent_mode=chat_agent_mode,
error_category=chat_error_category,
)
chat_span_cm.__exit__(None, None, None)
chat_span_cm.__exit__(*sys.exc_info())
async def stream_resume_chat(
@ -2262,6 +2315,7 @@ async def stream_resume_chat(
stream_result.client_platform = fs_platform
chat_agent_mode = "unknown"
chat_outcome = "success"
chat_error_category: str | None = None
chat_span_cm = ot.chat_request_span(
chat_id=chat_id,
search_space_id=search_space_id,
@ -2345,6 +2399,14 @@ async def stream_resume_chat(
selected_llm_config_id=llm_config_id,
)
).resolved_llm_config_id
ot.add_event(
"model.pin.resolved",
{
"pin.requested_id": requested_llm_config_id,
"pin.resolved_id": llm_config_id,
"pin.requires_image_input": False,
},
)
except ValueError as pin_error:
yield _emit_stream_error(
message=str(pin_error),
@ -2401,6 +2463,12 @@ async def stream_resume_chat(
)
_resume_premium_reserved_micros = reserve_amount_micros
if not quota_result.allowed:
ot.add_event(
"quota.denied",
{
"quota.code": "PREMIUM_QUOTA_EXHAUSTED",
},
)
if requested_llm_config_id == 0:
try:
llm_config_id = (
@ -2413,6 +2481,13 @@ async def stream_resume_chat(
force_repin_free=True,
)
).resolved_llm_config_id
ot.add_event(
"model.repin",
{
"repin.reason": "premium_quota_exhausted",
"repin.to_config_id": llm_config_id,
},
)
except ValueError as pin_error:
yield _emit_stream_error(
message=str(pin_error),
@ -2748,6 +2823,14 @@ async def stream_resume_chat(
llm_config_id,
time.perf_counter() - _t0,
)
ot.add_event(
"chat.rate_limit.recovered",
{
"recovery.reason": "provider_rate_limited",
"recovery.previous_config_id": previous_config_id,
"recovery.fallback_config_id": llm_config_id,
},
)
_log_chat_stream_error(
flow="resume",
error_kind="rate_limited",
@ -2775,6 +2858,12 @@ async def stream_resume_chat(
chat_id,
)
if stream_result.is_interrupted:
ot.add_event(
"chat.interrupted",
{
"chat.flow": "resume",
},
)
usage_summary = accumulator.per_message_summary()
_perf_log.info(
"[token_usage] interrupted resume_chat: calls=%d total=%d cost_micros=%d summary=%s",
@ -2869,9 +2958,11 @@ async def stream_resume_chat(
error_extra,
) = _classify_stream_exception(e, flow_label="resume")
chat_outcome = error_code or error_kind or "error"
chat_error_category = ot_metrics.categorize_exception(e)
with contextlib.suppress(Exception):
chat_span.set_attribute("chat.outcome", chat_outcome)
chat_span.record_exception(e)
chat_span.set_attribute("error.category", chat_error_category)
ot.record_error(chat_span, e)
error_message = f"Error during resume: {e!s}"
print(f"[stream_resume_chat] {error_message}")
print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}")
@ -3033,5 +3124,6 @@ async def stream_resume_chat(
flow="resume",
outcome=chat_outcome,
agent_mode=chat_agent_mode,
error_category=chat_error_category,
)
chat_span_cm.__exit__(None, None, None)
chat_span_cm.__exit__(*sys.exc_info())