diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index ecb1c2a6f..53e6c8e09 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -43,11 +43,12 @@ from app.schemas.new_chat import ( PublicChatSnapshotCreateResponse, PublicChatSnapshotListResponse, RegenerateRequest, + ResumeRequest, ThreadHistoryLoadResponse, ThreadListItem, ThreadListResponse, ) -from app.tasks.chat.stream_new_chat import stream_new_chat +from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat from app.users import current_active_user from app.utils.rbac import check_permission @@ -1326,3 +1327,78 @@ async def regenerate_response( status_code=500, detail=f"An unexpected error occurred during regeneration: {e!s}", ) from None + + +# ============================================================================= +# Resume Interrupted Chat Endpoint +# ============================================================================= + + +@router.post("/threads/{thread_id}/resume") +async def resume_chat( + thread_id: int, + request: ResumeRequest, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + try: + result = await session.execute( + select(NewChatThread).filter(NewChatThread.id == thread_id) + ) + thread = result.scalars().first() + + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + + await check_permission( + session, + user, + thread.search_space_id, + Permission.CHATS_CREATE.value, + "You don't have permission to chat in this search space", + ) + + await check_thread_access(session, thread, user) + + search_space_result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == request.search_space_id) + ) + search_space = search_space_result.scalars().first() + + if not search_space: + raise HTTPException(status_code=404, detail="Search space not found") + + llm_config_id = ( + search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 + ) + + decisions = [d.model_dump() for d in request.decisions] + + return StreamingResponse( + stream_resume_chat( + chat_id=thread_id, + search_space_id=request.search_space_id, + decisions=decisions, + session=session, + user_id=str(user.id), + llm_config_id=llm_config_id, + thread_visibility=thread.visibility, + ), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + except HTTPException: + raise + except Exception as e: + import traceback + + traceback.print_exc() + raise HTTPException( + status_code=500, + detail=f"An unexpected error occurred during resume: {e!s}", + ) from None diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 801d140be..852793230 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -1268,3 +1268,623 @@ async def stream_new_chat( finally: # Clear AI responding state for live collaboration await clear_ai_responding(session, chat_id) + + +async def stream_resume_chat( + chat_id: int, + search_space_id: int, + decisions: list[dict], + session: AsyncSession, + user_id: str | None = None, + llm_config_id: int = -1, + thread_visibility: ChatVisibility | None = None, +) -> AsyncGenerator[str, None]: + streaming_service = VercelStreamingService() + current_text_id: str | None = None + + try: + if user_id: + await set_ai_responding(session, chat_id, UUID(user_id)) + + agent_config: AgentConfig | None = None + if llm_config_id >= 0: + agent_config = await load_agent_config( + session=session, + config_id=llm_config_id, + search_space_id=search_space_id, + ) + if not agent_config: + yield streaming_service.format_error( + f"Failed to load NewLLMConfig with id {llm_config_id}" + ) + yield streaming_service.format_done() + return + llm = create_chat_litellm_from_agent_config(agent_config) + else: + llm_config = load_llm_config_from_yaml(llm_config_id=llm_config_id) + if not llm_config: + yield streaming_service.format_error( + f"Failed to load LLM config with id {llm_config_id}" + ) + yield streaming_service.format_done() + return + llm = create_chat_litellm_from_config(llm_config) + agent_config = AgentConfig.from_yaml_config(llm_config) + + if not llm: + yield streaming_service.format_error("Failed to create LLM instance") + yield streaming_service.format_done() + return + + connector_service = ConnectorService(session, search_space_id=search_space_id) + + from app.db import SearchSourceConnectorType + + firecrawl_api_key = None + webcrawler_connector = await connector_service.get_connector_by_type( + SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id + ) + if webcrawler_connector and webcrawler_connector.config: + firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY") + + checkpointer = await get_checkpointer() + visibility = thread_visibility or ChatVisibility.PRIVATE + + agent = await create_surfsense_deep_agent( + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + ) + + from langgraph.types import Command + + config = { + "configurable": {"thread_id": str(chat_id)}, + "recursion_limit": 80, + } + + yield streaming_service.format_message_start() + yield streaming_service.format_start_step() + + accumulated_text = "" + thinking_step_counter = 0 + tool_step_ids: dict[str, str] = {} + completed_step_ids: set[str] = set() + last_active_step_id: str | None = None + last_active_step_title = "" + last_active_step_items: list[str] = [] + just_finished_tool = False + + def next_thinking_step_id() -> str: + nonlocal thinking_step_counter + thinking_step_counter += 1 + return f"thinking-resume-{thinking_step_counter}" + + def complete_current_step() -> str | None: + nonlocal last_active_step_id + if last_active_step_id and last_active_step_id not in completed_step_ids: + completed_step_ids.add(last_active_step_id) + event = streaming_service.format_thinking_step( + step_id=last_active_step_id, + title=last_active_step_title, + status="completed", + items=last_active_step_items, + ) + last_active_step_id = None + return event + return None + + async for event in agent.astream_events( + Command(resume={"decisions": decisions}), config=config, version="v2" + ): + event_type = event.get("event", "") + + if event_type == "on_chat_model_stream": + chunk = event.get("data", {}).get("chunk") + if chunk and hasattr(chunk, "content"): + content = chunk.content + if content and isinstance(content, str): + if current_text_id is None: + completion_event = complete_current_step() + if completion_event: + yield completion_event + if just_finished_tool: + last_active_step_id = None + last_active_step_title = "" + last_active_step_items = [] + just_finished_tool = False + current_text_id = streaming_service.generate_text_id() + yield streaming_service.format_text_start(current_text_id) + yield streaming_service.format_text_delta( + current_text_id, content + ) + accumulated_text += content + + elif event_type == "on_tool_start": + tool_name = event.get("name", "unknown_tool") + run_id = event.get("run_id", "") + tool_input = event.get("data", {}).get("input", {}) + + if current_text_id is not None: + yield streaming_service.format_text_end(current_text_id) + current_text_id = None + + if last_active_step_title != "Synthesizing response": + completion_event = complete_current_step() + if completion_event: + yield completion_event + + just_finished_tool = False + tool_step_id = next_thinking_step_id() + tool_step_ids[run_id] = tool_step_id + last_active_step_id = tool_step_id + + if tool_name == "search_knowledge_base": + query = ( + tool_input.get("query", "") + if isinstance(tool_input, dict) + else str(tool_input) + ) + last_active_step_title = "Searching knowledge base" + last_active_step_items = [ + f"Query: {query[:100]}{'...' if len(query) > 100 else ''}" + ] + yield streaming_service.format_thinking_step( + step_id=tool_step_id, + title="Searching knowledge base", + status="in_progress", + items=last_active_step_items, + ) + elif tool_name == "link_preview": + url = ( + tool_input.get("url", "") + if isinstance(tool_input, dict) + else str(tool_input) + ) + last_active_step_title = "Fetching link preview" + last_active_step_items = [ + f"URL: {url[:80]}{'...' if len(url) > 80 else ''}" + ] + yield streaming_service.format_thinking_step( + step_id=tool_step_id, + title="Fetching link preview", + status="in_progress", + items=last_active_step_items, + ) + elif tool_name == "display_image": + src = ( + tool_input.get("src", "") + if isinstance(tool_input, dict) + else str(tool_input) + ) + title = ( + tool_input.get("title", "") + if isinstance(tool_input, dict) + else "" + ) + last_active_step_title = "Analyzing the image" + last_active_step_items = [ + f"Analyzing: {title[:50] if title else src[:50]}{'...' if len(title or src) > 50 else ''}" + ] + yield streaming_service.format_thinking_step( + step_id=tool_step_id, + title="Analyzing the image", + status="in_progress", + items=last_active_step_items, + ) + elif tool_name == "scrape_webpage": + url = ( + tool_input.get("url", "") + if isinstance(tool_input, dict) + else str(tool_input) + ) + last_active_step_title = "Scraping webpage" + last_active_step_items = [ + f"URL: {url[:80]}{'...' if len(url) > 80 else ''}" + ] + yield streaming_service.format_thinking_step( + step_id=tool_step_id, + title="Scraping webpage", + status="in_progress", + items=last_active_step_items, + ) + elif tool_name == "generate_podcast": + podcast_title = ( + tool_input.get("podcast_title", "SurfSense Podcast") + if isinstance(tool_input, dict) + else "SurfSense Podcast" + ) + content_len = len( + tool_input.get("source_content", "") + if isinstance(tool_input, dict) + else "" + ) + last_active_step_title = "Generating podcast" + last_active_step_items = [ + f"Title: {podcast_title}", + f"Content: {content_len:,} characters", + "Preparing audio generation...", + ] + yield streaming_service.format_thinking_step( + step_id=tool_step_id, + title="Generating podcast", + status="in_progress", + items=last_active_step_items, + ) + else: + last_active_step_title = f"Using {tool_name.replace('_', ' ')}" + last_active_step_items = [] + yield streaming_service.format_thinking_step( + step_id=tool_step_id, + title=last_active_step_title, + status="in_progress", + ) + + tool_call_id = ( + f"call_{run_id[:32]}" + if run_id + else streaming_service.generate_tool_call_id() + ) + yield streaming_service.format_tool_input_start(tool_call_id, tool_name) + yield streaming_service.format_tool_input_available( + tool_call_id, + tool_name, + tool_input + if isinstance(tool_input, dict) + else {"input": tool_input}, + ) + + elif event_type == "on_tool_end": + run_id = event.get("run_id", "") + tool_name = event.get("name", "unknown_tool") + raw_output = event.get("data", {}).get("output", "") + + if hasattr(raw_output, "content"): + content = raw_output.content + if isinstance(content, str): + try: + tool_output = json.loads(content) + except (json.JSONDecodeError, TypeError): + tool_output = {"result": content} + elif isinstance(content, dict): + tool_output = content + else: + tool_output = {"result": str(content)} + elif isinstance(raw_output, dict): + tool_output = raw_output + else: + tool_output = { + "result": str(raw_output) if raw_output else "completed" + } + + tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown" + original_step_id = tool_step_ids.get( + run_id, f"thinking-unknown-{run_id[:8]}" + ) + completed_step_ids.add(original_step_id) + + if tool_name == "search_knowledge_base": + result_info = "Search completed" + if isinstance(tool_output, dict): + result_len = tool_output.get("result_length", 0) + if result_len > 0: + result_info = ( + f"Found relevant information ({result_len} chars)" + ) + completed_items = [*last_active_step_items, result_info] + yield streaming_service.format_thinking_step( + step_id=original_step_id, + title="Searching knowledge base", + status="completed", + items=completed_items, + ) + elif tool_name == "link_preview": + if isinstance(tool_output, dict): + title = tool_output.get("title", "Link") + domain = tool_output.get("domain", "") + has_error = "error" in tool_output + if has_error: + completed_items = [ + *last_active_step_items, + f"Error: {tool_output.get('error', 'Failed to fetch')}", + ] + else: + completed_items = [ + *last_active_step_items, + f"Title: {title[:60]}{'...' if len(title) > 60 else ''}", + f"Domain: {domain}" if domain else "Preview loaded", + ] + else: + completed_items = [*last_active_step_items, "Preview loaded"] + yield streaming_service.format_thinking_step( + step_id=original_step_id, + title="Fetching link preview", + status="completed", + items=completed_items, + ) + elif tool_name == "display_image": + if isinstance(tool_output, dict): + title = tool_output.get("title", "") + alt = tool_output.get("alt", "Image") + display_name = title or alt + completed_items = [ + *last_active_step_items, + f"Analyzed: {display_name[:50]}{'...' if len(display_name) > 50 else ''}", + ] + else: + completed_items = [*last_active_step_items, "Image analyzed"] + yield streaming_service.format_thinking_step( + step_id=original_step_id, + title="Analyzing the image", + status="completed", + items=completed_items, + ) + elif tool_name == "scrape_webpage": + if isinstance(tool_output, dict): + title = tool_output.get("title", "Webpage") + word_count = tool_output.get("word_count", 0) + has_error = "error" in tool_output + if has_error: + completed_items = [ + *last_active_step_items, + f"Error: {tool_output.get('error', 'Failed to scrape')[:50]}", + ] + else: + completed_items = [ + *last_active_step_items, + f"Title: {title[:50]}{'...' if len(title) > 50 else ''}", + f"Extracted: {word_count:,} words", + ] + else: + completed_items = [*last_active_step_items, "Content extracted"] + yield streaming_service.format_thinking_step( + step_id=original_step_id, + title="Scraping webpage", + status="completed", + items=completed_items, + ) + elif tool_name == "generate_podcast": + podcast_status = ( + tool_output.get("status", "unknown") + if isinstance(tool_output, dict) + else "unknown" + ) + podcast_title = ( + tool_output.get("title", "Podcast") + if isinstance(tool_output, dict) + else "Podcast" + ) + if podcast_status == "processing": + completed_items = [ + f"Title: {podcast_title}", + "Audio generation started", + "Processing in background...", + ] + elif podcast_status == "already_generating": + completed_items = [ + f"Title: {podcast_title}", + "Podcast already in progress", + "Please wait for it to complete", + ] + elif podcast_status == "error": + error_msg = ( + tool_output.get("error", "Unknown error") + if isinstance(tool_output, dict) + else "Unknown error" + ) + completed_items = [ + f"Title: {podcast_title}", + f"Error: {error_msg[:50]}", + ] + else: + completed_items = last_active_step_items + yield streaming_service.format_thinking_step( + step_id=original_step_id, + title="Generating podcast", + status="completed", + items=completed_items, + ) + elif tool_name == "ls": + if isinstance(tool_output, dict): + result = tool_output.get("result", "") + elif isinstance(tool_output, str): + result = tool_output + else: + result = str(tool_output) if tool_output else "" + file_names = [] + if result: + for line in result.strip().split("\n"): + line = line.strip() + if line: + name = line.rstrip("/").split("/")[-1] + if name and len(name) <= 40: + file_names.append(name) + elif name: + file_names.append(name[:37] + "...") + if file_names: + if len(file_names) <= 5: + completed_items = [f"[{name}]" for name in file_names] + else: + completed_items = [f"[{name}]" for name in file_names[:4]] + completed_items.append(f"(+{len(file_names) - 4} more)") + else: + completed_items = ["No files found"] + yield streaming_service.format_thinking_step( + step_id=original_step_id, + title="Exploring files", + status="completed", + items=completed_items, + ) + else: + yield streaming_service.format_thinking_step( + step_id=original_step_id, + title=f"Using {tool_name.replace('_', ' ')}", + status="completed", + items=last_active_step_items, + ) + + just_finished_tool = True + last_active_step_id = None + last_active_step_title = "" + last_active_step_items = [] + + if tool_name == "generate_podcast": + yield streaming_service.format_tool_output_available( + tool_call_id, + tool_output + if isinstance(tool_output, dict) + else {"result": tool_output}, + ) + if ( + isinstance(tool_output, dict) + and tool_output.get("status") == "success" + ): + yield streaming_service.format_terminal_info( + f"Podcast generated successfully: {tool_output.get('title', 'Podcast')}", + "success", + ) + else: + error_msg = ( + tool_output.get("error", "Unknown error") + if isinstance(tool_output, dict) + else "Unknown error" + ) + yield streaming_service.format_terminal_info( + f"Podcast generation failed: {error_msg}", + "error", + ) + elif tool_name == "link_preview": + yield streaming_service.format_tool_output_available( + tool_call_id, + tool_output + if isinstance(tool_output, dict) + else {"result": tool_output}, + ) + if isinstance(tool_output, dict) and "error" not in tool_output: + title = tool_output.get("title", "Link") + yield streaming_service.format_terminal_info( + f"Link preview loaded: {title[:50]}{'...' if len(title) > 50 else ''}", + "success", + ) + else: + error_msg = ( + tool_output.get("error", "Failed to fetch") + if isinstance(tool_output, dict) + else "Failed to fetch" + ) + yield streaming_service.format_terminal_info( + f"Link preview failed: {error_msg}", + "error", + ) + elif tool_name == "display_image": + yield streaming_service.format_tool_output_available( + tool_call_id, + tool_output + if isinstance(tool_output, dict) + else {"result": tool_output}, + ) + if isinstance(tool_output, dict): + title = tool_output.get("title") or tool_output.get( + "alt", "Image" + ) + yield streaming_service.format_terminal_info( + f"Image analyzed: {title[:40]}{'...' if len(title) > 40 else ''}", + "success", + ) + elif tool_name == "scrape_webpage": + if isinstance(tool_output, dict): + display_output = { + k: v for k, v in tool_output.items() if k != "content" + } + if "content" in tool_output: + content = tool_output.get("content", "") + display_output["content_preview"] = ( + content[:500] + "..." if len(content) > 500 else content + ) + yield streaming_service.format_tool_output_available( + tool_call_id, + display_output, + ) + else: + yield streaming_service.format_tool_output_available( + tool_call_id, + {"result": tool_output}, + ) + if isinstance(tool_output, dict) and "error" not in tool_output: + title = tool_output.get("title", "Webpage") + word_count = tool_output.get("word_count", 0) + yield streaming_service.format_terminal_info( + f"Scraped: {title[:40]}{'...' if len(title) > 40 else ''} ({word_count:,} words)", + "success", + ) + else: + error_msg = ( + tool_output.get("error", "Failed to scrape") + if isinstance(tool_output, dict) + else "Failed to scrape" + ) + yield streaming_service.format_terminal_info( + f"Scrape failed: {error_msg}", + "error", + ) + elif tool_name == "search_knowledge_base": + yield streaming_service.format_tool_output_available( + tool_call_id, + {"status": "completed", "result_length": len(str(tool_output))}, + ) + yield streaming_service.format_terminal_info( + "Knowledge base search completed", "success" + ) + else: + yield streaming_service.format_tool_output_available( + tool_call_id, + {"status": "completed", "result_length": len(str(tool_output))}, + ) + yield streaming_service.format_terminal_info( + f"Tool {tool_name} completed", "success" + ) + + elif event_type in ("on_chain_end", "on_agent_end"): + if current_text_id is not None: + yield streaming_service.format_text_end(current_text_id) + current_text_id = None + + if current_text_id is not None: + yield streaming_service.format_text_end(current_text_id) + + completion_event = complete_current_step() + if completion_event: + yield completion_event + + state = await agent.aget_state(config) + is_interrupted = state.tasks and any(task.interrupts for task in state.tasks) + if is_interrupted: + interrupt_value = state.tasks[0].interrupts[0].value + yield streaming_service.format_interrupt_request(interrupt_value) + + yield streaming_service.format_finish_step() + yield streaming_service.format_finish() + yield streaming_service.format_done() + + except Exception as e: + import traceback + + error_message = f"Error during resume: {e!s}" + print(f"[stream_resume_chat] {error_message}") + print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}") + if current_text_id is not None: + yield streaming_service.format_text_end(current_text_id) + yield streaming_service.format_error(error_message) + yield streaming_service.format_finish_step() + yield streaming_service.format_finish() + yield streaming_service.format_done() + + finally: + await clear_ai_responding(session, chat_id)