""" Streaming task for the new SurfSense deep agent chat. This module streams responses from the deep agent using the Vercel AI SDK Data Stream Protocol (SSE format). Supports loading LLM configurations from: - YAML files (negative IDs for global configs) - NewLLMConfig database table (positive IDs for user-created configs with prompt settings) """ import ast import asyncio import contextlib import gc import json import logging import re import time from collections.abc import AsyncGenerator from dataclasses import dataclass, field from typing import Any from uuid import UUID import anyio from langchain_core.messages import HumanMessage from sqlalchemy import func from sqlalchemy.future import select from sqlalchemy.orm import selectinload from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent from app.agents.new_chat.checkpointer import get_checkpointer from app.agents.new_chat.llm_config import ( AgentConfig, create_chat_litellm_from_agent_config, create_chat_litellm_from_config, load_agent_config, load_global_llm_config_by_id, ) from app.agents.new_chat.memory_extraction import ( extract_and_save_memory, extract_and_save_team_memory, ) from app.db import ( ChatVisibility, NewChatMessage, NewChatThread, Report, SearchSourceConnectorType, SurfsenseDocsDocument, async_session_maker, shielded_async_session, ) from app.prompts import TITLE_GENERATION_PROMPT from app.services.chat_session_state_service import ( clear_ai_responding, set_ai_responding, ) from app.services.connector_service import ConnectorService from app.services.new_streaming_service import VercelStreamingService from app.utils.content_utils import bootstrap_history_from_db from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap _background_tasks: set[asyncio.Task] = set() _perf_log = get_perf_logger() def format_mentioned_surfsense_docs_as_context( documents: list[SurfsenseDocsDocument], ) -> str: """Format mentioned SurfSense documentation as context for the agent.""" if not documents: return "" context_parts = [""] context_parts.append( "The user has explicitly mentioned the following SurfSense documentation pages. " "These are official documentation about how to use SurfSense and should be used to answer questions about the application. " "Use [citation:CHUNK_ID] format for citations (e.g., [citation:doc-123])." ) for doc in documents: metadata_json = json.dumps({"source": doc.source}, ensure_ascii=False) context_parts.append("") context_parts.append("") context_parts.append(f" doc-{doc.id}") context_parts.append(" SURFSENSE_DOCS") context_parts.append(f" <![CDATA[{doc.title}]]>") context_parts.append(f" ") context_parts.append( f" " ) context_parts.append("") context_parts.append("") context_parts.append("") if hasattr(doc, "chunks") and doc.chunks: for chunk in doc.chunks: context_parts.append( f" " ) else: context_parts.append( f" " ) context_parts.append("") context_parts.append("") context_parts.append("") context_parts.append("") return "\n".join(context_parts) def extract_todos_from_deepagents(command_output) -> dict: """ Extract todos from deepagents' TodoListMiddleware Command output. deepagents returns a Command object with: - Command.update['todos'] = [{'content': '...', 'status': '...'}] Returns the todos directly (no transformation needed - UI matches deepagents format). """ todos_data = [] if hasattr(command_output, "update"): # It's a Command object from deepagents update = command_output.update todos_data = update.get("todos", []) elif isinstance(command_output, dict): # Already a dict - check if it has todos directly or in update if "todos" in command_output: todos_data = command_output.get("todos", []) elif "update" in command_output and isinstance(command_output["update"], dict): todos_data = command_output["update"].get("todos", []) return {"todos": todos_data} @dataclass class StreamResult: accumulated_text: str = "" is_interrupted: bool = False interrupt_value: dict[str, Any] | None = None sandbox_files: list[str] = field(default_factory=list) agent_called_update_memory: bool = False async def _stream_agent_events( agent: Any, config: dict[str, Any], input_data: Any, streaming_service: VercelStreamingService, result: StreamResult, step_prefix: str = "thinking", initial_step_id: str | None = None, initial_step_title: str = "", initial_step_items: list[str] | None = None, ) -> AsyncGenerator[str, None]: """Shared async generator that streams and formats astream_events from the agent. Yields SSE-formatted strings. After exhausting, inspect the ``result`` object for accumulated_text and interrupt state. Args: agent: The compiled LangGraph agent. config: LangGraph config dict (must include configurable.thread_id). input_data: The input to pass to agent.astream_events (dict or Command). streaming_service: VercelStreamingService instance for formatting events. result: Mutable StreamResult populated with accumulated_text / interrupt info. step_prefix: Prefix for thinking step IDs (e.g. "thinking" or "thinking-resume"). initial_step_id: If set, the helper inherits an already-active thinking step. initial_step_title: Title of the inherited thinking step. initial_step_items: Items of the inherited thinking step. Yields: SSE-formatted strings for each event. """ accumulated_text = "" current_text_id: str | None = None thinking_step_counter = 1 if initial_step_id else 0 tool_step_ids: dict[str, str] = {} completed_step_ids: set[str] = set() last_active_step_id: str | None = initial_step_id last_active_step_title: str = initial_step_title last_active_step_items: list[str] = initial_step_items or [] just_finished_tool: bool = False active_tool_depth: int = 0 # Track nesting: >0 means we're inside a tool called_update_memory: bool = False def next_thinking_step_id() -> str: nonlocal thinking_step_counter thinking_step_counter += 1 return f"{step_prefix}-{thinking_step_counter}" def complete_current_step() -> str | None: nonlocal last_active_step_id if last_active_step_id and last_active_step_id not in completed_step_ids: completed_step_ids.add(last_active_step_id) event = streaming_service.format_thinking_step( step_id=last_active_step_id, title=last_active_step_title, status="completed", items=last_active_step_items if last_active_step_items else None, ) last_active_step_id = None return event return None async for event in agent.astream_events(input_data, config=config, version="v2"): event_type = event.get("event", "") if event_type == "on_chat_model_stream": if active_tool_depth > 0: continue # Suppress inner-tool LLM tokens from leaking into chat if "surfsense:internal" in event.get("tags", []): continue # Suppress middleware-internal LLM tokens (e.g. KB search classification) chunk = event.get("data", {}).get("chunk") if chunk and hasattr(chunk, "content"): content = chunk.content if content and isinstance(content, str): if current_text_id is None: completion_event = complete_current_step() if completion_event: yield completion_event if just_finished_tool: last_active_step_id = None last_active_step_title = "" last_active_step_items = [] just_finished_tool = False current_text_id = streaming_service.generate_text_id() yield streaming_service.format_text_start(current_text_id) yield streaming_service.format_text_delta(current_text_id, content) accumulated_text += content elif event_type == "on_tool_start": active_tool_depth += 1 tool_name = event.get("name", "unknown_tool") run_id = event.get("run_id", "") tool_input = event.get("data", {}).get("input", {}) if current_text_id is not None: yield streaming_service.format_text_end(current_text_id) current_text_id = None if last_active_step_title != "Synthesizing response": completion_event = complete_current_step() if completion_event: yield completion_event just_finished_tool = False tool_step_id = next_thinking_step_id() tool_step_ids[run_id] = tool_step_id last_active_step_id = tool_step_id if tool_name == "ls": ls_path = ( tool_input.get("path", "/") if isinstance(tool_input, dict) else str(tool_input) ) last_active_step_title = "Listing files" last_active_step_items = [ls_path] yield streaming_service.format_thinking_step( step_id=tool_step_id, title="Listing files", status="in_progress", items=last_active_step_items, ) elif tool_name == "read_file": fp = ( tool_input.get("file_path", "") if isinstance(tool_input, dict) else str(tool_input) ) display_fp = fp if len(fp) <= 80 else "…" + fp[-77:] last_active_step_title = "Reading file" last_active_step_items = [display_fp] yield streaming_service.format_thinking_step( step_id=tool_step_id, title="Reading file", status="in_progress", items=last_active_step_items, ) elif tool_name == "write_file": fp = ( tool_input.get("file_path", "") if isinstance(tool_input, dict) else str(tool_input) ) display_fp = fp if len(fp) <= 80 else "…" + fp[-77:] last_active_step_title = "Writing file" last_active_step_items = [display_fp] yield streaming_service.format_thinking_step( step_id=tool_step_id, title="Writing file", status="in_progress", items=last_active_step_items, ) elif tool_name == "edit_file": fp = ( tool_input.get("file_path", "") if isinstance(tool_input, dict) else str(tool_input) ) display_fp = fp if len(fp) <= 80 else "…" + fp[-77:] last_active_step_title = "Editing file" last_active_step_items = [display_fp] yield streaming_service.format_thinking_step( step_id=tool_step_id, title="Editing file", status="in_progress", items=last_active_step_items, ) elif tool_name == "glob": pat = ( tool_input.get("pattern", "") if isinstance(tool_input, dict) else str(tool_input) ) base_path = ( tool_input.get("path", "/") if isinstance(tool_input, dict) else "/" ) last_active_step_title = "Searching files" last_active_step_items = [f"{pat} in {base_path}"] yield streaming_service.format_thinking_step( step_id=tool_step_id, title="Searching files", status="in_progress", items=last_active_step_items, ) elif tool_name == "grep": pat = ( tool_input.get("pattern", "") if isinstance(tool_input, dict) else str(tool_input) ) grep_path = ( tool_input.get("path", "") if isinstance(tool_input, dict) else "" ) display_pat = pat[:60] + ("…" if len(pat) > 60 else "") last_active_step_title = "Searching content" last_active_step_items = [ f'"{display_pat}"' + (f" in {grep_path}" if grep_path else "") ] yield streaming_service.format_thinking_step( step_id=tool_step_id, title="Searching content", status="in_progress", items=last_active_step_items, ) elif tool_name == "save_document": doc_title = ( tool_input.get("title", "") if isinstance(tool_input, dict) else str(tool_input) ) display_title = doc_title[:60] + ("…" if len(doc_title) > 60 else "") last_active_step_title = "Saving document" last_active_step_items = [display_title] yield streaming_service.format_thinking_step( step_id=tool_step_id, title="Saving document", status="in_progress", items=last_active_step_items, ) elif tool_name == "generate_image": prompt = ( tool_input.get("prompt", "") if isinstance(tool_input, dict) else str(tool_input) ) last_active_step_title = "Generating image" last_active_step_items = [ f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}" ] yield streaming_service.format_thinking_step( step_id=tool_step_id, title="Generating image", status="in_progress", items=last_active_step_items, ) elif tool_name == "scrape_webpage": url = ( tool_input.get("url", "") if isinstance(tool_input, dict) else str(tool_input) ) last_active_step_title = "Scraping webpage" last_active_step_items = [ f"URL: {url[:80]}{'...' if len(url) > 80 else ''}" ] yield streaming_service.format_thinking_step( step_id=tool_step_id, title="Scraping webpage", status="in_progress", items=last_active_step_items, ) elif tool_name == "generate_podcast": podcast_title = ( tool_input.get("podcast_title", "SurfSense Podcast") if isinstance(tool_input, dict) else "SurfSense Podcast" ) content_len = len( tool_input.get("source_content", "") if isinstance(tool_input, dict) else "" ) last_active_step_title = "Generating podcast" last_active_step_items = [ f"Title: {podcast_title}", f"Content: {content_len:,} characters", "Preparing audio generation...", ] yield streaming_service.format_thinking_step( step_id=tool_step_id, title="Generating podcast", status="in_progress", items=last_active_step_items, ) elif tool_name == "generate_report": report_topic = ( tool_input.get("topic", "Report") if isinstance(tool_input, dict) else "Report" ) is_revision = bool( isinstance(tool_input, dict) and tool_input.get("parent_report_id") ) step_title = "Revising report" if is_revision else "Generating report" last_active_step_title = step_title last_active_step_items = [ f"Topic: {report_topic}", "Analyzing source content...", ] yield streaming_service.format_thinking_step( step_id=tool_step_id, title=step_title, status="in_progress", items=last_active_step_items, ) elif tool_name in ("execute", "execute_code"): cmd = ( tool_input.get("command", "") if isinstance(tool_input, dict) else str(tool_input) ) display_cmd = cmd[:80] + ("…" if len(cmd) > 80 else "") last_active_step_title = "Running command" last_active_step_items = [f"$ {display_cmd}"] yield streaming_service.format_thinking_step( step_id=tool_step_id, title="Running command", status="in_progress", items=last_active_step_items, ) else: last_active_step_title = f"Using {tool_name.replace('_', ' ')}" last_active_step_items = [] yield streaming_service.format_thinking_step( step_id=tool_step_id, title=last_active_step_title, status="in_progress", ) tool_call_id = ( f"call_{run_id[:32]}" if run_id else streaming_service.generate_tool_call_id() ) yield streaming_service.format_tool_input_start(tool_call_id, tool_name) # Sanitize tool_input: strip runtime-injected non-serializable # values (e.g. LangChain ToolRuntime) before sending over SSE. if isinstance(tool_input, dict): _safe_input: dict[str, Any] = {} for _k, _v in tool_input.items(): try: json.dumps(_v) _safe_input[_k] = _v except (TypeError, ValueError, OverflowError): pass else: _safe_input = {"input": tool_input} yield streaming_service.format_tool_input_available( tool_call_id, tool_name, _safe_input, ) elif event_type == "on_tool_end": active_tool_depth = max(0, active_tool_depth - 1) run_id = event.get("run_id", "") tool_name = event.get("name", "unknown_tool") raw_output = event.get("data", {}).get("output", "") if tool_name == "update_memory": called_update_memory = True if hasattr(raw_output, "content"): content = raw_output.content if isinstance(content, str): try: tool_output = json.loads(content) except (json.JSONDecodeError, TypeError): tool_output = {"result": content} elif isinstance(content, dict): tool_output = content else: tool_output = {"result": str(content)} elif isinstance(raw_output, dict): tool_output = raw_output else: tool_output = {"result": str(raw_output) if raw_output else "completed"} tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown" original_step_id = tool_step_ids.get( run_id, f"{step_prefix}-unknown-{run_id[:8]}" ) completed_step_ids.add(original_step_id) if tool_name == "read_file": yield streaming_service.format_thinking_step( step_id=original_step_id, title="Reading file", status="completed", items=last_active_step_items, ) elif tool_name == "write_file": yield streaming_service.format_thinking_step( step_id=original_step_id, title="Writing file", status="completed", items=last_active_step_items, ) elif tool_name == "edit_file": yield streaming_service.format_thinking_step( step_id=original_step_id, title="Editing file", status="completed", items=last_active_step_items, ) elif tool_name == "glob": yield streaming_service.format_thinking_step( step_id=original_step_id, title="Searching files", status="completed", items=last_active_step_items, ) elif tool_name == "grep": yield streaming_service.format_thinking_step( step_id=original_step_id, title="Searching content", status="completed", items=last_active_step_items, ) elif tool_name == "save_document": result_str = ( tool_output.get("result", "") if isinstance(tool_output, dict) else str(tool_output) ) is_error = "Error" in result_str completed_items = [ *last_active_step_items, result_str[:80] if is_error else "Saved to knowledge base", ] yield streaming_service.format_thinking_step( step_id=original_step_id, title="Saving document", status="completed", items=completed_items, ) elif tool_name == "generate_image": if isinstance(tool_output, dict) and not tool_output.get("error"): completed_items = [ *last_active_step_items, "Image generated successfully", ] else: error_msg = ( tool_output.get("error", "Generation failed") if isinstance(tool_output, dict) else "Generation failed" ) completed_items = [*last_active_step_items, f"Error: {error_msg}"] yield streaming_service.format_thinking_step( step_id=original_step_id, title="Generating image", status="completed", items=completed_items, ) elif tool_name == "scrape_webpage": if isinstance(tool_output, dict): title = tool_output.get("title", "Webpage") word_count = tool_output.get("word_count", 0) has_error = "error" in tool_output if has_error: completed_items = [ *last_active_step_items, f"Error: {tool_output.get('error', 'Failed to scrape')[:50]}", ] else: completed_items = [ *last_active_step_items, f"Title: {title[:50]}{'...' if len(title) > 50 else ''}", f"Extracted: {word_count:,} words", ] else: completed_items = [*last_active_step_items, "Content extracted"] yield streaming_service.format_thinking_step( step_id=original_step_id, title="Scraping webpage", status="completed", items=completed_items, ) elif tool_name == "generate_podcast": podcast_status = ( tool_output.get("status", "unknown") if isinstance(tool_output, dict) else "unknown" ) podcast_title = ( tool_output.get("title", "Podcast") if isinstance(tool_output, dict) else "Podcast" ) if podcast_status == "processing": completed_items = [ f"Title: {podcast_title}", "Audio generation started", "Processing in background...", ] elif podcast_status == "already_generating": completed_items = [ f"Title: {podcast_title}", "Podcast already in progress", "Please wait for it to complete", ] elif podcast_status == "error": error_msg = ( tool_output.get("error", "Unknown error") if isinstance(tool_output, dict) else "Unknown error" ) completed_items = [ f"Title: {podcast_title}", f"Error: {error_msg[:50]}", ] else: completed_items = last_active_step_items yield streaming_service.format_thinking_step( step_id=original_step_id, title="Generating podcast", status="completed", items=completed_items, ) elif tool_name == "generate_video_presentation": vp_status = ( tool_output.get("status", "unknown") if isinstance(tool_output, dict) else "unknown" ) vp_title = ( tool_output.get("title", "Presentation") if isinstance(tool_output, dict) else "Presentation" ) if vp_status in ("pending", "generating"): completed_items = [ f"Title: {vp_title}", "Presentation generation started", "Processing in background...", ] elif vp_status == "failed": error_msg = ( tool_output.get("error", "Unknown error") if isinstance(tool_output, dict) else "Unknown error" ) completed_items = [ f"Title: {vp_title}", f"Error: {error_msg[:50]}", ] else: completed_items = last_active_step_items yield streaming_service.format_thinking_step( step_id=original_step_id, title="Generating video presentation", status="completed", items=completed_items, ) elif tool_name == "generate_report": report_status = ( tool_output.get("status", "unknown") if isinstance(tool_output, dict) else "unknown" ) report_title = ( tool_output.get("title", "Report") if isinstance(tool_output, dict) else "Report" ) word_count = ( tool_output.get("word_count", 0) if isinstance(tool_output, dict) else 0 ) is_revision = ( tool_output.get("is_revision", False) if isinstance(tool_output, dict) else False ) step_title = "Revising report" if is_revision else "Generating report" if report_status == "ready": completed_items = [ f"Topic: {report_title}", f"{word_count:,} words", "Report ready", ] elif report_status == "failed": error_msg = ( tool_output.get("error", "Unknown error") if isinstance(tool_output, dict) else "Unknown error" ) completed_items = [ f"Topic: {report_title}", f"Error: {error_msg[:50]}", ] else: completed_items = last_active_step_items yield streaming_service.format_thinking_step( step_id=original_step_id, title=step_title, status="completed", items=completed_items, ) elif tool_name in ("execute", "execute_code"): raw_text = ( tool_output.get("result", "") if isinstance(tool_output, dict) else str(tool_output) ) m = re.match(r"^Exit code:\s*(\d+)", raw_text) exit_code_val = int(m.group(1)) if m else None if exit_code_val is not None and exit_code_val == 0: completed_items = [ *last_active_step_items, "Completed successfully", ] elif exit_code_val is not None: completed_items = [ *last_active_step_items, f"Exit code: {exit_code_val}", ] else: completed_items = [*last_active_step_items, "Finished"] yield streaming_service.format_thinking_step( step_id=original_step_id, title="Running command", status="completed", items=completed_items, ) elif tool_name == "ls": if isinstance(tool_output, dict): ls_output = tool_output.get("result", "") elif isinstance(tool_output, str): ls_output = tool_output else: ls_output = str(tool_output) if tool_output else "" file_names: list[str] = [] if ls_output: paths: list[str] = [] try: parsed = ast.literal_eval(ls_output) if isinstance(parsed, list): paths = [str(p) for p in parsed] except (ValueError, SyntaxError): paths = [ line.strip() for line in ls_output.strip().split("\n") if line.strip() ] for p in paths: name = p.rstrip("/").split("/")[-1] if name and len(name) <= 40: file_names.append(name) elif name: file_names.append(name[:37] + "...") if file_names: if len(file_names) <= 5: completed_items = [f"[{name}]" for name in file_names] else: completed_items = [f"[{name}]" for name in file_names[:4]] completed_items.append(f"(+{len(file_names) - 4} more)") else: completed_items = ["No files found"] yield streaming_service.format_thinking_step( step_id=original_step_id, title="Listing files", status="completed", items=completed_items, ) else: yield streaming_service.format_thinking_step( step_id=original_step_id, title=f"Using {tool_name.replace('_', ' ')}", status="completed", items=last_active_step_items, ) just_finished_tool = True last_active_step_id = None last_active_step_title = "" last_active_step_items = [] if tool_name == "generate_podcast": yield streaming_service.format_tool_output_available( tool_call_id, tool_output if isinstance(tool_output, dict) else {"result": tool_output}, ) if ( isinstance(tool_output, dict) and tool_output.get("status") == "success" ): yield streaming_service.format_terminal_info( f"Podcast generated successfully: {tool_output.get('title', 'Podcast')}", "success", ) else: error_msg = ( tool_output.get("error", "Unknown error") if isinstance(tool_output, dict) else "Unknown error" ) yield streaming_service.format_terminal_info( f"Podcast generation failed: {error_msg}", "error", ) elif tool_name == "generate_video_presentation": yield streaming_service.format_tool_output_available( tool_call_id, tool_output if isinstance(tool_output, dict) else {"result": tool_output}, ) if ( isinstance(tool_output, dict) and tool_output.get("status") == "pending" ): yield streaming_service.format_terminal_info( f"Video presentation queued: {tool_output.get('title', 'Presentation')}", "success", ) elif ( isinstance(tool_output, dict) and tool_output.get("status") == "failed" ): error_msg = ( tool_output.get("error", "Unknown error") if isinstance(tool_output, dict) else "Unknown error" ) yield streaming_service.format_terminal_info( f"Presentation generation failed: {error_msg}", "error", ) elif tool_name == "generate_image": yield streaming_service.format_tool_output_available( tool_call_id, tool_output if isinstance(tool_output, dict) else {"result": tool_output}, ) if isinstance(tool_output, dict): if tool_output.get("error"): yield streaming_service.format_terminal_info( f"Image generation failed: {tool_output['error'][:60]}", "error", ) else: yield streaming_service.format_terminal_info( "Image generated successfully", "success", ) elif tool_name == "scrape_webpage": if isinstance(tool_output, dict): display_output = { k: v for k, v in tool_output.items() if k != "content" } if "content" in tool_output: content = tool_output.get("content", "") display_output["content_preview"] = ( content[:500] + "..." if len(content) > 500 else content ) yield streaming_service.format_tool_output_available( tool_call_id, display_output, ) else: yield streaming_service.format_tool_output_available( tool_call_id, {"result": tool_output}, ) if isinstance(tool_output, dict) and "error" not in tool_output: title = tool_output.get("title", "Webpage") word_count = tool_output.get("word_count", 0) yield streaming_service.format_terminal_info( f"Scraped: {title[:40]}{'...' if len(title) > 40 else ''} ({word_count:,} words)", "success", ) else: error_msg = ( tool_output.get("error", "Failed to scrape") if isinstance(tool_output, dict) else "Failed to scrape" ) yield streaming_service.format_terminal_info( f"Scrape failed: {error_msg}", "error", ) elif tool_name == "generate_report": # Stream the full report result so frontend can render the ReportCard yield streaming_service.format_tool_output_available( tool_call_id, tool_output if isinstance(tool_output, dict) else {"result": tool_output}, ) # Send appropriate terminal message based on status if ( isinstance(tool_output, dict) and tool_output.get("status") == "ready" ): word_count = tool_output.get("word_count", 0) yield streaming_service.format_terminal_info( f"Report generated: {tool_output.get('title', 'Report')} ({word_count:,} words)", "success", ) else: error_msg = ( tool_output.get("error", "Unknown error") if isinstance(tool_output, dict) else "Unknown error" ) yield streaming_service.format_terminal_info( f"Report generation failed: {error_msg}", "error", ) elif tool_name == "generate_resume": yield streaming_service.format_tool_output_available( tool_call_id, tool_output if isinstance(tool_output, dict) else {"result": tool_output}, ) if ( isinstance(tool_output, dict) and tool_output.get("status") == "ready" ): yield streaming_service.format_terminal_info( f"Resume generated: {tool_output.get('title', 'Resume')}", "success", ) else: error_msg = ( tool_output.get("error", "Unknown error") if isinstance(tool_output, dict) else "Unknown error" ) yield streaming_service.format_terminal_info( f"Resume generation failed: {error_msg}", "error", ) elif tool_name in ( "create_notion_page", "update_notion_page", "delete_notion_page", "create_linear_issue", "update_linear_issue", "delete_linear_issue", "create_google_drive_file", "delete_google_drive_file", "create_onedrive_file", "delete_onedrive_file", "create_dropbox_file", "delete_dropbox_file", "create_gmail_draft", "update_gmail_draft", "send_gmail_email", "trash_gmail_email", "create_calendar_event", "update_calendar_event", "delete_calendar_event", "create_jira_issue", "update_jira_issue", "delete_jira_issue", "create_confluence_page", "update_confluence_page", "delete_confluence_page", ): yield streaming_service.format_tool_output_available( tool_call_id, tool_output if isinstance(tool_output, dict) else {"result": tool_output}, ) elif tool_name in ("execute", "execute_code"): raw_text = ( tool_output.get("result", "") if isinstance(tool_output, dict) else str(tool_output) ) exit_code: int | None = None output_text = raw_text m = re.match(r"^Exit code:\s*(\d+)", raw_text) if m: exit_code = int(m.group(1)) om = re.search(r"\nOutput:\n([\s\S]*)", raw_text) output_text = om.group(1) if om else "" thread_id_str = config.get("configurable", {}).get("thread_id", "") for sf_match in re.finditer( r"^SANDBOX_FILE:\s*(.+)$", output_text, re.MULTILINE ): fpath = sf_match.group(1).strip() if fpath and fpath not in result.sandbox_files: result.sandbox_files.append(fpath) yield streaming_service.format_tool_output_available( tool_call_id, { "exit_code": exit_code, "output": output_text, "thread_id": thread_id_str, }, ) elif tool_name == "web_search": xml = ( tool_output.get("result", str(tool_output)) if isinstance(tool_output, dict) else str(tool_output) ) citations: dict[str, dict[str, str]] = {} for m in re.finditer( r"<!\[CDATA\[(.*?)\]\]>\s*", xml, ): title, url = m.group(1).strip(), m.group(2).strip() if url.startswith("http") and url not in citations: citations[url] = {"title": title} for m in re.finditer( r"", xml, ): chunk_url, content = m.group(1).strip(), m.group(2).strip() if ( chunk_url.startswith("http") and chunk_url in citations and content ): citations[chunk_url]["snippet"] = ( content[:200] + "…" if len(content) > 200 else content ) yield streaming_service.format_tool_output_available( tool_call_id, {"status": "completed", "citations": citations}, ) else: yield streaming_service.format_tool_output_available( tool_call_id, {"status": "completed", "result_length": len(str(tool_output))}, ) yield streaming_service.format_terminal_info( f"Tool {tool_name} completed", "success" ) elif event_type == "on_custom_event" and event.get("name") == "report_progress": # Live progress updates from inside the generate_report tool data = event.get("data", {}) message = data.get("message", "") if message and last_active_step_id: phase = data.get("phase", "") # Always keep the "Topic: ..." line topic_items = [ item for item in last_active_step_items if item.startswith("Topic:") ] if phase in ("revising_section", "adding_section"): # During section-level ops: keep plan summary + show current op plan_items = [ item for item in last_active_step_items if item.startswith("Topic:") or item.startswith("Modifying ") or item.startswith("Adding ") or item.startswith("Removing ") ] # Only keep plan_items that don't end with "..." (not progress lines) plan_items = [ item for item in plan_items if not item.endswith("...") ] last_active_step_items = [*plan_items, message] else: # Phase transitions: replace everything after topic last_active_step_items = [*topic_items, message] yield streaming_service.format_thinking_step( step_id=last_active_step_id, title=last_active_step_title, status="in_progress", items=last_active_step_items, ) elif ( event_type == "on_custom_event" and event.get("name") == "document_created" ): data = event.get("data", {}) if data.get("id"): yield streaming_service.format_data( "documents-updated", { "action": "created", "document": data, }, ) elif event_type in ("on_chain_end", "on_agent_end"): if current_text_id is not None: yield streaming_service.format_text_end(current_text_id) current_text_id = None if current_text_id is not None: yield streaming_service.format_text_end(current_text_id) completion_event = complete_current_step() if completion_event: yield completion_event result.accumulated_text = accumulated_text result.agent_called_update_memory = called_update_memory state = await agent.aget_state(config) is_interrupted = state.tasks and any(task.interrupts for task in state.tasks) if is_interrupted: result.is_interrupted = True result.interrupt_value = state.tasks[0].interrupts[0].value yield streaming_service.format_interrupt_request(result.interrupt_value) async def stream_new_chat( user_query: str, search_space_id: int, chat_id: int, user_id: str | None = None, llm_config_id: int = -1, mentioned_document_ids: list[int] | None = None, mentioned_surfsense_doc_ids: list[int] | None = None, checkpoint_id: str | None = None, needs_history_bootstrap: bool = False, thread_visibility: ChatVisibility | None = None, current_user_display_name: str | None = None, disabled_tools: list[str] | None = None, ) -> AsyncGenerator[str, None]: """ Stream chat responses from the new SurfSense deep agent. This uses the Vercel AI SDK Data Stream Protocol (SSE format) for streaming. The chat_id is used as LangGraph's thread_id for memory/checkpointing. The function creates and manages its own database session to guarantee proper cleanup even when Starlette's middleware cancels the task on client disconnect. Args: user_query: The user's query search_space_id: The search space ID chat_id: The chat ID (used as LangGraph thread_id for memory) user_id: The current user's UUID string (for memory tools and session state) llm_config_id: The LLM configuration ID (default: -1 for first global config) needs_history_bootstrap: If True, load message history from DB (for cloned chats) mentioned_document_ids: Optional list of document IDs mentioned with @ in the chat mentioned_surfsense_doc_ids: Optional list of SurfSense doc IDs mentioned with @ in the chat checkpoint_id: Optional checkpoint ID to rewind/fork from (for edit/reload operations) Yields: str: SSE formatted response strings """ streaming_service = VercelStreamingService() stream_result = StreamResult() _t_total = time.perf_counter() log_system_snapshot("stream_new_chat_START") from app.services.token_tracking_service import start_turn accumulator = start_turn() # Premium quota tracking state _premium_reserved = 0 _premium_request_id: str | None = None session = async_session_maker() try: # Mark AI as responding to this user for live collaboration if user_id: await set_ai_responding(session, chat_id, UUID(user_id)) # Load LLM config - supports both YAML (negative IDs) and database (positive IDs) agent_config: AgentConfig | None = None _t0 = time.perf_counter() if llm_config_id >= 0: # Positive ID: Load from NewLLMConfig database table agent_config = await load_agent_config( session=session, config_id=llm_config_id, search_space_id=search_space_id, ) if not agent_config: yield streaming_service.format_error( f"Failed to load NewLLMConfig with id {llm_config_id}" ) yield streaming_service.format_done() return # Create ChatLiteLLM from AgentConfig llm = create_chat_litellm_from_agent_config(agent_config) else: # Negative ID: Load from in-memory global configs (includes dynamic OpenRouter models) llm_config = load_global_llm_config_by_id(llm_config_id) if not llm_config: yield streaming_service.format_error( f"Failed to load LLM config with id {llm_config_id}" ) yield streaming_service.format_done() return # Create ChatLiteLLM from global config dict llm = create_chat_litellm_from_config(llm_config) agent_config = AgentConfig.from_yaml_config(llm_config) _perf_log.info( "[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)", time.perf_counter() - _t0, llm_config_id, ) # Premium quota reservation — applies to explicitly premium configs # AND Auto mode (which may route to premium models). _needs_premium_quota = ( agent_config is not None and user_id and (agent_config.is_premium or agent_config.is_auto_mode) ) if _needs_premium_quota: import uuid as _uuid from app.config import config as _app_config from app.services.token_quota_service import TokenQuotaService _premium_request_id = _uuid.uuid4().hex[:16] reserve_amount = min( agent_config.quota_reserve_tokens or _app_config.QUOTA_MAX_RESERVE_PER_CALL, _app_config.QUOTA_MAX_RESERVE_PER_CALL, ) async with shielded_async_session() as quota_session: quota_result = await TokenQuotaService.premium_reserve( db_session=quota_session, user_id=UUID(user_id), request_id=_premium_request_id, reserve_tokens=reserve_amount, ) _premium_reserved = reserve_amount if not quota_result.allowed: if agent_config.is_premium: yield streaming_service.format_error( "Premium token quota exceeded. Please purchase more tokens to continue using premium models." ) yield streaming_service.format_done() return # Auto mode: quota exhausted but we can still proceed # (the router may pick a free model). Reset reservation. _premium_request_id = None _premium_reserved = 0 if not llm: yield streaming_service.format_error("Failed to create LLM instance") yield streaming_service.format_done() return # Create connector service _t0 = time.perf_counter() connector_service = ConnectorService(session, search_space_id=search_space_id) firecrawl_api_key = None webcrawler_connector = await connector_service.get_connector_by_type( SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id ) if webcrawler_connector and webcrawler_connector.config: firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY") _perf_log.info( "[stream_new_chat] Connector service + firecrawl key in %.3fs", time.perf_counter() - _t0, ) # Get the PostgreSQL checkpointer for persistent conversation memory _t0 = time.perf_counter() checkpointer = await get_checkpointer() _perf_log.info( "[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0 ) visibility = thread_visibility or ChatVisibility.PRIVATE _t0 = time.perf_counter() agent = await create_surfsense_deep_agent( llm=llm, search_space_id=search_space_id, db_session=session, connector_service=connector_service, checkpointer=checkpointer, user_id=user_id, thread_id=chat_id, agent_config=agent_config, firecrawl_api_key=firecrawl_api_key, thread_visibility=visibility, disabled_tools=disabled_tools, mentioned_document_ids=mentioned_document_ids, ) _perf_log.info( "[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0 ) # Build input with message history langchain_messages = [] _t0 = time.perf_counter() # Bootstrap history for cloned chats (no LangGraph checkpoint exists yet) if needs_history_bootstrap: langchain_messages = await bootstrap_history_from_db( session, chat_id, thread_visibility=visibility ) thread_result = await session.execute( select(NewChatThread).filter(NewChatThread.id == chat_id) ) thread = thread_result.scalars().first() if thread: thread.needs_history_bootstrap = False await session.commit() # Mentioned KB documents are now handled by KnowledgeBaseSearchMiddleware # which merges them into the scoped filesystem with full document # structure. Only SurfSense docs and report context are inlined here. # Fetch mentioned SurfSense docs if any mentioned_surfsense_docs: list[SurfsenseDocsDocument] = [] if mentioned_surfsense_doc_ids: result = await session.execute( select(SurfsenseDocsDocument) .options(selectinload(SurfsenseDocsDocument.chunks)) .filter( SurfsenseDocsDocument.id.in_(mentioned_surfsense_doc_ids), ) ) mentioned_surfsense_docs = list(result.scalars().all()) # Fetch the most recent report(s) in this thread so the LLM can # easily find report_id for versioning decisions, instead of # having to dig through conversation history. recent_reports_result = await session.execute( select(Report) .filter( Report.thread_id == chat_id, Report.content.isnot(None), # exclude failed reports ) .order_by(Report.id.desc()) .limit(3) ) recent_reports = list(recent_reports_result.scalars().all()) # Format the user query with context (SurfSense docs + reports only) final_query = user_query context_parts = [] if mentioned_surfsense_docs: context_parts.append( format_mentioned_surfsense_docs_as_context(mentioned_surfsense_docs) ) # Surface report IDs prominently so the LLM doesn't have to # retrieve them from old tool responses in conversation history. if recent_reports: report_lines = [] for r in recent_reports: report_lines.append( f' - report_id={r.id}, title="{r.title}", ' f'style="{r.report_style or "detailed"}"' ) reports_listing = "\n".join(report_lines) context_parts.append( "\n" "Previously generated reports in this conversation:\n" f"{reports_listing}\n\n" "If the user wants to MODIFY, REVISE, UPDATE, or ADD to one of " "these reports, set parent_report_id to the relevant report_id above.\n" "If the user wants a completely NEW report on a different topic, " "leave parent_report_id unset.\n" "" ) if context_parts: context = "\n\n".join(context_parts) final_query = f"{context}\n\n{user_query}" if visibility == ChatVisibility.SEARCH_SPACE and current_user_display_name: final_query = f"**[{current_user_display_name}]:** {final_query}" # if messages: # # Convert frontend messages to LangChain format # for msg in messages: # if msg.role == "user": # langchain_messages.append(HumanMessage(content=msg.content)) # elif msg.role == "assistant": # langchain_messages.append(AIMessage(content=msg.content)) # else: # Fallback: just use the current user query with attachment context langchain_messages.append(HumanMessage(content=final_query)) input_state = { # Lets not pass this message atm because we are using the checkpointer to manage the conversation history # We will use this to simulate group chat functionality in the future "messages": langchain_messages, "search_space_id": search_space_id, } _perf_log.info( "[stream_new_chat] History bootstrap + doc/report queries in %.3fs", time.perf_counter() - _t0, ) # All pre-streaming DB reads are done. Commit to release the # transaction and its ACCESS SHARE locks so we don't block DDL # (e.g. migrations) for the entire duration of LLM streaming. # Tools that need DB access during streaming will start their own # short-lived transactions (or use isolated sessions). await session.commit() # Detach heavy ORM objects (documents with chunks, reports, etc.) # from the session identity map now that we've extracted the data # we need. This prevents them from accumulating in memory for the # entire duration of LLM streaming (which can be several minutes). session.expunge_all() _perf_log.info( "[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)", time.perf_counter() - _t_total, chat_id, ) # Configure LangGraph with thread_id for memory # If checkpoint_id is provided, fork from that checkpoint (for edit/reload) configurable = {"thread_id": str(chat_id)} if checkpoint_id: configurable["checkpoint_id"] = checkpoint_id config = { "configurable": configurable, "recursion_limit": 80, # Increase from default 25 to allow more tool iterations } # Start the message stream yield streaming_service.format_message_start() yield streaming_service.format_start_step() # Initial thinking step - analyzing the request if mentioned_surfsense_docs: initial_title = "Analyzing referenced content" action_verb = "Analyzing" else: initial_title = "Understanding your request" action_verb = "Processing" processing_parts = [] query_text = user_query[:80] + ("..." if len(user_query) > 80 else "") processing_parts.append(query_text) if mentioned_surfsense_docs: doc_names = [] for doc in mentioned_surfsense_docs: title = doc.title if len(title) > 30: title = title[:27] + "..." doc_names.append(title) if len(doc_names) == 1: processing_parts.append(f"[{doc_names[0]}]") else: processing_parts.append(f"[{len(doc_names)} docs]") initial_items = [f"{action_verb}: {' '.join(processing_parts)}"] initial_step_id = "thinking-1" yield streaming_service.format_thinking_step( step_id=initial_step_id, title=initial_title, status="in_progress", items=initial_items, ) # These ORM objects (with eagerly-loaded chunks) can be very large. # They're only needed to build context strings already copied into # final_query / langchain_messages — release them before streaming. del mentioned_surfsense_docs, recent_reports del langchain_messages, final_query # Check if this is the first assistant response so we can generate # a title in parallel with the agent stream (better UX than waiting # until after the full response). assistant_count_result = await session.execute( select(func.count(NewChatMessage.id)).filter( NewChatMessage.thread_id == chat_id, NewChatMessage.role == "assistant", ) ) is_first_response = (assistant_count_result.scalar() or 0) == 0 title_task: asyncio.Task[tuple[str | None, dict | None]] | None = None if is_first_response: async def _generate_title() -> tuple[str | None, dict | None]: """Generate a short title via litellm.acompletion. Returns (title, usage_dict). Usage is extracted directly from the response object because litellm fires its async callback via fire-and-forget ``create_task``, so the ``TokenTrackingCallback`` would run too late. We also blank the accumulator in this child-task context so the late callback doesn't double-count. """ try: from litellm import acompletion from app.services.llm_router_service import LLMRouterService from app.services.token_tracking_service import _turn_accumulator _turn_accumulator.set(None) prompt = TITLE_GENERATION_PROMPT.replace( "{user_query}", user_query[:500] ) messages = [{"role": "user", "content": prompt}] if getattr(llm, "model", None) == "auto": router = LLMRouterService.get_router() response = await router.acompletion( model="auto", messages=messages ) else: response = await acompletion( model=llm.model, messages=messages, api_key=getattr(llm, "api_key", None), api_base=getattr(llm, "api_base", None), ) usage_info = None usage = getattr(response, "usage", None) if usage: raw_model = getattr(llm, "model", "") or "" model_name = ( raw_model.split("/", 1)[-1] if "/" in raw_model else (raw_model or response.model or "unknown") ) usage_info = { "model": model_name, "prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0, "completion_tokens": getattr(usage, "completion_tokens", 0) or 0, "total_tokens": getattr(usage, "total_tokens", 0) or 0, } raw_title = response.choices[0].message.content.strip() if raw_title and len(raw_title) <= 100: return raw_title.strip("\"'"), usage_info return None, usage_info except Exception: logging.getLogger(__name__).exception( "[TitleGen] _generate_title failed" ) return None, None title_task = asyncio.create_task(_generate_title()) title_emitted = False _t_stream_start = time.perf_counter() _first_event_logged = False async for sse in _stream_agent_events( agent=agent, config=config, input_data=input_state, streaming_service=streaming_service, result=stream_result, step_prefix="thinking", initial_step_id=initial_step_id, initial_step_title=initial_title, initial_step_items=initial_items, ): if not _first_event_logged: _perf_log.info( "[stream_new_chat] First agent event in %.3fs (time since stream start), " "%.3fs (total since request start) (chat_id=%s)", time.perf_counter() - _t_stream_start, time.perf_counter() - _t_total, chat_id, ) _first_event_logged = True yield sse # Inject title update mid-stream as soon as the background task finishes if title_task is not None and title_task.done() and not title_emitted: generated_title, title_usage = title_task.result() if title_usage: accumulator.add(**title_usage) if generated_title: async with shielded_async_session() as title_session: title_thread_result = await title_session.execute( select(NewChatThread).filter(NewChatThread.id == chat_id) ) title_thread = title_thread_result.scalars().first() if title_thread: title_thread.title = generated_title await title_session.commit() yield streaming_service.format_thread_title_update( chat_id, generated_title ) title_emitted = True _perf_log.info( "[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)", time.perf_counter() - _t_stream_start, chat_id, ) log_system_snapshot("stream_new_chat_END") if stream_result.is_interrupted: if title_task is not None and not title_task.done(): title_task.cancel() usage_summary = accumulator.per_message_summary() _perf_log.info( "[token_usage] interrupted new_chat: calls=%d total=%d summary=%s", len(accumulator.calls), accumulator.grand_total, usage_summary, ) if usage_summary: yield streaming_service.format_data( "token-usage", { "usage": usage_summary, "prompt_tokens": accumulator.total_prompt_tokens, "completion_tokens": accumulator.total_completion_tokens, "total_tokens": accumulator.grand_total, "call_details": accumulator.serialized_calls(), }, ) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() return # If the title task didn't finish during streaming, await it now if title_task is not None and not title_emitted: generated_title, title_usage = await title_task if title_usage: accumulator.add(**title_usage) if generated_title: async with shielded_async_session() as title_session: title_thread_result = await title_session.execute( select(NewChatThread).filter(NewChatThread.id == chat_id) ) title_thread = title_thread_result.scalars().first() if title_thread: title_thread.title = generated_title await title_session.commit() yield streaming_service.format_thread_title_update( chat_id, generated_title ) # Finalize premium quota with actual tokens. # For Auto mode, only count tokens from calls that used premium models. if _premium_request_id and user_id: try: from app.services.token_quota_service import TokenQuotaService if agent_config and agent_config.is_auto_mode: from app.services.llm_router_service import LLMRouterService actual_premium_tokens = LLMRouterService.compute_premium_tokens( accumulator.calls ) else: actual_premium_tokens = accumulator.grand_total async with shielded_async_session() as quota_session: await TokenQuotaService.premium_finalize( db_session=quota_session, user_id=UUID(user_id), request_id=_premium_request_id, actual_tokens=actual_premium_tokens, reserved_tokens=_premium_reserved, ) except Exception: logging.getLogger(__name__).warning( "Failed to finalize premium quota for user %s", user_id, exc_info=True, ) usage_summary = accumulator.per_message_summary() _perf_log.info( "[token_usage] normal new_chat: calls=%d total=%d summary=%s", len(accumulator.calls), accumulator.grand_total, usage_summary, ) if usage_summary: yield streaming_service.format_data( "token-usage", { "usage": usage_summary, "prompt_tokens": accumulator.total_prompt_tokens, "completion_tokens": accumulator.total_completion_tokens, "total_tokens": accumulator.grand_total, "call_details": accumulator.serialized_calls(), }, ) # Fire background memory extraction if the agent didn't handle it. # Shared threads write to team memory; private threads write to user memory. if not stream_result.agent_called_update_memory: if visibility == ChatVisibility.SEARCH_SPACE: task = asyncio.create_task( extract_and_save_team_memory( user_message=user_query, search_space_id=search_space_id, llm=llm, author_display_name=current_user_display_name, ) ) _background_tasks.add(task) task.add_done_callback(_background_tasks.discard) elif user_id: task = asyncio.create_task( extract_and_save_memory( user_message=user_query, user_id=user_id, llm=llm, ) ) _background_tasks.add(task) task.add_done_callback(_background_tasks.discard) # Finish the step and message yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() except Exception as e: # Handle any errors import traceback error_message = f"Error during chat: {e!s}" print(f"[stream_new_chat] {error_message}") print(f"[stream_new_chat] Exception type: {type(e).__name__}") print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}") yield streaming_service.format_error(error_message) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() finally: # Shield the ENTIRE async cleanup from anyio cancel-scope # cancellation. Starlette's BaseHTTPMiddleware uses anyio task # groups; on client disconnect, it cancels the scope with # level-triggered cancellation — every unshielded `await` inside # the cancelled scope raises CancelledError immediately. Without # this shield the very first `await` (session.rollback) would # raise CancelledError, `except Exception` wouldn't catch it # (CancelledError is a BaseException), and the rest of the # finally block — including session.close() — would never run. with anyio.CancelScope(shield=True): # Release premium reservation if not finalized if _premium_request_id and _premium_reserved > 0 and user_id: try: from app.services.token_quota_service import TokenQuotaService async with shielded_async_session() as quota_session: await TokenQuotaService.premium_release( db_session=quota_session, user_id=UUID(user_id), reserved_tokens=_premium_reserved, ) _premium_reserved = 0 except Exception: logging.getLogger(__name__).warning( "Failed to release premium quota for user %s", user_id ) try: await session.rollback() await clear_ai_responding(session, chat_id) except Exception: try: async with shielded_async_session() as fresh_session: await clear_ai_responding(fresh_session, chat_id) except Exception: logging.getLogger(__name__).warning( "Failed to clear AI responding state for thread %s", chat_id ) with contextlib.suppress(Exception): session.expunge_all() with contextlib.suppress(Exception): await session.close() # Persist any sandbox-produced files to local storage so they # remain downloadable after the Daytona sandbox auto-deletes. if stream_result and stream_result.sandbox_files: with contextlib.suppress(Exception): from app.agents.new_chat.sandbox import ( is_sandbox_enabled, persist_and_delete_sandbox, ) if is_sandbox_enabled(): with anyio.CancelScope(shield=True): await persist_and_delete_sandbox( chat_id, stream_result.sandbox_files ) # Break circular refs held by the agent graph, tools, and LLM # wrappers so the GC can reclaim them in a single pass. agent = llm = connector_service = None input_state = stream_result = None session = None collected = gc.collect(0) + gc.collect(1) + gc.collect(2) if collected: _perf_log.info( "[stream_new_chat] gc.collect() reclaimed %d objects (chat_id=%s)", collected, chat_id, ) trim_native_heap() log_system_snapshot("stream_new_chat_END") async def stream_resume_chat( chat_id: int, search_space_id: int, decisions: list[dict], user_id: str | None = None, llm_config_id: int = -1, thread_visibility: ChatVisibility | None = None, ) -> AsyncGenerator[str, None]: streaming_service = VercelStreamingService() stream_result = StreamResult() _t_total = time.perf_counter() from app.services.token_tracking_service import start_turn accumulator = start_turn() session = async_session_maker() try: if user_id: await set_ai_responding(session, chat_id, UUID(user_id)) agent_config: AgentConfig | None = None _t0 = time.perf_counter() if llm_config_id >= 0: agent_config = await load_agent_config( session=session, config_id=llm_config_id, search_space_id=search_space_id, ) if not agent_config: yield streaming_service.format_error( f"Failed to load NewLLMConfig with id {llm_config_id}" ) yield streaming_service.format_done() return llm = create_chat_litellm_from_agent_config(agent_config) else: llm_config = load_global_llm_config_by_id(llm_config_id) if not llm_config: yield streaming_service.format_error( f"Failed to load LLM config with id {llm_config_id}" ) yield streaming_service.format_done() return llm = create_chat_litellm_from_config(llm_config) agent_config = AgentConfig.from_yaml_config(llm_config) _perf_log.info( "[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0 ) # Premium quota reservation (same logic as stream_new_chat) _resume_premium_reserved = 0 _resume_premium_request_id: str | None = None _resume_needs_premium = ( agent_config is not None and user_id and (agent_config.is_premium or agent_config.is_auto_mode) ) if _resume_needs_premium: import uuid as _uuid from app.config import config as _app_config from app.services.token_quota_service import TokenQuotaService _resume_premium_request_id = _uuid.uuid4().hex[:16] reserve_amount = min( agent_config.quota_reserve_tokens or _app_config.QUOTA_MAX_RESERVE_PER_CALL, _app_config.QUOTA_MAX_RESERVE_PER_CALL, ) async with shielded_async_session() as quota_session: quota_result = await TokenQuotaService.premium_reserve( db_session=quota_session, user_id=UUID(user_id), request_id=_resume_premium_request_id, reserve_tokens=reserve_amount, ) _resume_premium_reserved = reserve_amount if not quota_result.allowed: if agent_config.is_premium: yield streaming_service.format_error( "Premium token quota exceeded. Please purchase more tokens to continue using premium models." ) yield streaming_service.format_done() return _resume_premium_request_id = None _resume_premium_reserved = 0 if not llm: yield streaming_service.format_error("Failed to create LLM instance") yield streaming_service.format_done() return _t0 = time.perf_counter() connector_service = ConnectorService(session, search_space_id=search_space_id) firecrawl_api_key = None webcrawler_connector = await connector_service.get_connector_by_type( SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id ) if webcrawler_connector and webcrawler_connector.config: firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY") _perf_log.info( "[stream_resume] Connector service + firecrawl key in %.3fs", time.perf_counter() - _t0, ) _t0 = time.perf_counter() checkpointer = await get_checkpointer() _perf_log.info( "[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0 ) visibility = thread_visibility or ChatVisibility.PRIVATE _t0 = time.perf_counter() agent = await create_surfsense_deep_agent( llm=llm, search_space_id=search_space_id, db_session=session, connector_service=connector_service, checkpointer=checkpointer, user_id=user_id, thread_id=chat_id, agent_config=agent_config, firecrawl_api_key=firecrawl_api_key, thread_visibility=visibility, ) _perf_log.info( "[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0 ) # Release the transaction before streaming (same rationale as stream_new_chat). await session.commit() session.expunge_all() _perf_log.info( "[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)", time.perf_counter() - _t_total, chat_id, ) from langgraph.types import Command config = { "configurable": {"thread_id": str(chat_id)}, "recursion_limit": 80, } yield streaming_service.format_message_start() yield streaming_service.format_start_step() _t_stream_start = time.perf_counter() _first_event_logged = False async for sse in _stream_agent_events( agent=agent, config=config, input_data=Command(resume={"decisions": decisions}), streaming_service=streaming_service, result=stream_result, step_prefix="thinking-resume", ): if not _first_event_logged: _perf_log.info( "[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)", time.perf_counter() - _t_stream_start, time.perf_counter() - _t_total, chat_id, ) _first_event_logged = True yield sse _perf_log.info( "[stream_resume] Agent stream completed in %.3fs (chat_id=%s)", time.perf_counter() - _t_stream_start, chat_id, ) if stream_result.is_interrupted: usage_summary = accumulator.per_message_summary() _perf_log.info( "[token_usage] interrupted resume_chat: calls=%d total=%d summary=%s", len(accumulator.calls), accumulator.grand_total, usage_summary, ) if usage_summary: yield streaming_service.format_data( "token-usage", { "usage": usage_summary, "prompt_tokens": accumulator.total_prompt_tokens, "completion_tokens": accumulator.total_completion_tokens, "total_tokens": accumulator.grand_total, "call_details": accumulator.serialized_calls(), }, ) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() return # Finalize premium quota for resume path if _resume_premium_request_id and user_id: try: from app.services.token_quota_service import TokenQuotaService if agent_config and agent_config.is_auto_mode: from app.services.llm_router_service import LLMRouterService actual_premium_tokens = LLMRouterService.compute_premium_tokens( accumulator.calls ) else: actual_premium_tokens = accumulator.grand_total async with shielded_async_session() as quota_session: await TokenQuotaService.premium_finalize( db_session=quota_session, user_id=UUID(user_id), request_id=_resume_premium_request_id, actual_tokens=actual_premium_tokens, reserved_tokens=_resume_premium_reserved, ) except Exception: logging.getLogger(__name__).warning( "Failed to finalize premium quota for user %s (resume)", user_id, exc_info=True, ) usage_summary = accumulator.per_message_summary() _perf_log.info( "[token_usage] normal resume_chat: calls=%d total=%d summary=%s", len(accumulator.calls), accumulator.grand_total, usage_summary, ) if usage_summary: yield streaming_service.format_data( "token-usage", { "usage": usage_summary, "prompt_tokens": accumulator.total_prompt_tokens, "completion_tokens": accumulator.total_completion_tokens, "total_tokens": accumulator.grand_total, "call_details": accumulator.serialized_calls(), }, ) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() except Exception as e: import traceback error_message = f"Error during resume: {e!s}" print(f"[stream_resume_chat] {error_message}") print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}") yield streaming_service.format_error(error_message) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() finally: with anyio.CancelScope(shield=True): # Release premium reservation if not finalized if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id: try: from app.services.token_quota_service import TokenQuotaService async with shielded_async_session() as quota_session: await TokenQuotaService.premium_release( db_session=quota_session, user_id=UUID(user_id), reserved_tokens=_resume_premium_reserved, ) _resume_premium_reserved = 0 except Exception: logging.getLogger(__name__).warning( "Failed to release premium quota for user %s (resume)", user_id ) try: await session.rollback() await clear_ai_responding(session, chat_id) except Exception: try: async with shielded_async_session() as fresh_session: await clear_ai_responding(fresh_session, chat_id) except Exception: logging.getLogger(__name__).warning( "Failed to clear AI responding state for thread %s", chat_id ) with contextlib.suppress(Exception): session.expunge_all() with contextlib.suppress(Exception): await session.close() agent = llm = connector_service = None stream_result = None session = None collected = gc.collect(0) + gc.collect(1) + gc.collect(2) if collected: _perf_log.info( "[stream_resume] gc.collect() reclaimed %d objects (chat_id=%s)", collected, chat_id, ) trim_native_heap() log_system_snapshot("stream_resume_chat_END")