mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-27 19:25:15 +02:00
feat(chat): add model retry and stream lifecycle events
This commit is contained in:
parent
dbb652d4f8
commit
dc893281ba
3 changed files with 138 additions and 6 deletions
|
|
@ -45,6 +45,8 @@ from langchain.agents.middleware.types import (
|
||||||
from langchain_core.callbacks import adispatch_custom_event, dispatch_custom_event
|
from langchain_core.callbacks import adispatch_custom_event, dispatch_custom_event
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
from app.observability import metrics as ot_metrics, otel as ot
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Names of exception classes for which a retry would not help — context
|
# Names of exception classes for which a retry would not help — context
|
||||||
|
|
@ -198,6 +200,15 @@ class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Resp
|
||||||
if not self._should_retry(exc) or attempt >= self.max_retries:
|
if not self._should_retry(exc) or attempt >= self.max_retries:
|
||||||
raise
|
raise
|
||||||
delay = self._delay_for_attempt(attempt, exc)
|
delay = self._delay_for_attempt(attempt, exc)
|
||||||
|
ot.add_event(
|
||||||
|
"model.retry.scheduled",
|
||||||
|
{
|
||||||
|
"retry.attempt": attempt + 1,
|
||||||
|
"retry.max": self.max_retries,
|
||||||
|
"retry.delay_ms": int(delay * 1000),
|
||||||
|
"retry.reason": ot_metrics.categorize_exception(exc),
|
||||||
|
},
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
dispatch_custom_event(
|
dispatch_custom_event(
|
||||||
"surfsense.retrying",
|
"surfsense.retrying",
|
||||||
|
|
@ -231,6 +242,15 @@ class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Resp
|
||||||
if not self._should_retry(exc) or attempt >= self.max_retries:
|
if not self._should_retry(exc) or attempt >= self.max_retries:
|
||||||
raise
|
raise
|
||||||
delay = self._delay_for_attempt(attempt, exc)
|
delay = self._delay_for_attempt(attempt, exc)
|
||||||
|
ot.add_event(
|
||||||
|
"model.retry.scheduled",
|
||||||
|
{
|
||||||
|
"retry.attempt": attempt + 1,
|
||||||
|
"retry.max": self.max_retries,
|
||||||
|
"retry.delay_ms": int(delay * 1000),
|
||||||
|
"retry.reason": ot_metrics.categorize_exception(exc),
|
||||||
|
},
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
await adispatch_custom_event(
|
await adispatch_custom_event(
|
||||||
"surfsense.retrying",
|
"surfsense.retrying",
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,8 @@ from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from langchain.agents.middleware import ModelFallbackMiddleware
|
from langchain.agents.middleware import ModelFallbackMiddleware
|
||||||
|
|
||||||
|
from app.observability import metrics as ot_metrics, otel as ot
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
|
|
@ -55,7 +57,16 @@ class ScopedModelFallbackMiddleware(ModelFallbackMiddleware):
|
||||||
raise
|
raise
|
||||||
last_exception = e
|
last_exception = e
|
||||||
|
|
||||||
for fallback_model in self.models:
|
for attempt, fallback_model in enumerate(self.models, start=1):
|
||||||
|
ot.add_event(
|
||||||
|
"model.fallback",
|
||||||
|
{
|
||||||
|
"fallback.attempt": attempt,
|
||||||
|
"fallback.from": attempt - 1,
|
||||||
|
"fallback.to": attempt,
|
||||||
|
"fallback.reason": ot_metrics.categorize_exception(last_exception),
|
||||||
|
},
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
return handler(request.override(model=fallback_model))
|
return handler(request.override(model=fallback_model))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -79,7 +90,16 @@ class ScopedModelFallbackMiddleware(ModelFallbackMiddleware):
|
||||||
raise
|
raise
|
||||||
last_exception = e
|
last_exception = e
|
||||||
|
|
||||||
for fallback_model in self.models:
|
for attempt, fallback_model in enumerate(self.models, start=1):
|
||||||
|
ot.add_event(
|
||||||
|
"model.fallback",
|
||||||
|
{
|
||||||
|
"fallback.attempt": attempt,
|
||||||
|
"fallback.from": attempt - 1,
|
||||||
|
"fallback.to": attempt,
|
||||||
|
"fallback.reason": ot_metrics.categorize_exception(last_exception),
|
||||||
|
},
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
return await handler(request.override(model=fallback_model))
|
return await handler(request.override(model=fallback_model))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ import contextlib
|
||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
@ -886,6 +887,7 @@ async def stream_new_chat(
|
||||||
stream_result.client_platform = fs_platform
|
stream_result.client_platform = fs_platform
|
||||||
chat_agent_mode = "unknown"
|
chat_agent_mode = "unknown"
|
||||||
chat_outcome = "success"
|
chat_outcome = "success"
|
||||||
|
chat_error_category: str | None = None
|
||||||
chat_span_cm = ot.chat_request_span(
|
chat_span_cm = ot.chat_request_span(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
|
|
@ -985,6 +987,14 @@ async def stream_new_chat(
|
||||||
requires_image_input=_requires_image_input,
|
requires_image_input=_requires_image_input,
|
||||||
)
|
)
|
||||||
).resolved_llm_config_id
|
).resolved_llm_config_id
|
||||||
|
ot.add_event(
|
||||||
|
"model.pin.resolved",
|
||||||
|
{
|
||||||
|
"pin.requested_id": requested_llm_config_id,
|
||||||
|
"pin.resolved_id": llm_config_id,
|
||||||
|
"pin.requires_image_input": _requires_image_input,
|
||||||
|
},
|
||||||
|
)
|
||||||
except ValueError as pin_error:
|
except ValueError as pin_error:
|
||||||
# Auto-pin's "no vision-capable cfg" path raises a ValueError
|
# Auto-pin's "no vision-capable cfg" path raises a ValueError
|
||||||
# whose message we map to the friendly image-input SSE error
|
# whose message we map to the friendly image-input SSE error
|
||||||
|
|
@ -1001,6 +1011,13 @@ async def stream_new_chat(
|
||||||
if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
|
if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
|
||||||
else "server_error"
|
else "server_error"
|
||||||
)
|
)
|
||||||
|
if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT":
|
||||||
|
ot.add_event(
|
||||||
|
"quota.denied",
|
||||||
|
{
|
||||||
|
"quota.code": error_code,
|
||||||
|
},
|
||||||
|
)
|
||||||
yield _emit_stream_error(
|
yield _emit_stream_error(
|
||||||
message=str(pin_error),
|
message=str(pin_error),
|
||||||
error_kind=error_kind,
|
error_kind=error_kind,
|
||||||
|
|
@ -1055,6 +1072,12 @@ async def stream_new_chat(
|
||||||
model_label = (
|
model_label = (
|
||||||
agent_config.config_name or agent_config.model_name or "model"
|
agent_config.config_name or agent_config.model_name or "model"
|
||||||
)
|
)
|
||||||
|
ot.add_event(
|
||||||
|
"quota.denied",
|
||||||
|
{
|
||||||
|
"quota.code": "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT",
|
||||||
|
},
|
||||||
|
)
|
||||||
yield _emit_stream_error(
|
yield _emit_stream_error(
|
||||||
message=(
|
message=(
|
||||||
f"The selected model ({model_label}) does not support "
|
f"The selected model ({model_label}) does not support "
|
||||||
|
|
@ -1098,6 +1121,12 @@ async def stream_new_chat(
|
||||||
)
|
)
|
||||||
_premium_reserved_micros = reserve_amount_micros
|
_premium_reserved_micros = reserve_amount_micros
|
||||||
if not quota_result.allowed:
|
if not quota_result.allowed:
|
||||||
|
ot.add_event(
|
||||||
|
"quota.denied",
|
||||||
|
{
|
||||||
|
"quota.code": "PREMIUM_QUOTA_EXHAUSTED",
|
||||||
|
},
|
||||||
|
)
|
||||||
if requested_llm_config_id == 0:
|
if requested_llm_config_id == 0:
|
||||||
try:
|
try:
|
||||||
llm_config_id = (
|
llm_config_id = (
|
||||||
|
|
@ -1111,6 +1140,13 @@ async def stream_new_chat(
|
||||||
requires_image_input=_requires_image_input,
|
requires_image_input=_requires_image_input,
|
||||||
)
|
)
|
||||||
).resolved_llm_config_id
|
).resolved_llm_config_id
|
||||||
|
ot.add_event(
|
||||||
|
"model.repin",
|
||||||
|
{
|
||||||
|
"repin.reason": "premium_quota_exhausted",
|
||||||
|
"repin.to_config_id": llm_config_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
except ValueError as pin_error:
|
except ValueError as pin_error:
|
||||||
yield _emit_stream_error(
|
yield _emit_stream_error(
|
||||||
message=str(pin_error),
|
message=str(pin_error),
|
||||||
|
|
@ -1880,6 +1916,14 @@ async def stream_new_chat(
|
||||||
llm_config_id,
|
llm_config_id,
|
||||||
time.perf_counter() - _t0,
|
time.perf_counter() - _t0,
|
||||||
)
|
)
|
||||||
|
ot.add_event(
|
||||||
|
"chat.rate_limit.recovered",
|
||||||
|
{
|
||||||
|
"recovery.reason": "provider_rate_limited",
|
||||||
|
"recovery.previous_config_id": previous_config_id,
|
||||||
|
"recovery.fallback_config_id": llm_config_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
_log_chat_stream_error(
|
_log_chat_stream_error(
|
||||||
flow=flow,
|
flow=flow,
|
||||||
error_kind="rate_limited",
|
error_kind="rate_limited",
|
||||||
|
|
@ -1910,6 +1954,12 @@ async def stream_new_chat(
|
||||||
log_system_snapshot("stream_new_chat_END")
|
log_system_snapshot("stream_new_chat_END")
|
||||||
|
|
||||||
if stream_result.is_interrupted:
|
if stream_result.is_interrupted:
|
||||||
|
ot.add_event(
|
||||||
|
"chat.interrupted",
|
||||||
|
{
|
||||||
|
"chat.flow": flow,
|
||||||
|
},
|
||||||
|
)
|
||||||
if title_task is not None and not title_task.done():
|
if title_task is not None and not title_task.done():
|
||||||
title_task.cancel()
|
title_task.cancel()
|
||||||
|
|
||||||
|
|
@ -2029,9 +2079,11 @@ async def stream_new_chat(
|
||||||
error_extra,
|
error_extra,
|
||||||
) = _classify_stream_exception(e, flow_label="chat")
|
) = _classify_stream_exception(e, flow_label="chat")
|
||||||
chat_outcome = error_code or error_kind or "error"
|
chat_outcome = error_code or error_kind or "error"
|
||||||
|
chat_error_category = ot_metrics.categorize_exception(e)
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
chat_span.set_attribute("chat.outcome", chat_outcome)
|
chat_span.set_attribute("chat.outcome", chat_outcome)
|
||||||
chat_span.record_exception(e)
|
chat_span.set_attribute("error.category", chat_error_category)
|
||||||
|
ot.record_error(chat_span, e)
|
||||||
error_message = f"Error during chat: {e!s}"
|
error_message = f"Error during chat: {e!s}"
|
||||||
print(f"[stream_new_chat] {error_message}")
|
print(f"[stream_new_chat] {error_message}")
|
||||||
print(f"[stream_new_chat] Exception type: {type(e).__name__}")
|
print(f"[stream_new_chat] Exception type: {type(e).__name__}")
|
||||||
|
|
@ -2234,8 +2286,9 @@ async def stream_new_chat(
|
||||||
flow=flow,
|
flow=flow,
|
||||||
outcome=chat_outcome,
|
outcome=chat_outcome,
|
||||||
agent_mode=chat_agent_mode,
|
agent_mode=chat_agent_mode,
|
||||||
|
error_category=chat_error_category,
|
||||||
)
|
)
|
||||||
chat_span_cm.__exit__(None, None, None)
|
chat_span_cm.__exit__(*sys.exc_info())
|
||||||
|
|
||||||
|
|
||||||
async def stream_resume_chat(
|
async def stream_resume_chat(
|
||||||
|
|
@ -2262,6 +2315,7 @@ async def stream_resume_chat(
|
||||||
stream_result.client_platform = fs_platform
|
stream_result.client_platform = fs_platform
|
||||||
chat_agent_mode = "unknown"
|
chat_agent_mode = "unknown"
|
||||||
chat_outcome = "success"
|
chat_outcome = "success"
|
||||||
|
chat_error_category: str | None = None
|
||||||
chat_span_cm = ot.chat_request_span(
|
chat_span_cm = ot.chat_request_span(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
|
|
@ -2345,6 +2399,14 @@ async def stream_resume_chat(
|
||||||
selected_llm_config_id=llm_config_id,
|
selected_llm_config_id=llm_config_id,
|
||||||
)
|
)
|
||||||
).resolved_llm_config_id
|
).resolved_llm_config_id
|
||||||
|
ot.add_event(
|
||||||
|
"model.pin.resolved",
|
||||||
|
{
|
||||||
|
"pin.requested_id": requested_llm_config_id,
|
||||||
|
"pin.resolved_id": llm_config_id,
|
||||||
|
"pin.requires_image_input": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
except ValueError as pin_error:
|
except ValueError as pin_error:
|
||||||
yield _emit_stream_error(
|
yield _emit_stream_error(
|
||||||
message=str(pin_error),
|
message=str(pin_error),
|
||||||
|
|
@ -2401,6 +2463,12 @@ async def stream_resume_chat(
|
||||||
)
|
)
|
||||||
_resume_premium_reserved_micros = reserve_amount_micros
|
_resume_premium_reserved_micros = reserve_amount_micros
|
||||||
if not quota_result.allowed:
|
if not quota_result.allowed:
|
||||||
|
ot.add_event(
|
||||||
|
"quota.denied",
|
||||||
|
{
|
||||||
|
"quota.code": "PREMIUM_QUOTA_EXHAUSTED",
|
||||||
|
},
|
||||||
|
)
|
||||||
if requested_llm_config_id == 0:
|
if requested_llm_config_id == 0:
|
||||||
try:
|
try:
|
||||||
llm_config_id = (
|
llm_config_id = (
|
||||||
|
|
@ -2413,6 +2481,13 @@ async def stream_resume_chat(
|
||||||
force_repin_free=True,
|
force_repin_free=True,
|
||||||
)
|
)
|
||||||
).resolved_llm_config_id
|
).resolved_llm_config_id
|
||||||
|
ot.add_event(
|
||||||
|
"model.repin",
|
||||||
|
{
|
||||||
|
"repin.reason": "premium_quota_exhausted",
|
||||||
|
"repin.to_config_id": llm_config_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
except ValueError as pin_error:
|
except ValueError as pin_error:
|
||||||
yield _emit_stream_error(
|
yield _emit_stream_error(
|
||||||
message=str(pin_error),
|
message=str(pin_error),
|
||||||
|
|
@ -2748,6 +2823,14 @@ async def stream_resume_chat(
|
||||||
llm_config_id,
|
llm_config_id,
|
||||||
time.perf_counter() - _t0,
|
time.perf_counter() - _t0,
|
||||||
)
|
)
|
||||||
|
ot.add_event(
|
||||||
|
"chat.rate_limit.recovered",
|
||||||
|
{
|
||||||
|
"recovery.reason": "provider_rate_limited",
|
||||||
|
"recovery.previous_config_id": previous_config_id,
|
||||||
|
"recovery.fallback_config_id": llm_config_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
_log_chat_stream_error(
|
_log_chat_stream_error(
|
||||||
flow="resume",
|
flow="resume",
|
||||||
error_kind="rate_limited",
|
error_kind="rate_limited",
|
||||||
|
|
@ -2775,6 +2858,12 @@ async def stream_resume_chat(
|
||||||
chat_id,
|
chat_id,
|
||||||
)
|
)
|
||||||
if stream_result.is_interrupted:
|
if stream_result.is_interrupted:
|
||||||
|
ot.add_event(
|
||||||
|
"chat.interrupted",
|
||||||
|
{
|
||||||
|
"chat.flow": "resume",
|
||||||
|
},
|
||||||
|
)
|
||||||
usage_summary = accumulator.per_message_summary()
|
usage_summary = accumulator.per_message_summary()
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[token_usage] interrupted resume_chat: calls=%d total=%d cost_micros=%d summary=%s",
|
"[token_usage] interrupted resume_chat: calls=%d total=%d cost_micros=%d summary=%s",
|
||||||
|
|
@ -2869,9 +2958,11 @@ async def stream_resume_chat(
|
||||||
error_extra,
|
error_extra,
|
||||||
) = _classify_stream_exception(e, flow_label="resume")
|
) = _classify_stream_exception(e, flow_label="resume")
|
||||||
chat_outcome = error_code or error_kind or "error"
|
chat_outcome = error_code or error_kind or "error"
|
||||||
|
chat_error_category = ot_metrics.categorize_exception(e)
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
chat_span.set_attribute("chat.outcome", chat_outcome)
|
chat_span.set_attribute("chat.outcome", chat_outcome)
|
||||||
chat_span.record_exception(e)
|
chat_span.set_attribute("error.category", chat_error_category)
|
||||||
|
ot.record_error(chat_span, e)
|
||||||
error_message = f"Error during resume: {e!s}"
|
error_message = f"Error during resume: {e!s}"
|
||||||
print(f"[stream_resume_chat] {error_message}")
|
print(f"[stream_resume_chat] {error_message}")
|
||||||
print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}")
|
print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}")
|
||||||
|
|
@ -3033,5 +3124,6 @@ async def stream_resume_chat(
|
||||||
flow="resume",
|
flow="resume",
|
||||||
outcome=chat_outcome,
|
outcome=chat_outcome,
|
||||||
agent_mode=chat_agent_mode,
|
agent_mode=chat_agent_mode,
|
||||||
|
error_category=chat_error_category,
|
||||||
)
|
)
|
||||||
chat_span_cm.__exit__(None, None, None)
|
chat_span_cm.__exit__(*sys.exc_info())
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue