mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-13 17:52:38 +02:00
Add chat streaming error classification, helpers, and StreamResult.
This commit is contained in:
parent
366122da6e
commit
c25b78c304
10 changed files with 444 additions and 0 deletions
3
surfsense_backend/app/tasks/chat/streaming/__init__.py
Normal file
3
surfsense_backend/app/tasks/chat/streaming/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
"""Chat streaming orchestrator and event relay."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
"""Error classification, structured logging, and terminal-error SSE emission."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
187
surfsense_backend/app/tasks/chat/streaming/errors/classifier.py
Normal file
187
surfsense_backend/app/tasks/chat/streaming/errors/classifier.py
Normal file
|
|
@ -0,0 +1,187 @@
|
||||||
|
"""Classify stream exceptions for logging and client error payloads."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from app.agents.new_chat.errors import BusyError
|
||||||
|
from app.agents.new_chat.middleware.busy_mutex import (
|
||||||
|
get_cancel_state,
|
||||||
|
is_cancel_requested,
|
||||||
|
)
|
||||||
|
|
||||||
|
TURN_CANCELLING_INITIAL_DELAY_MS = 200
|
||||||
|
TURN_CANCELLING_BACKOFF_FACTOR = 2
|
||||||
|
TURN_CANCELLING_MAX_DELAY_MS = 1500
|
||||||
|
|
||||||
|
|
||||||
|
def compute_turn_cancelling_retry_delay(attempt: int) -> int:
|
||||||
|
if attempt < 1:
|
||||||
|
attempt = 1
|
||||||
|
delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
|
||||||
|
TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
|
||||||
|
)
|
||||||
|
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
|
||||||
|
|
||||||
|
|
||||||
|
def log_chat_stream_error(
|
||||||
|
*,
|
||||||
|
flow: Literal["new", "resume", "regenerate"],
|
||||||
|
error_kind: str,
|
||||||
|
error_code: str | None,
|
||||||
|
severity: Literal["info", "warn", "error"],
|
||||||
|
is_expected: bool,
|
||||||
|
request_id: str | None,
|
||||||
|
thread_id: int | None,
|
||||||
|
search_space_id: int | None,
|
||||||
|
user_id: str | None,
|
||||||
|
message: str,
|
||||||
|
extra: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"event": "chat_stream_error",
|
||||||
|
"flow": flow,
|
||||||
|
"error_kind": error_kind,
|
||||||
|
"error_code": error_code,
|
||||||
|
"severity": severity,
|
||||||
|
"is_expected": is_expected,
|
||||||
|
"request_id": request_id or "unknown",
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"search_space_id": search_space_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"message": message,
|
||||||
|
}
|
||||||
|
if extra:
|
||||||
|
payload.update(extra)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
rendered = json.dumps(payload, ensure_ascii=False)
|
||||||
|
if severity == "error":
|
||||||
|
logger.error("[chat_stream_error] %s", rendered)
|
||||||
|
elif severity == "warn":
|
||||||
|
logger.warning("[chat_stream_error] %s", rendered)
|
||||||
|
else:
|
||||||
|
logger.info("[chat_stream_error] %s", rendered)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_error_payload(message: str) -> dict[str, Any] | None:
|
||||||
|
candidates = [message]
|
||||||
|
first_brace_idx = message.find("{")
|
||||||
|
if first_brace_idx >= 0:
|
||||||
|
candidates.append(message[first_brace_idx:])
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
try:
|
||||||
|
parsed = json.loads(candidate)
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
return parsed
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None:
|
||||||
|
if not isinstance(parsed, dict):
|
||||||
|
return None
|
||||||
|
candidates: list[Any] = [parsed.get("code")]
|
||||||
|
nested = parsed.get("error")
|
||||||
|
if isinstance(nested, dict):
|
||||||
|
candidates.append(nested.get("code"))
|
||||||
|
for value in candidates:
|
||||||
|
try:
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
return int(value)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def is_provider_rate_limited(exc: BaseException) -> bool:
|
||||||
|
"""Return True if the exception looks like an upstream HTTP 429 / rate limit."""
|
||||||
|
raw = str(exc)
|
||||||
|
lowered = raw.lower()
|
||||||
|
if "ratelimit" in type(exc).__name__.lower():
|
||||||
|
return True
|
||||||
|
parsed = _parse_error_payload(raw)
|
||||||
|
provider_code = _extract_provider_error_code(parsed)
|
||||||
|
if provider_code == 429:
|
||||||
|
return True
|
||||||
|
|
||||||
|
provider_error_type = ""
|
||||||
|
if parsed:
|
||||||
|
top_type = parsed.get("type")
|
||||||
|
if isinstance(top_type, str):
|
||||||
|
provider_error_type = top_type.lower()
|
||||||
|
nested = parsed.get("error")
|
||||||
|
if isinstance(nested, dict):
|
||||||
|
nested_type = nested.get("type")
|
||||||
|
if isinstance(nested_type, str):
|
||||||
|
provider_error_type = nested_type.lower()
|
||||||
|
if provider_error_type == "rate_limit_error":
|
||||||
|
return True
|
||||||
|
|
||||||
|
return (
|
||||||
|
"rate limited" in lowered
|
||||||
|
or "rate-limited" in lowered
|
||||||
|
or "temporarily rate-limited upstream" in lowered
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def classify_stream_exception(
|
||||||
|
exc: Exception,
|
||||||
|
*,
|
||||||
|
flow_label: str,
|
||||||
|
) -> tuple[
|
||||||
|
str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None
|
||||||
|
]:
|
||||||
|
"""Return kind, code, severity, expected flag, message, and optional extra dict."""
|
||||||
|
raw = str(exc)
|
||||||
|
if isinstance(exc, BusyError) or "Thread is busy with another request" in raw:
|
||||||
|
busy_thread_id = str(exc.request_id) if isinstance(exc, BusyError) else None
|
||||||
|
if busy_thread_id and is_cancel_requested(busy_thread_id):
|
||||||
|
cancel_state = get_cancel_state(busy_thread_id)
|
||||||
|
attempt = cancel_state[0] if cancel_state else 1
|
||||||
|
retry_after_ms = compute_turn_cancelling_retry_delay(attempt)
|
||||||
|
retry_after_at = int(time.time() * 1000) + retry_after_ms
|
||||||
|
return (
|
||||||
|
"thread_busy",
|
||||||
|
"TURN_CANCELLING",
|
||||||
|
"info",
|
||||||
|
True,
|
||||||
|
"A previous response is still stopping. Please try again in a moment.",
|
||||||
|
{
|
||||||
|
"retry_after_ms": retry_after_ms,
|
||||||
|
"retry_after_at": retry_after_at,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
"thread_busy",
|
||||||
|
"THREAD_BUSY",
|
||||||
|
"warn",
|
||||||
|
True,
|
||||||
|
"Another response is still finishing for this thread. Please try again in a moment.",
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_provider_rate_limited(exc):
|
||||||
|
return (
|
||||||
|
"rate_limited",
|
||||||
|
"RATE_LIMITED",
|
||||||
|
"warn",
|
||||||
|
True,
|
||||||
|
"This model is temporarily rate-limited. Please try again in a few seconds or switch models.",
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
"server_error",
|
||||||
|
"SERVER_ERROR",
|
||||||
|
"error",
|
||||||
|
False,
|
||||||
|
f"Error during {flow_label}: {raw}",
|
||||||
|
None,
|
||||||
|
)
|
||||||
38
surfsense_backend/app/tasks/chat/streaming/errors/emitter.py
Normal file
38
surfsense_backend/app/tasks/chat/streaming/errors/emitter.py
Normal file
|
|
@ -0,0 +1,38 @@
|
||||||
|
"""Emit one terminal error SSE frame and log via the stream error classifier."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from .classifier import log_chat_stream_error
|
||||||
|
|
||||||
|
|
||||||
|
def emit_stream_terminal_error(
|
||||||
|
*,
|
||||||
|
streaming_service: Any,
|
||||||
|
flow: Literal["new", "resume", "regenerate"],
|
||||||
|
request_id: str | None,
|
||||||
|
thread_id: int,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str | None,
|
||||||
|
message: str,
|
||||||
|
error_kind: str = "server_error",
|
||||||
|
error_code: str = "SERVER_ERROR",
|
||||||
|
severity: Literal["info", "warn", "error"] = "error",
|
||||||
|
is_expected: bool = False,
|
||||||
|
extra: dict[str, Any] | None = None,
|
||||||
|
) -> str:
|
||||||
|
log_chat_stream_error(
|
||||||
|
flow=flow,
|
||||||
|
error_kind=error_kind,
|
||||||
|
error_code=error_code,
|
||||||
|
severity=severity,
|
||||||
|
is_expected=is_expected,
|
||||||
|
request_id=request_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
message=message,
|
||||||
|
extra=extra,
|
||||||
|
)
|
||||||
|
return streaming_service.format_error(message, error_code=error_code, extra=extra)
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
"""Pure helpers for chat streaming."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
@ -0,0 +1,60 @@
|
||||||
|
"""Split a model chunk into text, reasoning, and tool-call fragment lists."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def extract_chunk_parts(chunk: Any) -> dict[str, Any]:
|
||||||
|
"""Return dict with keys text, reasoning, and tool_call_chunks (merged from chunk fields)."""
|
||||||
|
out: dict[str, Any] = {"text": "", "reasoning": "", "tool_call_chunks": []}
|
||||||
|
if chunk is None:
|
||||||
|
return out
|
||||||
|
|
||||||
|
content = getattr(chunk, "content", None)
|
||||||
|
if isinstance(content, str):
|
||||||
|
if content:
|
||||||
|
out["text"] = content
|
||||||
|
elif isinstance(content, list):
|
||||||
|
text_parts: list[str] = []
|
||||||
|
reasoning_parts: list[str] = []
|
||||||
|
for block in content:
|
||||||
|
if not isinstance(block, dict):
|
||||||
|
continue
|
||||||
|
block_type = block.get("type")
|
||||||
|
if block_type == "text":
|
||||||
|
value = block.get("text") or block.get("content") or ""
|
||||||
|
if isinstance(value, str) and value:
|
||||||
|
text_parts.append(value)
|
||||||
|
elif block_type == "reasoning":
|
||||||
|
value = (
|
||||||
|
block.get("reasoning")
|
||||||
|
or block.get("text")
|
||||||
|
or block.get("content")
|
||||||
|
or ""
|
||||||
|
)
|
||||||
|
if isinstance(value, str) and value:
|
||||||
|
reasoning_parts.append(value)
|
||||||
|
elif block_type in ("tool_call_chunk", "tool_use"):
|
||||||
|
out["tool_call_chunks"].append(block)
|
||||||
|
if text_parts:
|
||||||
|
out["text"] = "".join(text_parts)
|
||||||
|
if reasoning_parts:
|
||||||
|
out["reasoning"] = "".join(reasoning_parts)
|
||||||
|
|
||||||
|
additional = getattr(chunk, "additional_kwargs", None) or {}
|
||||||
|
if isinstance(additional, dict):
|
||||||
|
extra_reasoning = additional.get("reasoning_content")
|
||||||
|
if isinstance(extra_reasoning, str) and extra_reasoning:
|
||||||
|
existing = out["reasoning"]
|
||||||
|
out["reasoning"] = (
|
||||||
|
(existing + extra_reasoning) if existing else extra_reasoning
|
||||||
|
)
|
||||||
|
|
||||||
|
extra_tool_chunks = getattr(chunk, "tool_call_chunks", None)
|
||||||
|
if isinstance(extra_tool_chunks, list):
|
||||||
|
for tcc in extra_tool_chunks:
|
||||||
|
if isinstance(tcc, dict):
|
||||||
|
out["tool_call_chunks"].append(tcc)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
@ -0,0 +1,47 @@
|
||||||
|
"""Read the first interrupt payload from a LangGraph state snapshot."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def first_interrupt_value(state: Any) -> dict[str, Any] | None:
|
||||||
|
"""Return the first interrupt payload across all snapshot tasks."""
|
||||||
|
|
||||||
|
def _extract(candidate: Any) -> dict[str, Any] | None:
|
||||||
|
if isinstance(candidate, dict):
|
||||||
|
value = candidate.get("value", candidate)
|
||||||
|
return value if isinstance(value, dict) else None
|
||||||
|
value = getattr(candidate, "value", None)
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return value
|
||||||
|
if isinstance(candidate, list | tuple):
|
||||||
|
for item in candidate:
|
||||||
|
extracted = _extract(item)
|
||||||
|
if extracted is not None:
|
||||||
|
return extracted
|
||||||
|
return None
|
||||||
|
|
||||||
|
for task in getattr(state, "tasks", ()) or ():
|
||||||
|
try:
|
||||||
|
interrupts = getattr(task, "interrupts", ()) or ()
|
||||||
|
except (AttributeError, IndexError, TypeError):
|
||||||
|
interrupts = ()
|
||||||
|
if not interrupts:
|
||||||
|
extracted = _extract(task)
|
||||||
|
if extracted is not None:
|
||||||
|
return extracted
|
||||||
|
continue
|
||||||
|
for interrupt_item in interrupts:
|
||||||
|
extracted = _extract(interrupt_item)
|
||||||
|
if extracted is not None:
|
||||||
|
return extracted
|
||||||
|
|
||||||
|
try:
|
||||||
|
state_interrupts = getattr(state, "interrupts", ()) or ()
|
||||||
|
except (AttributeError, IndexError, TypeError):
|
||||||
|
state_interrupts = ()
|
||||||
|
extracted = _extract(state_interrupts)
|
||||||
|
if extracted is not None:
|
||||||
|
return extracted
|
||||||
|
return None
|
||||||
|
|
@ -0,0 +1,32 @@
|
||||||
|
"""Match buffered model tool-call chunks to a tool start when ids were missing."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def match_buffered_langchain_tool_call_id(
|
||||||
|
pending_tool_call_chunks: list[dict[str, Any]],
|
||||||
|
tool_name: str,
|
||||||
|
run_id: str,
|
||||||
|
lc_tool_call_id_by_run: dict[str, str],
|
||||||
|
) -> str | None:
|
||||||
|
matched_idx: int | None = None
|
||||||
|
for idx, tcc in enumerate(pending_tool_call_chunks):
|
||||||
|
if tcc.get("name") == tool_name and tcc.get("id"):
|
||||||
|
matched_idx = idx
|
||||||
|
break
|
||||||
|
if matched_idx is None:
|
||||||
|
for idx, tcc in enumerate(pending_tool_call_chunks):
|
||||||
|
if tcc.get("id"):
|
||||||
|
matched_idx = idx
|
||||||
|
break
|
||||||
|
if matched_idx is None:
|
||||||
|
return None
|
||||||
|
matched = pending_tool_call_chunks.pop(matched_idx)
|
||||||
|
candidate = matched.get("id")
|
||||||
|
if isinstance(candidate, str) and candidate:
|
||||||
|
if run_id:
|
||||||
|
lc_tool_call_id_by_run[run_id] = candidate
|
||||||
|
return candidate
|
||||||
|
return None
|
||||||
|
|
@ -0,0 +1,43 @@
|
||||||
|
"""Normalize filesystem tool payloads for SSE cards and messages."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def tool_output_to_text(tool_output: Any) -> str:
|
||||||
|
if isinstance(tool_output, dict):
|
||||||
|
if isinstance(tool_output.get("result"), str):
|
||||||
|
return tool_output["result"]
|
||||||
|
if isinstance(tool_output.get("error"), str):
|
||||||
|
return tool_output["error"]
|
||||||
|
return json.dumps(tool_output, ensure_ascii=False)
|
||||||
|
return str(tool_output)
|
||||||
|
|
||||||
|
|
||||||
|
def tool_output_has_error(tool_output: Any) -> bool:
|
||||||
|
if isinstance(tool_output, dict):
|
||||||
|
if tool_output.get("error"):
|
||||||
|
return True
|
||||||
|
result = tool_output.get("result")
|
||||||
|
return bool(
|
||||||
|
isinstance(result, str) and result.strip().lower().startswith("error:")
|
||||||
|
)
|
||||||
|
if isinstance(tool_output, str):
|
||||||
|
return tool_output.strip().lower().startswith("error:")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def extract_resolved_file_path(
|
||||||
|
*, tool_name: str, tool_output: Any, tool_input: Any | None = None
|
||||||
|
) -> str | None:
|
||||||
|
if isinstance(tool_output, dict):
|
||||||
|
path_value = tool_output.get("path")
|
||||||
|
if isinstance(path_value, str) and path_value.strip():
|
||||||
|
return path_value.strip()
|
||||||
|
if tool_name in ("write_file", "edit_file") and isinstance(tool_input, dict):
|
||||||
|
file_path = tool_input.get("file_path")
|
||||||
|
if isinstance(file_path, str) and file_path.strip():
|
||||||
|
return file_path.strip()
|
||||||
|
return None
|
||||||
28
surfsense_backend/app/tasks/chat/streaming/stream_result.py
Normal file
28
surfsense_backend/app/tasks/chat/streaming/stream_result.py
Normal file
|
|
@ -0,0 +1,28 @@
|
||||||
|
"""Mutable facts collected while streaming one agent turn."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StreamResult:
|
||||||
|
accumulated_text: str = ""
|
||||||
|
is_interrupted: bool = False
|
||||||
|
interrupt_value: dict[str, Any] | None = None
|
||||||
|
sandbox_files: list[str] = field(default_factory=list)
|
||||||
|
agent_called_update_memory: bool = False
|
||||||
|
request_id: str | None = None
|
||||||
|
turn_id: str = ""
|
||||||
|
filesystem_mode: str = "cloud"
|
||||||
|
client_platform: str = "web"
|
||||||
|
intent_detected: str = "chat_only"
|
||||||
|
intent_confidence: float = 0.0
|
||||||
|
write_attempted: bool = False
|
||||||
|
write_succeeded: bool = False
|
||||||
|
verification_succeeded: bool = False
|
||||||
|
commit_gate_passed: bool = True
|
||||||
|
commit_gate_reason: str = ""
|
||||||
|
assistant_message_id: int | None = None
|
||||||
|
content_builder: Any | None = field(default=None, repr=False)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue