feat(chat): add model retry and stream lifecycle events

This commit is contained in:
Anish Sarkar 2026-05-22 17:48:43 +05:30
parent dbb652d4f8
commit dc893281ba
3 changed files with 138 additions and 6 deletions

View file

@ -45,6 +45,8 @@ from langchain.agents.middleware.types import (
from langchain_core.callbacks import adispatch_custom_event, dispatch_custom_event
from langchain_core.messages import AIMessage
from app.observability import metrics as ot_metrics, otel as ot
logger = logging.getLogger(__name__)
# Names of exception classes for which a retry would not help — context
@ -198,6 +200,15 @@ class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Resp
if not self._should_retry(exc) or attempt >= self.max_retries:
raise
delay = self._delay_for_attempt(attempt, exc)
ot.add_event(
"model.retry.scheduled",
{
"retry.attempt": attempt + 1,
"retry.max": self.max_retries,
"retry.delay_ms": int(delay * 1000),
"retry.reason": ot_metrics.categorize_exception(exc),
},
)
try:
dispatch_custom_event(
"surfsense.retrying",
@ -231,6 +242,15 @@ class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Resp
if not self._should_retry(exc) or attempt >= self.max_retries:
raise
delay = self._delay_for_attempt(attempt, exc)
ot.add_event(
"model.retry.scheduled",
{
"retry.attempt": attempt + 1,
"retry.max": self.max_retries,
"retry.delay_ms": int(delay * 1000),
"retry.reason": ot_metrics.categorize_exception(exc),
},
)
try:
await adispatch_custom_event(
"surfsense.retrying",

View file

@ -6,6 +6,8 @@ from typing import TYPE_CHECKING, Any
from langchain.agents.middleware import ModelFallbackMiddleware
from app.observability import metrics as ot_metrics, otel as ot
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
@ -55,7 +57,16 @@ class ScopedModelFallbackMiddleware(ModelFallbackMiddleware):
raise
last_exception = e
for fallback_model in self.models:
for attempt, fallback_model in enumerate(self.models, start=1):
ot.add_event(
"model.fallback",
{
"fallback.attempt": attempt,
"fallback.from": attempt - 1,
"fallback.to": attempt,
"fallback.reason": ot_metrics.categorize_exception(last_exception),
},
)
try:
return handler(request.override(model=fallback_model))
except Exception as e:
@ -79,7 +90,16 @@ class ScopedModelFallbackMiddleware(ModelFallbackMiddleware):
raise
last_exception = e
for fallback_model in self.models:
for attempt, fallback_model in enumerate(self.models, start=1):
ot.add_event(
"model.fallback",
{
"fallback.attempt": attempt,
"fallback.from": attempt - 1,
"fallback.to": attempt,
"fallback.reason": ot_metrics.categorize_exception(last_exception),
},
)
try:
return await handler(request.override(model=fallback_model))
except Exception as e: