feat(chat): add model retry and stream lifecycle events

This commit is contained in:
Anish Sarkar 2026-05-22 17:48:43 +05:30
parent dbb652d4f8
commit dc893281ba
3 changed files with 138 additions and 6 deletions

View file

@ -45,6 +45,8 @@ from langchain.agents.middleware.types import (
from langchain_core.callbacks import adispatch_custom_event, dispatch_custom_event from langchain_core.callbacks import adispatch_custom_event, dispatch_custom_event
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage
from app.observability import metrics as ot_metrics, otel as ot
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Names of exception classes for which a retry would not help — context # Names of exception classes for which a retry would not help — context
@ -198,6 +200,15 @@ class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Resp
if not self._should_retry(exc) or attempt >= self.max_retries: if not self._should_retry(exc) or attempt >= self.max_retries:
raise raise
delay = self._delay_for_attempt(attempt, exc) delay = self._delay_for_attempt(attempt, exc)
ot.add_event(
"model.retry.scheduled",
{
"retry.attempt": attempt + 1,
"retry.max": self.max_retries,
"retry.delay_ms": int(delay * 1000),
"retry.reason": ot_metrics.categorize_exception(exc),
},
)
try: try:
dispatch_custom_event( dispatch_custom_event(
"surfsense.retrying", "surfsense.retrying",
@ -231,6 +242,15 @@ class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Resp
if not self._should_retry(exc) or attempt >= self.max_retries: if not self._should_retry(exc) or attempt >= self.max_retries:
raise raise
delay = self._delay_for_attempt(attempt, exc) delay = self._delay_for_attempt(attempt, exc)
ot.add_event(
"model.retry.scheduled",
{
"retry.attempt": attempt + 1,
"retry.max": self.max_retries,
"retry.delay_ms": int(delay * 1000),
"retry.reason": ot_metrics.categorize_exception(exc),
},
)
try: try:
await adispatch_custom_event( await adispatch_custom_event(
"surfsense.retrying", "surfsense.retrying",

View file

@ -6,6 +6,8 @@ from typing import TYPE_CHECKING, Any
from langchain.agents.middleware import ModelFallbackMiddleware from langchain.agents.middleware import ModelFallbackMiddleware
from app.observability import metrics as ot_metrics, otel as ot
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
@ -55,7 +57,16 @@ class ScopedModelFallbackMiddleware(ModelFallbackMiddleware):
raise raise
last_exception = e last_exception = e
for fallback_model in self.models: for attempt, fallback_model in enumerate(self.models, start=1):
ot.add_event(
"model.fallback",
{
"fallback.attempt": attempt,
"fallback.from": attempt - 1,
"fallback.to": attempt,
"fallback.reason": ot_metrics.categorize_exception(last_exception),
},
)
try: try:
return handler(request.override(model=fallback_model)) return handler(request.override(model=fallback_model))
except Exception as e: except Exception as e:
@ -79,7 +90,16 @@ class ScopedModelFallbackMiddleware(ModelFallbackMiddleware):
raise raise
last_exception = e last_exception = e
for fallback_model in self.models: for attempt, fallback_model in enumerate(self.models, start=1):
ot.add_event(
"model.fallback",
{
"fallback.attempt": attempt,
"fallback.from": attempt - 1,
"fallback.to": attempt,
"fallback.reason": ot_metrics.categorize_exception(last_exception),
},
)
try: try:
return await handler(request.override(model=fallback_model)) return await handler(request.override(model=fallback_model))
except Exception as e: except Exception as e:

View file

@ -14,6 +14,7 @@ import contextlib
import gc import gc
import json import json
import logging import logging
import sys
import time import time
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -886,6 +887,7 @@ async def stream_new_chat(
stream_result.client_platform = fs_platform stream_result.client_platform = fs_platform
chat_agent_mode = "unknown" chat_agent_mode = "unknown"
chat_outcome = "success" chat_outcome = "success"
chat_error_category: str | None = None
chat_span_cm = ot.chat_request_span( chat_span_cm = ot.chat_request_span(
chat_id=chat_id, chat_id=chat_id,
search_space_id=search_space_id, search_space_id=search_space_id,
@ -985,6 +987,14 @@ async def stream_new_chat(
requires_image_input=_requires_image_input, requires_image_input=_requires_image_input,
) )
).resolved_llm_config_id ).resolved_llm_config_id
ot.add_event(
"model.pin.resolved",
{
"pin.requested_id": requested_llm_config_id,
"pin.resolved_id": llm_config_id,
"pin.requires_image_input": _requires_image_input,
},
)
except ValueError as pin_error: except ValueError as pin_error:
# Auto-pin's "no vision-capable cfg" path raises a ValueError # Auto-pin's "no vision-capable cfg" path raises a ValueError
# whose message we map to the friendly image-input SSE error # whose message we map to the friendly image-input SSE error
@ -1001,6 +1011,13 @@ async def stream_new_chat(
if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT" if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
else "server_error" else "server_error"
) )
if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT":
ot.add_event(
"quota.denied",
{
"quota.code": error_code,
},
)
yield _emit_stream_error( yield _emit_stream_error(
message=str(pin_error), message=str(pin_error),
error_kind=error_kind, error_kind=error_kind,
@ -1055,6 +1072,12 @@ async def stream_new_chat(
model_label = ( model_label = (
agent_config.config_name or agent_config.model_name or "model" agent_config.config_name or agent_config.model_name or "model"
) )
ot.add_event(
"quota.denied",
{
"quota.code": "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT",
},
)
yield _emit_stream_error( yield _emit_stream_error(
message=( message=(
f"The selected model ({model_label}) does not support " f"The selected model ({model_label}) does not support "
@ -1098,6 +1121,12 @@ async def stream_new_chat(
) )
_premium_reserved_micros = reserve_amount_micros _premium_reserved_micros = reserve_amount_micros
if not quota_result.allowed: if not quota_result.allowed:
ot.add_event(
"quota.denied",
{
"quota.code": "PREMIUM_QUOTA_EXHAUSTED",
},
)
if requested_llm_config_id == 0: if requested_llm_config_id == 0:
try: try:
llm_config_id = ( llm_config_id = (
@ -1111,6 +1140,13 @@ async def stream_new_chat(
requires_image_input=_requires_image_input, requires_image_input=_requires_image_input,
) )
).resolved_llm_config_id ).resolved_llm_config_id
ot.add_event(
"model.repin",
{
"repin.reason": "premium_quota_exhausted",
"repin.to_config_id": llm_config_id,
},
)
except ValueError as pin_error: except ValueError as pin_error:
yield _emit_stream_error( yield _emit_stream_error(
message=str(pin_error), message=str(pin_error),
@ -1880,6 +1916,14 @@ async def stream_new_chat(
llm_config_id, llm_config_id,
time.perf_counter() - _t0, time.perf_counter() - _t0,
) )
ot.add_event(
"chat.rate_limit.recovered",
{
"recovery.reason": "provider_rate_limited",
"recovery.previous_config_id": previous_config_id,
"recovery.fallback_config_id": llm_config_id,
},
)
_log_chat_stream_error( _log_chat_stream_error(
flow=flow, flow=flow,
error_kind="rate_limited", error_kind="rate_limited",
@ -1910,6 +1954,12 @@ async def stream_new_chat(
log_system_snapshot("stream_new_chat_END") log_system_snapshot("stream_new_chat_END")
if stream_result.is_interrupted: if stream_result.is_interrupted:
ot.add_event(
"chat.interrupted",
{
"chat.flow": flow,
},
)
if title_task is not None and not title_task.done(): if title_task is not None and not title_task.done():
title_task.cancel() title_task.cancel()
@ -2029,9 +2079,11 @@ async def stream_new_chat(
error_extra, error_extra,
) = _classify_stream_exception(e, flow_label="chat") ) = _classify_stream_exception(e, flow_label="chat")
chat_outcome = error_code or error_kind or "error" chat_outcome = error_code or error_kind or "error"
chat_error_category = ot_metrics.categorize_exception(e)
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
chat_span.set_attribute("chat.outcome", chat_outcome) chat_span.set_attribute("chat.outcome", chat_outcome)
chat_span.record_exception(e) chat_span.set_attribute("error.category", chat_error_category)
ot.record_error(chat_span, e)
error_message = f"Error during chat: {e!s}" error_message = f"Error during chat: {e!s}"
print(f"[stream_new_chat] {error_message}") print(f"[stream_new_chat] {error_message}")
print(f"[stream_new_chat] Exception type: {type(e).__name__}") print(f"[stream_new_chat] Exception type: {type(e).__name__}")
@ -2234,8 +2286,9 @@ async def stream_new_chat(
flow=flow, flow=flow,
outcome=chat_outcome, outcome=chat_outcome,
agent_mode=chat_agent_mode, agent_mode=chat_agent_mode,
error_category=chat_error_category,
) )
chat_span_cm.__exit__(None, None, None) chat_span_cm.__exit__(*sys.exc_info())
async def stream_resume_chat( async def stream_resume_chat(
@ -2262,6 +2315,7 @@ async def stream_resume_chat(
stream_result.client_platform = fs_platform stream_result.client_platform = fs_platform
chat_agent_mode = "unknown" chat_agent_mode = "unknown"
chat_outcome = "success" chat_outcome = "success"
chat_error_category: str | None = None
chat_span_cm = ot.chat_request_span( chat_span_cm = ot.chat_request_span(
chat_id=chat_id, chat_id=chat_id,
search_space_id=search_space_id, search_space_id=search_space_id,
@ -2345,6 +2399,14 @@ async def stream_resume_chat(
selected_llm_config_id=llm_config_id, selected_llm_config_id=llm_config_id,
) )
).resolved_llm_config_id ).resolved_llm_config_id
ot.add_event(
"model.pin.resolved",
{
"pin.requested_id": requested_llm_config_id,
"pin.resolved_id": llm_config_id,
"pin.requires_image_input": False,
},
)
except ValueError as pin_error: except ValueError as pin_error:
yield _emit_stream_error( yield _emit_stream_error(
message=str(pin_error), message=str(pin_error),
@ -2401,6 +2463,12 @@ async def stream_resume_chat(
) )
_resume_premium_reserved_micros = reserve_amount_micros _resume_premium_reserved_micros = reserve_amount_micros
if not quota_result.allowed: if not quota_result.allowed:
ot.add_event(
"quota.denied",
{
"quota.code": "PREMIUM_QUOTA_EXHAUSTED",
},
)
if requested_llm_config_id == 0: if requested_llm_config_id == 0:
try: try:
llm_config_id = ( llm_config_id = (
@ -2413,6 +2481,13 @@ async def stream_resume_chat(
force_repin_free=True, force_repin_free=True,
) )
).resolved_llm_config_id ).resolved_llm_config_id
ot.add_event(
"model.repin",
{
"repin.reason": "premium_quota_exhausted",
"repin.to_config_id": llm_config_id,
},
)
except ValueError as pin_error: except ValueError as pin_error:
yield _emit_stream_error( yield _emit_stream_error(
message=str(pin_error), message=str(pin_error),
@ -2748,6 +2823,14 @@ async def stream_resume_chat(
llm_config_id, llm_config_id,
time.perf_counter() - _t0, time.perf_counter() - _t0,
) )
ot.add_event(
"chat.rate_limit.recovered",
{
"recovery.reason": "provider_rate_limited",
"recovery.previous_config_id": previous_config_id,
"recovery.fallback_config_id": llm_config_id,
},
)
_log_chat_stream_error( _log_chat_stream_error(
flow="resume", flow="resume",
error_kind="rate_limited", error_kind="rate_limited",
@ -2775,6 +2858,12 @@ async def stream_resume_chat(
chat_id, chat_id,
) )
if stream_result.is_interrupted: if stream_result.is_interrupted:
ot.add_event(
"chat.interrupted",
{
"chat.flow": "resume",
},
)
usage_summary = accumulator.per_message_summary() usage_summary = accumulator.per_message_summary()
_perf_log.info( _perf_log.info(
"[token_usage] interrupted resume_chat: calls=%d total=%d cost_micros=%d summary=%s", "[token_usage] interrupted resume_chat: calls=%d total=%d cost_micros=%d summary=%s",
@ -2869,9 +2958,11 @@ async def stream_resume_chat(
error_extra, error_extra,
) = _classify_stream_exception(e, flow_label="resume") ) = _classify_stream_exception(e, flow_label="resume")
chat_outcome = error_code or error_kind or "error" chat_outcome = error_code or error_kind or "error"
chat_error_category = ot_metrics.categorize_exception(e)
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
chat_span.set_attribute("chat.outcome", chat_outcome) chat_span.set_attribute("chat.outcome", chat_outcome)
chat_span.record_exception(e) chat_span.set_attribute("error.category", chat_error_category)
ot.record_error(chat_span, e)
error_message = f"Error during resume: {e!s}" error_message = f"Error during resume: {e!s}"
print(f"[stream_resume_chat] {error_message}") print(f"[stream_resume_chat] {error_message}")
print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}") print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}")
@ -3033,5 +3124,6 @@ async def stream_resume_chat(
flow="resume", flow="resume",
outcome=chat_outcome, outcome=chat_outcome,
agent_mode=chat_agent_mode, agent_mode=chat_agent_mode,
error_category=chat_error_category,
) )
chat_span_cm.__exit__(None, None, None) chat_span_cm.__exit__(*sys.exc_info())