diff --git a/surfsense_backend/app/tasks/chat/streaming/__init__.py b/surfsense_backend/app/tasks/chat/streaming/__init__.py new file mode 100644 index 000000000..bb06cc021 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/__init__.py @@ -0,0 +1,3 @@ +"""Chat streaming orchestrator and event relay.""" + +from __future__ import annotations diff --git a/surfsense_backend/app/tasks/chat/streaming/errors/__init__.py b/surfsense_backend/app/tasks/chat/streaming/errors/__init__.py new file mode 100644 index 000000000..02284d4b0 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/errors/__init__.py @@ -0,0 +1,3 @@ +"""Error classification, structured logging, and terminal-error SSE emission.""" + +from __future__ import annotations diff --git a/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py b/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py new file mode 100644 index 000000000..3af2b9f9f --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py @@ -0,0 +1,187 @@ +"""Classify stream exceptions for logging and client error payloads.""" + +from __future__ import annotations + +import json +import logging +import time +from typing import Any, Literal + +from app.agents.new_chat.errors import BusyError +from app.agents.new_chat.middleware.busy_mutex import ( + get_cancel_state, + is_cancel_requested, +) + +TURN_CANCELLING_INITIAL_DELAY_MS = 200 +TURN_CANCELLING_BACKOFF_FACTOR = 2 +TURN_CANCELLING_MAX_DELAY_MS = 1500 + + +def compute_turn_cancelling_retry_delay(attempt: int) -> int: + if attempt < 1: + attempt = 1 + delay = TURN_CANCELLING_INITIAL_DELAY_MS * ( + TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1) + ) + return min(delay, TURN_CANCELLING_MAX_DELAY_MS) + + +def log_chat_stream_error( + *, + flow: Literal["new", "resume", "regenerate"], + error_kind: str, + error_code: str | None, + severity: Literal["info", "warn", "error"], + is_expected: bool, + request_id: str | None, + thread_id: int | None, + search_space_id: int | None, + user_id: str | None, + message: str, + extra: dict[str, Any] | None = None, +) -> None: + payload: dict[str, Any] = { + "event": "chat_stream_error", + "flow": flow, + "error_kind": error_kind, + "error_code": error_code, + "severity": severity, + "is_expected": is_expected, + "request_id": request_id or "unknown", + "thread_id": thread_id, + "search_space_id": search_space_id, + "user_id": user_id, + "message": message, + } + if extra: + payload.update(extra) + + logger = logging.getLogger(__name__) + rendered = json.dumps(payload, ensure_ascii=False) + if severity == "error": + logger.error("[chat_stream_error] %s", rendered) + elif severity == "warn": + logger.warning("[chat_stream_error] %s", rendered) + else: + logger.info("[chat_stream_error] %s", rendered) + + +def _parse_error_payload(message: str) -> dict[str, Any] | None: + candidates = [message] + first_brace_idx = message.find("{") + if first_brace_idx >= 0: + candidates.append(message[first_brace_idx:]) + + for candidate in candidates: + try: + parsed = json.loads(candidate) + if isinstance(parsed, dict): + return parsed + except Exception: + continue + return None + + +def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None: + if not isinstance(parsed, dict): + return None + candidates: list[Any] = [parsed.get("code")] + nested = parsed.get("error") + if isinstance(nested, dict): + candidates.append(nested.get("code")) + for value in candidates: + try: + if value is None: + continue + return int(value) + except Exception: + continue + return None + + +def is_provider_rate_limited(exc: BaseException) -> bool: + """Return True if the exception looks like an upstream HTTP 429 / rate limit.""" + raw = str(exc) + lowered = raw.lower() + if "ratelimit" in type(exc).__name__.lower(): + return True + parsed = _parse_error_payload(raw) + provider_code = _extract_provider_error_code(parsed) + if provider_code == 429: + return True + + provider_error_type = "" + if parsed: + top_type = parsed.get("type") + if isinstance(top_type, str): + provider_error_type = top_type.lower() + nested = parsed.get("error") + if isinstance(nested, dict): + nested_type = nested.get("type") + if isinstance(nested_type, str): + provider_error_type = nested_type.lower() + if provider_error_type == "rate_limit_error": + return True + + return ( + "rate limited" in lowered + or "rate-limited" in lowered + or "temporarily rate-limited upstream" in lowered + ) + + +def classify_stream_exception( + exc: Exception, + *, + flow_label: str, +) -> tuple[ + str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None +]: + """Return kind, code, severity, expected flag, message, and optional extra dict.""" + raw = str(exc) + if isinstance(exc, BusyError) or "Thread is busy with another request" in raw: + busy_thread_id = str(exc.request_id) if isinstance(exc, BusyError) else None + if busy_thread_id and is_cancel_requested(busy_thread_id): + cancel_state = get_cancel_state(busy_thread_id) + attempt = cancel_state[0] if cancel_state else 1 + retry_after_ms = compute_turn_cancelling_retry_delay(attempt) + retry_after_at = int(time.time() * 1000) + retry_after_ms + return ( + "thread_busy", + "TURN_CANCELLING", + "info", + True, + "A previous response is still stopping. Please try again in a moment.", + { + "retry_after_ms": retry_after_ms, + "retry_after_at": retry_after_at, + }, + ) + return ( + "thread_busy", + "THREAD_BUSY", + "warn", + True, + "Another response is still finishing for this thread. Please try again in a moment.", + None, + ) + + if is_provider_rate_limited(exc): + return ( + "rate_limited", + "RATE_LIMITED", + "warn", + True, + "This model is temporarily rate-limited. Please try again in a few seconds or switch models.", + None, + ) + + return ( + "server_error", + "SERVER_ERROR", + "error", + False, + f"Error during {flow_label}: {raw}", + None, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/errors/emitter.py b/surfsense_backend/app/tasks/chat/streaming/errors/emitter.py new file mode 100644 index 000000000..95806ab87 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/errors/emitter.py @@ -0,0 +1,38 @@ +"""Emit one terminal error SSE frame and log via the stream error classifier.""" + +from __future__ import annotations + +from typing import Any, Literal + +from .classifier import log_chat_stream_error + + +def emit_stream_terminal_error( + *, + streaming_service: Any, + flow: Literal["new", "resume", "regenerate"], + request_id: str | None, + thread_id: int, + search_space_id: int, + user_id: str | None, + message: str, + error_kind: str = "server_error", + error_code: str = "SERVER_ERROR", + severity: Literal["info", "warn", "error"] = "error", + is_expected: bool = False, + extra: dict[str, Any] | None = None, +) -> str: + log_chat_stream_error( + flow=flow, + error_kind=error_kind, + error_code=error_code, + severity=severity, + is_expected=is_expected, + request_id=request_id, + thread_id=thread_id, + search_space_id=search_space_id, + user_id=user_id, + message=message, + extra=extra, + ) + return streaming_service.format_error(message, error_code=error_code, extra=extra) diff --git a/surfsense_backend/app/tasks/chat/streaming/helpers/__init__.py b/surfsense_backend/app/tasks/chat/streaming/helpers/__init__.py new file mode 100644 index 000000000..151dfdaac --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/helpers/__init__.py @@ -0,0 +1,3 @@ +"""Pure helpers for chat streaming.""" + +from __future__ import annotations diff --git a/surfsense_backend/app/tasks/chat/streaming/helpers/chunk_parts.py b/surfsense_backend/app/tasks/chat/streaming/helpers/chunk_parts.py new file mode 100644 index 000000000..48b44fc1d --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/helpers/chunk_parts.py @@ -0,0 +1,60 @@ +"""Split a model chunk into text, reasoning, and tool-call fragment lists.""" + +from __future__ import annotations + +from typing import Any + + +def extract_chunk_parts(chunk: Any) -> dict[str, Any]: + """Return dict with keys text, reasoning, and tool_call_chunks (merged from chunk fields).""" + out: dict[str, Any] = {"text": "", "reasoning": "", "tool_call_chunks": []} + if chunk is None: + return out + + content = getattr(chunk, "content", None) + if isinstance(content, str): + if content: + out["text"] = content + elif isinstance(content, list): + text_parts: list[str] = [] + reasoning_parts: list[str] = [] + for block in content: + if not isinstance(block, dict): + continue + block_type = block.get("type") + if block_type == "text": + value = block.get("text") or block.get("content") or "" + if isinstance(value, str) and value: + text_parts.append(value) + elif block_type == "reasoning": + value = ( + block.get("reasoning") + or block.get("text") + or block.get("content") + or "" + ) + if isinstance(value, str) and value: + reasoning_parts.append(value) + elif block_type in ("tool_call_chunk", "tool_use"): + out["tool_call_chunks"].append(block) + if text_parts: + out["text"] = "".join(text_parts) + if reasoning_parts: + out["reasoning"] = "".join(reasoning_parts) + + additional = getattr(chunk, "additional_kwargs", None) or {} + if isinstance(additional, dict): + extra_reasoning = additional.get("reasoning_content") + if isinstance(extra_reasoning, str) and extra_reasoning: + existing = out["reasoning"] + out["reasoning"] = ( + (existing + extra_reasoning) if existing else extra_reasoning + ) + + extra_tool_chunks = getattr(chunk, "tool_call_chunks", None) + if isinstance(extra_tool_chunks, list): + for tcc in extra_tool_chunks: + if isinstance(tcc, dict): + out["tool_call_chunks"].append(tcc) + + return out diff --git a/surfsense_backend/app/tasks/chat/streaming/helpers/interrupt_inspector.py b/surfsense_backend/app/tasks/chat/streaming/helpers/interrupt_inspector.py new file mode 100644 index 000000000..dca099b3f --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/helpers/interrupt_inspector.py @@ -0,0 +1,47 @@ +"""Read the first interrupt payload from a LangGraph state snapshot.""" + +from __future__ import annotations + +from typing import Any + + +def first_interrupt_value(state: Any) -> dict[str, Any] | None: + """Return the first interrupt payload across all snapshot tasks.""" + + def _extract(candidate: Any) -> dict[str, Any] | None: + if isinstance(candidate, dict): + value = candidate.get("value", candidate) + return value if isinstance(value, dict) else None + value = getattr(candidate, "value", None) + if isinstance(value, dict): + return value + if isinstance(candidate, list | tuple): + for item in candidate: + extracted = _extract(item) + if extracted is not None: + return extracted + return None + + for task in getattr(state, "tasks", ()) or (): + try: + interrupts = getattr(task, "interrupts", ()) or () + except (AttributeError, IndexError, TypeError): + interrupts = () + if not interrupts: + extracted = _extract(task) + if extracted is not None: + return extracted + continue + for interrupt_item in interrupts: + extracted = _extract(interrupt_item) + if extracted is not None: + return extracted + + try: + state_interrupts = getattr(state, "interrupts", ()) or () + except (AttributeError, IndexError, TypeError): + state_interrupts = () + extracted = _extract(state_interrupts) + if extracted is not None: + return extracted + return None diff --git a/surfsense_backend/app/tasks/chat/streaming/helpers/tool_call_matching.py b/surfsense_backend/app/tasks/chat/streaming/helpers/tool_call_matching.py new file mode 100644 index 000000000..fbe4c94b7 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/helpers/tool_call_matching.py @@ -0,0 +1,32 @@ +"""Match buffered model tool-call chunks to a tool start when ids were missing.""" + +from __future__ import annotations + +from typing import Any + + +def match_buffered_langchain_tool_call_id( + pending_tool_call_chunks: list[dict[str, Any]], + tool_name: str, + run_id: str, + lc_tool_call_id_by_run: dict[str, str], +) -> str | None: + matched_idx: int | None = None + for idx, tcc in enumerate(pending_tool_call_chunks): + if tcc.get("name") == tool_name and tcc.get("id"): + matched_idx = idx + break + if matched_idx is None: + for idx, tcc in enumerate(pending_tool_call_chunks): + if tcc.get("id"): + matched_idx = idx + break + if matched_idx is None: + return None + matched = pending_tool_call_chunks.pop(matched_idx) + candidate = matched.get("id") + if isinstance(candidate, str) and candidate: + if run_id: + lc_tool_call_id_by_run[run_id] = candidate + return candidate + return None diff --git a/surfsense_backend/app/tasks/chat/streaming/helpers/tool_output.py b/surfsense_backend/app/tasks/chat/streaming/helpers/tool_output.py new file mode 100644 index 000000000..a7c401dee --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/helpers/tool_output.py @@ -0,0 +1,43 @@ +"""Normalize filesystem tool payloads for SSE cards and messages.""" + +from __future__ import annotations + +import json +from typing import Any + + +def tool_output_to_text(tool_output: Any) -> str: + if isinstance(tool_output, dict): + if isinstance(tool_output.get("result"), str): + return tool_output["result"] + if isinstance(tool_output.get("error"), str): + return tool_output["error"] + return json.dumps(tool_output, ensure_ascii=False) + return str(tool_output) + + +def tool_output_has_error(tool_output: Any) -> bool: + if isinstance(tool_output, dict): + if tool_output.get("error"): + return True + result = tool_output.get("result") + return bool( + isinstance(result, str) and result.strip().lower().startswith("error:") + ) + if isinstance(tool_output, str): + return tool_output.strip().lower().startswith("error:") + return False + + +def extract_resolved_file_path( + *, tool_name: str, tool_output: Any, tool_input: Any | None = None +) -> str | None: + if isinstance(tool_output, dict): + path_value = tool_output.get("path") + if isinstance(path_value, str) and path_value.strip(): + return path_value.strip() + if tool_name in ("write_file", "edit_file") and isinstance(tool_input, dict): + file_path = tool_input.get("file_path") + if isinstance(file_path, str) and file_path.strip(): + return file_path.strip() + return None diff --git a/surfsense_backend/app/tasks/chat/streaming/stream_result.py b/surfsense_backend/app/tasks/chat/streaming/stream_result.py new file mode 100644 index 000000000..8ea3bd295 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/stream_result.py @@ -0,0 +1,28 @@ +"""Mutable facts collected while streaming one agent turn.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class StreamResult: + accumulated_text: str = "" + is_interrupted: bool = False + interrupt_value: dict[str, Any] | None = None + sandbox_files: list[str] = field(default_factory=list) + agent_called_update_memory: bool = False + request_id: str | None = None + turn_id: str = "" + filesystem_mode: str = "cloud" + client_platform: str = "web" + intent_detected: str = "chat_only" + intent_confidence: float = 0.0 + write_attempted: bool = False + write_succeeded: bool = False + verification_succeeded: bool = False + commit_gate_passed: bool = True + commit_gate_reason: str = "" + assistant_message_id: int | None = None + content_builder: Any | None = field(default=None, repr=False)