diff --git a/.vscode/settings.json b/.vscode/settings.json index 05bd30702..7da4b54f8 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,4 +1,9 @@ { "biome.configurationPath": "./surfsense_web/biome.json", - "deepscan.ignoreConfirmWarning": true + "deepscan.ignoreConfirmWarning": true, + "python.defaultInterpreterPath": "${workspaceFolder}/surfsense_backend/.venv/bin/python", + "basedpyright.analysis.extraPaths": [ + "${workspaceFolder}/surfsense_backend" + ], + "python-envs.pythonProjects": [] } \ No newline at end of file diff --git a/docker/.env.example b/docker/.env.example index fd56bdccc..aba15f13f 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -324,7 +324,6 @@ SURFSENSE_ENABLE_ACTION_LOG=true SURFSENSE_ENABLE_REVERT_ROUTE=true SURFSENSE_ENABLE_PERMISSION=true SURFSENSE_ENABLE_DOOM_LOOP=true -SURFSENSE_ENABLE_STREAM_PARITY_V2=true # Periodic connector sync interval (default: 5m) # SCHEDULE_CHECKER_INTERVAL=5m diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index ba89059c8..3d442973c 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -315,14 +315,6 @@ LANGSMITH_PROJECT=surfsense # SURFSENSE_ENABLE_ACTION_LOG=false # SURFSENSE_ENABLE_REVERT_ROUTE=false # Backend-only; flip when UI ships -# Streaming parity v2 — opt in to LangChain's structured AIMessageChunk -# content (typed reasoning blocks, tool-input deltas) and propagate the -# real tool_call_id to the SSE layer. When OFF, the stream falls back to -# the str-only text path and synthetic "call_" tool-call ids. -# Schema migrations 135/136 ship unconditionally because they are -# forward-compatible. -# SURFSENSE_ENABLE_STREAM_PARITY_V2=false - # Plugins # SURFSENSE_ENABLE_PLUGIN_LOADER=false # Comma-separated allowlist of plugin entry-point names diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/config.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/config.py index 16211686c..ac232b92a 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/config.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/config.py @@ -6,12 +6,19 @@ exposes the side-channel ``stream_resume_chat`` uses to ferry resume payloads. from __future__ import annotations +import logging from typing import Any from langchain.tools import ToolRuntime from .constants import DEFAULT_SUBAGENT_RECURSION_LIMIT +logger = logging.getLogger(__name__) + +# langgraph stores the parent task's scratchpad under this configurable key; +# subagents inherit the chain via ``parent_scratchpad`` fallback. +_LANGGRAPH_SCRATCHPAD_KEY = "__pregel_scratchpad" + def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]: """RunnableConfig for the nested invoke; raises ``recursion_limit`` to the parent's budget.""" @@ -42,3 +49,42 @@ def has_surfsense_resume(runtime: ToolRuntime) -> bool: if not isinstance(configurable, dict): return False return "surfsense_resume_value" in configurable + + +def drain_parent_null_resume(runtime: ToolRuntime) -> None: + """Consume the parent's lingering ``NULL_TASK_ID/RESUME`` write before delegating. + + ``stream_resume_chat`` wakes the main agent with + ``Command(resume={"decisions": [...]})`` so the propagated + ``_lg_interrupt(...)`` can return. langgraph stores that payload as the + parent task's ``null_resume`` pending write, which only gets consumed + *after* ``subagent.[a]invoke`` returns (when the post-call propagation + re-fires). While the subagent is mid-execution, any *new* ``interrupt()`` + inside it (e.g. a follow-up tool call after a mixed approve/reject) walks + ``subagent_scratchpad → parent_scratchpad.get_null_resume`` and picks up + the parent's still-live decisions — mismatching against a different number + of hanging tool calls and crashing ``HumanInTheLoopMiddleware``. + + Draining the write here closes that cross-graph leak so subagent + interrupts pause cleanly and re-propagate as a fresh approval card. + """ + cfg = runtime.config or {} + configurable = cfg.get("configurable") if isinstance(cfg, dict) else None + if not isinstance(configurable, dict): + return + scratchpad = configurable.get(_LANGGRAPH_SCRATCHPAD_KEY) + if scratchpad is None: + return + consume = getattr(scratchpad, "get_null_resume", None) + if not callable(consume): + return + try: + consume(True) + except Exception: + # Defensive: if langgraph's internal scratchpad shape changes we don't + # want to break the resume path. Worst case the original ValueError + # still surfaces — same behavior as before this fix. + logger.debug( + "drain_parent_null_resume: scratchpad.get_null_resume raised", + exc_info=True, + ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_tool.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_tool.py index 5668f8ddb..7c0dd8624 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_tool.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_tool.py @@ -20,6 +20,7 @@ from langgraph.types import Command from .config import ( consume_surfsense_resume, + drain_parent_null_resume, has_surfsense_resume, subagent_invoke_config, ) @@ -157,6 +158,9 @@ def build_task_tool_with_parent_config( ) expected = hitlrequest_action_count(pending_value) resume_value = fan_out_decisions_to_match(resume_value, expected) + # Prevent the parent's resume payload from leaking into subagent + # interrupts via langgraph's parent_scratchpad fallback. + drain_parent_null_resume(runtime) result = subagent.invoke( build_resume_command(resume_value, pending_id), config=sub_config, @@ -221,6 +225,9 @@ def build_task_tool_with_parent_config( ) expected = hitlrequest_action_count(pending_value) resume_value = fan_out_decisions_to_match(resume_value, expected) + # Prevent the parent's resume payload from leaking into subagent + # interrupts via langgraph's parent_scratchpad fallback. + drain_parent_null_resume(runtime) result = await subagent.ainvoke( build_resume_command(resume_value, pending_id), config=sub_config, diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/__init__.py index 768738118..dc721013a 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/__init__.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/__init__.py @@ -1,11 +1,3 @@ -"""Jira tools for creating, updating, and deleting issues.""" +"""Jira route: native tool factories are empty; MCP supplies tools when configured.""" -from .create_issue import create_create_jira_issue_tool -from .delete_issue import create_delete_jira_issue_tool -from .update_issue import create_update_jira_issue_tool - -__all__ = [ - "create_create_jira_issue_tool", - "create_delete_jira_issue_tool", - "create_update_jira_issue_tool", -] +__all__: list[str] = [] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/index.py index 342f120be..08b0e005e 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/index.py @@ -6,29 +6,9 @@ from app.agents.multi_agent_chat.subagents.shared.permissions import ( ToolsPermissions, ) -from .create_issue import create_create_jira_issue_tool -from .delete_issue import create_delete_jira_issue_tool -from .update_issue import create_update_jira_issue_tool - def load_tools( *, dependencies: dict[str, Any] | None = None, **kwargs: Any ) -> ToolsPermissions: - d = {**(dependencies or {}), **kwargs} - common = { - "db_session": d["db_session"], - "search_space_id": d["search_space_id"], - "user_id": d["user_id"], - "connector_id": d.get("connector_id"), - } - create = create_create_jira_issue_tool(**common) - update = create_update_jira_issue_tool(**common) - delete = create_delete_jira_issue_tool(**common) - return { - "allow": [], - "ask": [ - {"name": getattr(create, "name", "") or "", "tool": create}, - {"name": getattr(update, "name", "") or "", "tool": update}, - {"name": getattr(delete, "name", "") or "", "tool": delete}, - ], - } + _ = {**(dependencies or {}), **kwargs} + return {"allow": [], "ask": []} diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/__init__.py index 31acf1e2a..5b464a9df 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/__init__.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/__init__.py @@ -1,11 +1,3 @@ -"""Linear tools for creating, updating, and deleting issues.""" +"""Linear route: native tool factories are empty; MCP supplies tools when configured.""" -from .create_issue import create_create_linear_issue_tool -from .delete_issue import create_delete_linear_issue_tool -from .update_issue import create_update_linear_issue_tool - -__all__ = [ - "create_create_linear_issue_tool", - "create_delete_linear_issue_tool", - "create_update_linear_issue_tool", -] +__all__: list[str] = [] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/create_issue.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/create_issue.py deleted file mode 100644 index ff254e133..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/create_issue.py +++ /dev/null @@ -1,248 +0,0 @@ -import logging -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.new_chat.tools.hitl import request_approval -from app.connectors.linear_connector import LinearAPIError, LinearConnector -from app.services.linear import LinearToolMetadataService - -logger = logging.getLogger(__name__) - - -def create_create_linear_issue_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, - connector_id: int | None = None, -): - """ - Factory function to create the create_linear_issue tool. - - Args: - db_session: Database session for accessing the Linear connector - search_space_id: Search space ID to find the Linear connector - user_id: User ID for fetching user-specific context - connector_id: Optional specific connector ID (if known) - - Returns: - Configured create_linear_issue tool - """ - - @tool - async def create_linear_issue( - title: str, - description: str | None = None, - ) -> dict[str, Any]: - """Create a new issue in Linear. - - Use this tool when the user explicitly asks to create, add, or file - a new issue / ticket / task in Linear. The user MUST describe the issue - before you call this tool. If the request is vague, ask what the issue - should be about. Never call this tool without a clear topic from the user. - - Args: - title: Short, descriptive issue title. Infer from the user's request. - description: Optional markdown body for the issue. Generate from context. - - Returns: - Dictionary with: - - status: "success", "rejected", or "error" - - issue_id: Linear issue UUID (if success) - - identifier: Human-readable ID like "ENG-42" (if success) - - url: URL to the created issue (if success) - - message: Result message - - IMPORTANT: If status is "rejected", the user explicitly declined the action. - Respond with a brief acknowledgment (e.g., "Understood, I won't create the issue.") - and move on. Do NOT retry, troubleshoot, or suggest alternatives. - - Examples: - - "Create a Linear issue for the login bug" - - "File a ticket about the payment timeout problem" - - "Add an issue for the broken search feature" - """ - logger.info(f"create_linear_issue called: title='{title}'") - - if db_session is None or search_space_id is None or user_id is None: - logger.error( - "Linear tool not properly configured - missing required parameters" - ) - return { - "status": "error", - "message": "Linear tool not properly configured. Please contact support.", - } - - try: - metadata_service = LinearToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - logger.error(f"Failed to fetch creation context: {context['error']}") - return {"status": "error", "message": context["error"]} - - workspaces = context.get("workspaces", []) - if workspaces and all(w.get("auth_expired") for w in workspaces): - logger.warning("All Linear accounts have expired authentication") - return { - "status": "auth_error", - "message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "linear", - } - - logger.info(f"Requesting approval for creating Linear issue: '{title}'") - result = request_approval( - action_type="linear_issue_creation", - tool_name="create_linear_issue", - params={ - "title": title, - "description": description, - "team_id": None, - "state_id": None, - "assignee_id": None, - "priority": None, - "label_ids": [], - "connector_id": connector_id, - }, - context=context, - ) - - if result.rejected: - logger.info("Linear issue creation rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_title = result.params.get("title", title) - final_description = result.params.get("description", description) - final_team_id = result.params.get("team_id") - final_state_id = result.params.get("state_id") - final_assignee_id = result.params.get("assignee_id") - final_priority = result.params.get("priority") - final_label_ids = result.params.get("label_ids") or [] - final_connector_id = result.params.get("connector_id", connector_id) - - if not final_title or not final_title.strip(): - logger.error("Title is empty or contains only whitespace") - return {"status": "error", "message": "Issue title cannot be empty."} - if not final_team_id: - return { - "status": "error", - "message": "A team must be selected to create an issue.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - actual_connector_id = final_connector_id - if actual_connector_id is None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.LINEAR_CONNECTOR, - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Linear connector found. Please connect Linear in your workspace settings.", - } - actual_connector_id = connector.id - logger.info(f"Found Linear connector: id={actual_connector_id}") - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == actual_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.LINEAR_CONNECTOR, - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Linear connector is invalid or has been disconnected.", - } - logger.info(f"Validated Linear connector: id={actual_connector_id}") - - logger.info( - f"Creating Linear issue with final params: title='{final_title}'" - ) - linear_client = LinearConnector( - session=db_session, connector_id=actual_connector_id - ) - result = await linear_client.create_issue( - team_id=final_team_id, - title=final_title, - description=final_description, - state_id=final_state_id, - assignee_id=final_assignee_id, - priority=final_priority, - label_ids=final_label_ids if final_label_ids else None, - ) - - if result.get("status") == "error": - logger.error(f"Failed to create Linear issue: {result.get('message')}") - return {"status": "error", "message": result.get("message")} - - logger.info( - f"Linear issue created: {result.get('identifier')} - {result.get('title')}" - ) - - kb_message_suffix = "" - try: - from app.services.linear import LinearKBSyncService - - kb_service = LinearKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - issue_id=result.get("id"), - issue_identifier=result.get("identifier", ""), - issue_title=result.get("title", final_title), - issue_url=result.get("url"), - description=final_description, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." - else: - kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." - - return { - "status": "success", - "issue_id": result.get("id"), - "identifier": result.get("identifier"), - "url": result.get("url"), - "message": (result.get("message", "") + kb_message_suffix), - } - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - - logger.error(f"Error creating Linear issue: {e}", exc_info=True) - if isinstance(e, ValueError | LinearAPIError): - message = str(e) - else: - message = ( - "Something went wrong while creating the issue. Please try again." - ) - return {"status": "error", "message": message} - - return create_linear_issue diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/delete_issue.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/delete_issue.py deleted file mode 100644 index 29ef0cdf2..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/delete_issue.py +++ /dev/null @@ -1,245 +0,0 @@ -import logging -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.new_chat.tools.hitl import request_approval -from app.connectors.linear_connector import LinearAPIError, LinearConnector -from app.services.linear import LinearToolMetadataService - -logger = logging.getLogger(__name__) - - -def create_delete_linear_issue_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, - connector_id: int | None = None, -): - """ - Factory function to create the delete_linear_issue tool. - - Args: - db_session: Database session for accessing the Linear connector - search_space_id: Search space ID to find the Linear connector - user_id: User ID for finding the correct Linear connector - connector_id: Optional specific connector ID (if known) - - Returns: - Configured delete_linear_issue tool - """ - - @tool - async def delete_linear_issue( - issue_ref: str, - delete_from_kb: bool = False, - ) -> dict[str, Any]: - """Archive (delete) a Linear issue. - - Use this tool when the user asks to delete, remove, or archive a Linear issue. - Note that Linear archives issues rather than permanently deleting them - (they can be restored from the archive). - - - Args: - issue_ref: The issue to delete. Can be the issue title (e.g. "Fix login bug"), - the identifier (e.g. "ENG-42"), or the full document title - (e.g. "ENG-42: Fix login bug"). - delete_from_kb: Whether to also remove the issue from the knowledge base. - Default is False. Set to True to remove from both Linear - and the knowledge base. - - Returns: - Dictionary with: - - status: "success", "rejected", "not_found", or "error" - - identifier: Human-readable ID like "ENG-42" (if success) - - message: Success or error message - - deleted_from_kb: Whether the issue was also removed from the knowledge base (if success) - - IMPORTANT: - - If status is "rejected", the user explicitly declined the action. - Respond with a brief acknowledgment (e.g., "Understood, I won't delete the issue.") - and move on. Do NOT ask for alternatives or troubleshoot. - - If status is "not_found", inform the user conversationally using the exact message - provided. Do NOT treat this as an error. Simply relay the message and ask the user - to verify the issue title or identifier, or check if it has been indexed. - Examples: - - "Delete the 'Fix login bug' Linear issue" - - "Archive ENG-42" - - "Remove the 'Old payment flow' issue from Linear" - """ - logger.info( - f"delete_linear_issue called: issue_ref='{issue_ref}', delete_from_kb={delete_from_kb}" - ) - - if db_session is None or search_space_id is None or user_id is None: - logger.error( - "Linear tool not properly configured - missing required parameters" - ) - return { - "status": "error", - "message": "Linear tool not properly configured. Please contact support.", - } - - try: - metadata_service = LinearToolMetadataService(db_session) - context = await metadata_service.get_delete_context( - search_space_id, user_id, issue_ref - ) - - if "error" in context: - error_msg = context["error"] - if context.get("auth_expired"): - logger.warning(f"Auth expired for delete context: {error_msg}") - return { - "status": "auth_error", - "message": error_msg, - "connector_id": context.get("connector_id"), - "connector_type": "linear", - } - if "not found" in error_msg.lower(): - logger.warning(f"Issue not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - else: - logger.error(f"Failed to fetch delete context: {error_msg}") - return {"status": "error", "message": error_msg} - - issue_id = context["issue"]["id"] - issue_identifier = context["issue"].get("identifier", "") - document_id = context["issue"]["document_id"] - connector_id_from_context = context.get("workspace", {}).get("id") - - logger.info( - f"Requesting approval for deleting Linear issue: '{issue_ref}' " - f"(id={issue_id}, delete_from_kb={delete_from_kb})" - ) - result = request_approval( - action_type="linear_issue_deletion", - tool_name="delete_linear_issue", - params={ - "issue_id": issue_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) - - if result.rejected: - logger.info("Linear issue deletion rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_issue_id = result.params.get("issue_id", issue_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - logger.info( - f"Deleting Linear issue with final params: issue_id={final_issue_id}, " - f"connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}" - ) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - if final_connector_id: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.LINEAR_CONNECTOR, - ) - ) - connector = result.scalars().first() - if not connector: - logger.error( - f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" - ) - return { - "status": "error", - "message": "Selected Linear connector is invalid or has been disconnected.", - } - actual_connector_id = connector.id - logger.info(f"Validated Linear connector: id={actual_connector_id}") - else: - logger.error("No connector found for this issue") - return { - "status": "error", - "message": "No connector found for this issue.", - } - - linear_client = LinearConnector( - session=db_session, connector_id=actual_connector_id - ) - - result = await linear_client.archive_issue(issue_id=final_issue_id) - - logger.info( - f"archive_issue result: {result.get('status')} - {result.get('message', '')}" - ) - - deleted_from_kb = False - if ( - result.get("status") == "success" - and final_delete_from_kb - and document_id - ): - try: - from app.db import Document - - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - result["warning"] = ( - f"Issue archived in Linear, but failed to remove from knowledge base: {e!s}" - ) - - if result.get("status") == "success": - result["deleted_from_kb"] = deleted_from_kb - if issue_identifier: - result["message"] = ( - f"Issue {issue_identifier} archived successfully." - ) - if deleted_from_kb: - result["message"] = ( - f"{result.get('message', '')} Also removed from the knowledge base." - ) - - return result - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - - logger.error(f"Error deleting Linear issue: {e}", exc_info=True) - if isinstance(e, ValueError | LinearAPIError): - message = str(e) - else: - message = ( - "Something went wrong while deleting the issue. Please try again." - ) - return {"status": "error", "message": message} - - return delete_linear_issue diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/index.py index f1ee49964..08b0e005e 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/index.py @@ -6,29 +6,9 @@ from app.agents.multi_agent_chat.subagents.shared.permissions import ( ToolsPermissions, ) -from .create_issue import create_create_linear_issue_tool -from .delete_issue import create_delete_linear_issue_tool -from .update_issue import create_update_linear_issue_tool - def load_tools( *, dependencies: dict[str, Any] | None = None, **kwargs: Any ) -> ToolsPermissions: - d = {**(dependencies or {}), **kwargs} - common = { - "db_session": d["db_session"], - "search_space_id": d["search_space_id"], - "user_id": d["user_id"], - "connector_id": d.get("connector_id"), - } - create = create_create_linear_issue_tool(**common) - update = create_update_linear_issue_tool(**common) - delete = create_delete_linear_issue_tool(**common) - return { - "allow": [], - "ask": [ - {"name": getattr(create, "name", "") or "", "tool": create}, - {"name": getattr(update, "name", "") or "", "tool": update}, - {"name": getattr(delete, "name", "") or "", "tool": delete}, - ], - } + _ = {**(dependencies or {}), **kwargs} + return {"allow": [], "ask": []} diff --git a/surfsense_backend/app/agents/new_chat/feature_flags.py b/surfsense_backend/app/agents/new_chat/feature_flags.py index b3dc0fa82..3cea051ef 100644 --- a/surfsense_backend/app/agents/new_chat/feature_flags.py +++ b/surfsense_backend/app/agents/new_chat/feature_flags.py @@ -28,7 +28,6 @@ Defaults: SURFSENSE_ENABLE_PERMISSION=true SURFSENSE_ENABLE_DOOM_LOOP=true SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call - SURFSENSE_ENABLE_STREAM_PARITY_V2=true Master kill-switch (overrides everything else): @@ -88,15 +87,6 @@ class AgentFeatureFlags: enable_action_log: bool = True enable_revert_route: bool = True - # Streaming parity v2 — opt in to LangChain's structured - # ``AIMessageChunk`` content (typed reasoning blocks, tool-input - # deltas) and propagate the real ``tool_call_id`` to the SSE layer. - # When OFF the ``stream_new_chat`` task falls back to the str-only - # text path and the synthetic ``call_`` tool-call id (no - # ``langchainToolCallId`` propagation). Schema migrations 135/136 - # ship unconditionally because they're forward-compatible. - enable_stream_parity_v2: bool = True - # Plugins enable_plugin_loader: bool = False @@ -169,7 +159,6 @@ class AgentFeatureFlags: enable_kb_planner_runnable=False, enable_action_log=False, enable_revert_route=False, - enable_stream_parity_v2=False, enable_plugin_loader=False, enable_otel=False, enable_agent_cache=False, @@ -208,10 +197,6 @@ class AgentFeatureFlags: # Snapshot / revert enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", True), enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", True), - # Streaming parity v2 - enable_stream_parity_v2=_env_bool( - "SURFSENSE_ENABLE_STREAM_PARITY_V2", True - ), # Plugins enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False), # Observability diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index ad96654f5..743b5b849 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -71,7 +71,10 @@ from app.schemas.new_chat import ( TokenUsageSummary, TurnStatusResponse, ) -from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat +from app.tasks.chat.stream_new_chat import ( + stream_new_chat, + stream_resume_chat, +) from app.users import current_active_user from app.utils.perf import get_perf_logger from app.utils.rbac import check_permission diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index 95d183433..fe8dab076 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -380,7 +380,7 @@ class ResumeRequest(BaseModel): "/regenerate. Resume reuses the original interrupted user " "turn so the server does not write a new user message. " "Currently unused but accepted to keep request bodies " - "uniform across the three streaming entrypoints." + "uniform across new-message, regenerate, and resume stream routes." ), ) diff --git a/surfsense_backend/app/services/new_streaming_service.py b/surfsense_backend/app/services/new_streaming_service.py index 55129668c..ba0cb8753 100644 --- a/surfsense_backend/app/services/new_streaming_service.py +++ b/surfsense_backend/app/services/new_streaming_service.py @@ -456,6 +456,8 @@ class VercelStreamingService: title: str, status: str = "in_progress", items: list[str] | None = None, + *, + metadata: dict[str, Any] | None = None, ) -> str: """ Format a thinking step for chain-of-thought display (SurfSense specific). @@ -469,15 +471,15 @@ class VercelStreamingService: Returns: str: SSE formatted thinking step data part """ - return self.format_data( - "thinking-step", - { - "id": step_id, - "title": title, - "status": status, - "items": items or [], - }, - ) + payload: dict[str, Any] = { + "id": step_id, + "title": title, + "status": status, + "items": items or [], + } + if metadata: + payload["metadata"] = metadata + return self.format_data("thinking-step", payload) def format_thread_title_update(self, thread_id: int, title: str) -> str: """ @@ -601,6 +603,7 @@ class VercelStreamingService: tool_name: str, *, langchain_tool_call_id: str | None = None, + metadata: dict[str, Any] | None = None, ) -> str: """ Format the start of tool input streaming. @@ -608,15 +611,14 @@ class VercelStreamingService: Args: tool_call_id: The unique tool call identifier. May be EITHER the synthetic ``call_`` id derived from LangGraph - ``run_id`` (legacy / ``SURFSENSE_ENABLE_STREAM_PARITY_V2`` - OFF, or the unmatched-fallback path under parity_v2) OR - the authoritative LangChain ``tool_call.id`` (parity_v2 - path: when the provider streams ``tool_call_chunks`` we - register the ``index`` and reuse the lc-id as the card - id so live ``tool-input-delta`` events can be routed - without a downstream join). Either way, the same id is - preserved across ``tool-input-start`` / ``-delta`` / - ``-available`` / ``tool-output-available`` for one call. + ``run_id`` (unmatched chunk fallback when no ``index`` was + registered) OR the authoritative LangChain ``tool_call.id`` + (when the provider streams ``tool_call_chunks`` we register + the ``index`` and reuse the lc-id as the card id so live + ``tool-input-delta`` events route without a downstream join). + Either way, the same id is preserved across + ``tool-input-start`` / ``-delta`` / ``-available`` / + ``tool-output-available`` for one call. tool_name: The name of the tool being called. langchain_tool_call_id: Optional authoritative LangChain ``tool_call.id``. When set, surfaces as @@ -636,6 +638,8 @@ class VercelStreamingService: } if langchain_tool_call_id: payload["langchainToolCallId"] = langchain_tool_call_id + if metadata: + payload["metadata"] = metadata return self._format_sse(payload) def format_tool_input_delta(self, tool_call_id: str, input_text_delta: str) -> str: @@ -667,6 +671,7 @@ class VercelStreamingService: input_data: dict[str, Any], *, langchain_tool_call_id: str | None = None, + metadata: dict[str, Any] | None = None, ) -> str: """ Format the completion of tool input. @@ -692,6 +697,8 @@ class VercelStreamingService: } if langchain_tool_call_id: payload["langchainToolCallId"] = langchain_tool_call_id + if metadata: + payload["metadata"] = metadata return self._format_sse(payload) def format_tool_output_available( @@ -700,6 +707,7 @@ class VercelStreamingService: output: Any, *, langchain_tool_call_id: str | None = None, + metadata: dict[str, Any] | None = None, ) -> str: """ Format tool execution output. @@ -726,6 +734,8 @@ class VercelStreamingService: } if langchain_tool_call_id: payload["langchainToolCallId"] = langchain_tool_call_id + if metadata: + payload["metadata"] = metadata return self._format_sse(payload) # ========================================================================= diff --git a/surfsense_backend/app/services/streaming/__init__.py b/surfsense_backend/app/services/streaming/__init__.py new file mode 100644 index 000000000..3ec9b9cf1 --- /dev/null +++ b/surfsense_backend/app/services/streaming/__init__.py @@ -0,0 +1,20 @@ +"""Single-responsibility split of the streaming SSE protocol. + +Layout: +* ``envelope/`` - SSE wire framing + ID generators +* ``emitter/`` - identity of the agent that emitted an event + runtime registry +* ``events/`` - one module per SSE event family +* ``service.py`` - composition root used when emitting chat SSE +* ``interrupt_correlation.py`` - id-aware lookup over LangGraph state + +Naming on the wire: +* AI SDK protocol fields keep their existing camelCase + (``toolCallId``, ``messageId``, ``inputTextDelta``, ``langchainToolCallId``). +* Every SurfSense-added field uses ``snake_case``, including the + top-level ``emitted_by`` envelope and all inner ``data`` payloads. + +Production chat uses ``app.services.new_streaming_service`` from +``app.tasks.chat.stream_new_chat`` and related routes. +""" + +from __future__ import annotations diff --git a/surfsense_backend/app/services/streaming/emitter/__init__.py b/surfsense_backend/app/services/streaming/emitter/__init__.py new file mode 100644 index 000000000..7814894f3 --- /dev/null +++ b/surfsense_backend/app/services/streaming/emitter/__init__.py @@ -0,0 +1,29 @@ +"""Identity of the agent that emitted a streamed event. + +The wire field is ``emitted_by``; the Python identity is :class:`Emitter`. +``EmitterRegistry`` resolves which emitter owns a LangGraph event, with +LangGraph's own namespace metadata as the primary key and a parent_ids +walk as a fallback for cases where context vars don't propagate. +""" + +from __future__ import annotations + +from .emitter import ( + MAIN_EMITTER, + Emitter, + EmitterLevel, + attach_emitted_by, + main_emitter, + subagent_emitter, +) +from .registry import EmitterRegistry + +__all__ = [ + "MAIN_EMITTER", + "Emitter", + "EmitterLevel", + "EmitterRegistry", + "attach_emitted_by", + "main_emitter", + "subagent_emitter", +] diff --git a/surfsense_backend/app/services/streaming/emitter/emitter.py b/surfsense_backend/app/services/streaming/emitter/emitter.py new file mode 100644 index 000000000..08f625a69 --- /dev/null +++ b/surfsense_backend/app/services/streaming/emitter/emitter.py @@ -0,0 +1,61 @@ +"""Identity payload describing which agent produced a stream event.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Literal + +EmitterLevel = Literal["main", "subagent"] + + +@dataclass(frozen=True) +class Emitter: + level: EmitterLevel + subagent_type: str | None = None + subagent_run_id: str | None = None + parent_tool_call_id: str | None = None + extra: dict[str, Any] = field(default_factory=dict) + + def to_payload(self) -> dict[str, Any]: + payload: dict[str, Any] = {"level": self.level} + if self.subagent_type is not None: + payload["subagent_type"] = self.subagent_type + if self.subagent_run_id is not None: + payload["subagent_run_id"] = self.subagent_run_id + if self.parent_tool_call_id is not None: + payload["parent_tool_call_id"] = self.parent_tool_call_id + if self.extra: + payload.update(self.extra) + return payload + + +MAIN_EMITTER = Emitter(level="main") + + +def main_emitter() -> Emitter: + return MAIN_EMITTER + + +def subagent_emitter( + *, + subagent_type: str, + subagent_run_id: str, + parent_tool_call_id: str | None = None, + extra: dict[str, Any] | None = None, +) -> Emitter: + return Emitter( + level="subagent", + subagent_type=subagent_type, + subagent_run_id=subagent_run_id, + parent_tool_call_id=parent_tool_call_id, + extra=dict(extra or {}), + ) + + +def attach_emitted_by( + payload: dict[str, Any], emitter: Emitter | None +) -> dict[str, Any]: + if emitter is None: + return payload + payload["emitted_by"] = emitter.to_payload() + return payload diff --git a/surfsense_backend/app/services/streaming/emitter/registry.py b/surfsense_backend/app/services/streaming/emitter/registry.py new file mode 100644 index 000000000..cd3e10cdd --- /dev/null +++ b/surfsense_backend/app/services/streaming/emitter/registry.py @@ -0,0 +1,51 @@ +"""Resolve which agent owns a streamed event from its LangGraph run lineage.""" + +from __future__ import annotations + +from collections.abc import Iterable + +from .emitter import Emitter, main_emitter + + +class EmitterRegistry: + def __init__(self) -> None: + self._by_run_id: dict[str, Emitter] = {} + + def register(self, run_id: str, emitter: Emitter) -> None: + if not run_id: + return + self._by_run_id[run_id] = emitter + + def unregister(self, run_id: str) -> Emitter | None: + if not run_id: + return None + return self._by_run_id.pop(run_id, None) + + def get(self, run_id: str | None) -> Emitter | None: + if not run_id: + return None + return self._by_run_id.get(run_id) + + def resolve( + self, + *, + run_id: str | None, + parent_ids: Iterable[str] | None, + ) -> Emitter: + own = self.get(run_id) + if own is not None: + return own + if parent_ids: + for ancestor in reversed(list(parent_ids)): + emitter = self.get(ancestor) + if emitter is not None: + return emitter + return main_emitter() + + def has_active_subagents(self) -> bool: + return any( + emitter.level == "subagent" for emitter in self._by_run_id.values() + ) + + def clear(self) -> None: + self._by_run_id.clear() diff --git a/surfsense_backend/app/services/streaming/envelope/__init__.py b/surfsense_backend/app/services/streaming/envelope/__init__.py new file mode 100644 index 000000000..862e84c8d --- /dev/null +++ b/surfsense_backend/app/services/streaming/envelope/__init__.py @@ -0,0 +1,23 @@ +"""Wire framing layer.""" + +from __future__ import annotations + +from .identifiers import ( + generate_message_id, + generate_reasoning_id, + generate_subagent_run_id, + generate_text_id, + generate_tool_call_id, +) +from .sse import format_done, format_sse, get_response_headers + +__all__ = [ + "format_done", + "format_sse", + "generate_message_id", + "generate_reasoning_id", + "generate_subagent_run_id", + "generate_text_id", + "generate_tool_call_id", + "get_response_headers", +] diff --git a/surfsense_backend/app/services/streaming/envelope/identifiers.py b/surfsense_backend/app/services/streaming/envelope/identifiers.py new file mode 100644 index 000000000..2fdd6ff09 --- /dev/null +++ b/surfsense_backend/app/services/streaming/envelope/identifiers.py @@ -0,0 +1,25 @@ +"""Prefixed UUID generators for stream parts.""" + +from __future__ import annotations + +import uuid + + +def generate_message_id() -> str: + return f"msg_{uuid.uuid4().hex}" + + +def generate_text_id() -> str: + return f"text_{uuid.uuid4().hex}" + + +def generate_reasoning_id() -> str: + return f"reasoning_{uuid.uuid4().hex}" + + +def generate_tool_call_id() -> str: + return f"call_{uuid.uuid4().hex}" + + +def generate_subagent_run_id() -> str: + return f"subagent_{uuid.uuid4().hex}" diff --git a/surfsense_backend/app/services/streaming/envelope/sse.py b/surfsense_backend/app/services/streaming/envelope/sse.py new file mode 100644 index 000000000..508fc1b1c --- /dev/null +++ b/surfsense_backend/app/services/streaming/envelope/sse.py @@ -0,0 +1,25 @@ +"""Server-Sent-Events wire framing.""" + +from __future__ import annotations + +import json +from typing import Any + + +def format_sse(data: Any) -> str: + if isinstance(data, str): + return f"data: {data}\n\n" + return f"data: {json.dumps(data)}\n\n" + + +def format_done() -> str: + return "data: [DONE]\n\n" + + +def get_response_headers() -> dict[str, str]: + return { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "x-vercel-ai-ui-message-stream": "v1", + } diff --git a/surfsense_backend/app/services/streaming/events/__init__.py b/surfsense_backend/app/services/streaming/events/__init__.py new file mode 100644 index 000000000..91a8ff854 --- /dev/null +++ b/surfsense_backend/app/services/streaming/events/__init__.py @@ -0,0 +1,29 @@ +"""SSE event payload formatters, one module per event family.""" + +from __future__ import annotations + +from . import ( + action_log, + data, + error, + interrupt, + lifecycle, + reasoning, + source, + subagent_lifecycle, + text, + tool, +) + +__all__ = [ + "action_log", + "data", + "error", + "interrupt", + "lifecycle", + "reasoning", + "source", + "subagent_lifecycle", + "text", + "tool", +] diff --git a/surfsense_backend/app/services/streaming/events/action_log.py b/surfsense_backend/app/services/streaming/events/action_log.py new file mode 100644 index 000000000..0a8e46f0a --- /dev/null +++ b/surfsense_backend/app/services/streaming/events/action_log.py @@ -0,0 +1,24 @@ +"""Action-log events relayed from ``ActionLogMiddleware`` custom dispatches.""" + +from __future__ import annotations + +from typing import Any + +from ..emitter import Emitter +from .data import format_data + + +def format_action_log( + payload: dict[str, Any], + *, + emitter: Emitter | None = None, +) -> str: + return format_data("action-log", payload, emitter=emitter) + + +def format_action_log_updated( + payload: dict[str, Any], + *, + emitter: Emitter | None = None, +) -> str: + return format_data("action-log-updated", payload, emitter=emitter) diff --git a/surfsense_backend/app/services/streaming/events/data.py b/surfsense_backend/app/services/streaming/events/data.py new file mode 100644 index 000000000..f6e190578 --- /dev/null +++ b/surfsense_backend/app/services/streaming/events/data.py @@ -0,0 +1,118 @@ +"""Generic ``data-*`` envelopes and SurfSense-specific data parts. + +Inner ``data`` dict fields use snake_case. Legacy ``threadId`` / +``messageId`` keys are preserved where they cross the AI SDK boundary. +""" + +from __future__ import annotations + +from typing import Any + +from ..emitter import Emitter, attach_emitted_by +from ..envelope import format_sse + + +def format_data( + data_type: str, + data: Any, + *, + emitter: Emitter | None = None, +) -> str: + payload: dict[str, Any] = {"type": f"data-{data_type}", "data": data} + return format_sse(attach_emitted_by(payload, emitter)) + + +def format_terminal_info( + text: str, + *, + message_type: str = "info", + emitter: Emitter | None = None, +) -> str: + return format_data( + "terminal-info", + {"text": text, "type": message_type}, + emitter=emitter, + ) + + +def format_further_questions( + questions: list[str], + *, + emitter: Emitter | None = None, +) -> str: + return format_data("further-questions", {"questions": questions}, emitter=emitter) + + +def format_thinking_step( + *, + step_id: str, + title: str, + status: str = "in_progress", + items: list[str] | None = None, + emitter: Emitter | None = None, +) -> str: + return format_data( + "thinking-step", + { + "id": step_id, + "title": title, + "status": status, + "items": items or [], + }, + emitter=emitter, + ) + + +def format_thread_title_update( + *, + thread_id: int, + title: str, + emitter: Emitter | None = None, +) -> str: + return format_data( + "thread-title-update", + {"threadId": thread_id, "title": title}, + emitter=emitter, + ) + + +def format_turn_info( + *, + chat_turn_id: str, + emitter: Emitter | None = None, +) -> str: + return format_data("turn-info", {"chat_turn_id": chat_turn_id}, emitter=emitter) + + +def format_turn_status( + *, + status: str, + emitter: Emitter | None = None, +) -> str: + return format_data("turn-status", {"status": status}, emitter=emitter) + + +def format_user_message_id( + *, + message_id: str, + turn_id: str, + emitter: Emitter | None = None, +) -> str: + return format_data( + "user-message-id", + {"message_id": message_id, "turn_id": turn_id}, + emitter=emitter, + ) + + +def format_assistant_message_id( + *, + message_id: str, + turn_id: str, + emitter: Emitter | None = None, +) -> str: + return format_data( + "assistant-message-id", + {"message_id": message_id, "turn_id": turn_id}, + emitter=emitter, + ) diff --git a/surfsense_backend/app/services/streaming/events/error.py b/surfsense_backend/app/services/streaming/events/error.py new file mode 100644 index 000000000..a1e8e01ca --- /dev/null +++ b/surfsense_backend/app/services/streaming/events/error.py @@ -0,0 +1,23 @@ +"""Single terminal error path chat streaming must route through.""" + +from __future__ import annotations + +from typing import Any + +from ..emitter import Emitter, attach_emitted_by +from ..envelope import format_sse + + +def format_error( + error_text: str, + *, + error_code: str | None = None, + extra: dict[str, Any] | None = None, + emitter: Emitter | None = None, +) -> str: + payload: dict[str, Any] = {"type": "error", "errorText": error_text} + if error_code: + payload["errorCode"] = error_code + if extra: + payload.update(extra) + return format_sse(attach_emitted_by(payload, emitter)) diff --git a/surfsense_backend/app/services/streaming/events/interrupt.py b/surfsense_backend/app/services/streaming/events/interrupt.py new file mode 100644 index 000000000..0334b10b3 --- /dev/null +++ b/surfsense_backend/app/services/streaming/events/interrupt.py @@ -0,0 +1,56 @@ +"""Interrupt-request events with a single canonical payload shape.""" + +from __future__ import annotations + +from typing import Any + +from ..emitter import Emitter +from .data import format_data + + +def normalize_interrupt_payload(interrupt_value: dict[str, Any]) -> dict[str, Any]: + if "action_requests" in interrupt_value and "review_configs" in interrupt_value: + return interrupt_value + + interrupt_type = interrupt_value.get("type", "unknown") + message = interrupt_value.get("message") + action = interrupt_value.get("action", {}) or {} + context = interrupt_value.get("context", {}) or {} + + normalized: dict[str, Any] = { + "action_requests": [ + { + "name": action.get("tool", "unknown_tool"), + "args": action.get("params", {}), + } + ], + "review_configs": [ + { + "action_name": action.get("tool", "unknown_tool"), + "allowed_decisions": ["approve", "edit", "reject"], + } + ], + "interrupt_type": interrupt_type, + "context": context, + } + if message: + normalized["message"] = message + return normalized + + +def format_interrupt_request( + interrupt_value: dict[str, Any], + *, + interrupt_id: str | None = None, + pending_interrupt_count: int | None = None, + chat_turn_id: str | None = None, + emitter: Emitter | None = None, +) -> str: + payload = normalize_interrupt_payload(interrupt_value) + if interrupt_id is not None: + payload["interrupt_id"] = interrupt_id + if pending_interrupt_count is not None: + payload["pending_interrupt_count"] = pending_interrupt_count + if chat_turn_id is not None: + payload["chat_turn_id"] = chat_turn_id + return format_data("interrupt-request", payload, emitter=emitter) diff --git a/surfsense_backend/app/services/streaming/events/lifecycle.py b/surfsense_backend/app/services/streaming/events/lifecycle.py new file mode 100644 index 000000000..019718b67 --- /dev/null +++ b/surfsense_backend/app/services/streaming/events/lifecycle.py @@ -0,0 +1,29 @@ +"""High-level message and step lifecycle events. + +Wire verbs are fixed by the AI SDK protocol (``start`` / ``finish`` for +the whole message, ``start-step`` / ``finish-step`` for each step). +Python helpers always read ``format__`` so pairs are +visible at the call site. +""" + +from __future__ import annotations + +from ..emitter import Emitter, attach_emitted_by +from ..envelope import format_sse + + +def format_message_start(message_id: str, *, emitter: Emitter | None = None) -> str: + payload = {"type": "start", "messageId": message_id} + return format_sse(attach_emitted_by(payload, emitter)) + + +def format_message_finish(*, emitter: Emitter | None = None) -> str: + return format_sse(attach_emitted_by({"type": "finish"}, emitter)) + + +def format_step_start(*, emitter: Emitter | None = None) -> str: + return format_sse(attach_emitted_by({"type": "start-step"}, emitter)) + + +def format_step_finish(*, emitter: Emitter | None = None) -> str: + return format_sse(attach_emitted_by({"type": "finish-step"}, emitter)) diff --git a/surfsense_backend/app/services/streaming/events/reasoning.py b/surfsense_backend/app/services/streaming/events/reasoning.py new file mode 100644 index 000000000..5b912d43a --- /dev/null +++ b/surfsense_backend/app/services/streaming/events/reasoning.py @@ -0,0 +1,36 @@ +"""Reasoning-block streaming events.""" + +from __future__ import annotations + +from ..emitter import Emitter, attach_emitted_by +from ..envelope import format_sse + + +def format_reasoning_start( + reasoning_id: str, *, emitter: Emitter | None = None +) -> str: + return format_sse( + attach_emitted_by({"type": "reasoning-start", "id": reasoning_id}, emitter) + ) + + +def format_reasoning_delta( + reasoning_id: str, + delta: str, + *, + emitter: Emitter | None = None, +) -> str: + return format_sse( + attach_emitted_by( + {"type": "reasoning-delta", "id": reasoning_id, "delta": delta}, + emitter, + ) + ) + + +def format_reasoning_end( + reasoning_id: str, *, emitter: Emitter | None = None +) -> str: + return format_sse( + attach_emitted_by({"type": "reasoning-end", "id": reasoning_id}, emitter) + ) diff --git a/surfsense_backend/app/services/streaming/events/source.py b/surfsense_backend/app/services/streaming/events/source.py new file mode 100644 index 000000000..54541e8d2 --- /dev/null +++ b/surfsense_backend/app/services/streaming/events/source.py @@ -0,0 +1,59 @@ +"""Source and file reference events.""" + +from __future__ import annotations + +from typing import Any + +from ..emitter import Emitter, attach_emitted_by +from ..envelope import format_sse + + +def format_source_url( + url: str, + *, + source_id: str | None = None, + title: str | None = None, + emitter: Emitter | None = None, +) -> str: + payload: dict[str, Any] = { + "type": "source-url", + "sourceId": source_id or url, + "url": url, + } + if title: + payload["title"] = title + return format_sse(attach_emitted_by(payload, emitter)) + + +def format_source_document( + source_id: str, + *, + media_type: str = "file", + title: str | None = None, + description: str | None = None, + emitter: Emitter | None = None, +) -> str: + payload: dict[str, Any] = { + "type": "source-document", + "sourceId": source_id, + "mediaType": media_type, + } + if title: + payload["title"] = title + if description: + payload["description"] = description + return format_sse(attach_emitted_by(payload, emitter)) + + +def format_file( + url: str, + media_type: str, + *, + emitter: Emitter | None = None, +) -> str: + payload: dict[str, Any] = { + "type": "file", + "url": url, + "mediaType": media_type, + } + return format_sse(attach_emitted_by(payload, emitter)) diff --git a/surfsense_backend/app/services/streaming/events/subagent_lifecycle.py b/surfsense_backend/app/services/streaming/events/subagent_lifecycle.py new file mode 100644 index 000000000..6dd2d4eab --- /dev/null +++ b/surfsense_backend/app/services/streaming/events/subagent_lifecycle.py @@ -0,0 +1,86 @@ +"""Sub-agent lifecycle events the FE pairs into one timeline lane. + +A sub-agent run is a high-level boundary (a whole agent invocation), +so we use the ``start`` / ``finish`` verb pair, matching how the AI SDK +spells message- and step-level lifecycles. +""" + +from __future__ import annotations + +from typing import Any + +from ..emitter import Emitter +from .data import format_data + + +def format_subagent_start( + *, + subagent_run_id: str, + subagent_type: str, + parent_tool_call_id: str, + chat_turn_id: str | None = None, + description: str | None = None, + started_at: str | None = None, + emitter: Emitter | None = None, +) -> str: + payload: dict[str, Any] = { + "subagent_run_id": subagent_run_id, + "subagent_type": subagent_type, + "parent_tool_call_id": parent_tool_call_id, + } + if chat_turn_id is not None: + payload["chat_turn_id"] = chat_turn_id + if description is not None: + payload["description"] = description + if started_at is not None: + payload["started_at"] = started_at + return format_data("subagent-start", payload, emitter=emitter) + + +def format_subagent_finish( + *, + subagent_run_id: str, + subagent_type: str, + parent_tool_call_id: str, + status: str = "completed", + ended_at: str | None = None, + duration_ms: int | None = None, + emitter: Emitter | None = None, +) -> str: + payload: dict[str, Any] = { + "subagent_run_id": subagent_run_id, + "subagent_type": subagent_type, + "parent_tool_call_id": parent_tool_call_id, + "status": status, + } + if ended_at is not None: + payload["ended_at"] = ended_at + if duration_ms is not None: + payload["duration_ms"] = duration_ms + return format_data("subagent-finish", payload, emitter=emitter) + + +def format_subagent_error( + *, + subagent_run_id: str, + subagent_type: str, + parent_tool_call_id: str, + error_text: str, + error_type: str | None = None, + ended_at: str | None = None, + duration_ms: int | None = None, + emitter: Emitter | None = None, +) -> str: + payload: dict[str, Any] = { + "subagent_run_id": subagent_run_id, + "subagent_type": subagent_type, + "parent_tool_call_id": parent_tool_call_id, + "error_text": error_text, + } + if error_type is not None: + payload["error_type"] = error_type + if ended_at is not None: + payload["ended_at"] = ended_at + if duration_ms is not None: + payload["duration_ms"] = duration_ms + return format_data("subagent-error", payload, emitter=emitter) diff --git a/surfsense_backend/app/services/streaming/events/text.py b/surfsense_backend/app/services/streaming/events/text.py new file mode 100644 index 000000000..3baebdebb --- /dev/null +++ b/surfsense_backend/app/services/streaming/events/text.py @@ -0,0 +1,31 @@ +"""Text-block streaming events.""" + +from __future__ import annotations + +from ..emitter import Emitter, attach_emitted_by +from ..envelope import format_sse + + +def format_text_start(text_id: str, *, emitter: Emitter | None = None) -> str: + return format_sse( + attach_emitted_by({"type": "text-start", "id": text_id}, emitter) + ) + + +def format_text_delta( + text_id: str, + delta: str, + *, + emitter: Emitter | None = None, +) -> str: + return format_sse( + attach_emitted_by( + {"type": "text-delta", "id": text_id, "delta": delta}, emitter + ) + ) + + +def format_text_end(text_id: str, *, emitter: Emitter | None = None) -> str: + return format_sse( + attach_emitted_by({"type": "text-end", "id": text_id}, emitter) + ) diff --git a/surfsense_backend/app/services/streaming/events/tool.py b/surfsense_backend/app/services/streaming/events/tool.py new file mode 100644 index 000000000..c85dc061b --- /dev/null +++ b/surfsense_backend/app/services/streaming/events/tool.py @@ -0,0 +1,80 @@ +"""Tool-call streaming events. + +``toolCallId`` and ``langchainToolCallId`` are AI SDK protocol fields +and stay camelCase. Sub-agent provenance rides on the snake_case +top-level ``emitted_by`` envelope added by :func:`attach_emitted_by`. +""" + +from __future__ import annotations + +from typing import Any + +from ..emitter import Emitter, attach_emitted_by +from ..envelope import format_sse + + +def format_tool_input_start( + tool_call_id: str, + tool_name: str, + *, + langchain_tool_call_id: str | None = None, + emitter: Emitter | None = None, +) -> str: + payload: dict[str, Any] = { + "type": "tool-input-start", + "toolCallId": tool_call_id, + "toolName": tool_name, + } + if langchain_tool_call_id: + payload["langchainToolCallId"] = langchain_tool_call_id + return format_sse(attach_emitted_by(payload, emitter)) + + +def format_tool_input_delta( + tool_call_id: str, + input_text_delta: str, + *, + emitter: Emitter | None = None, +) -> str: + payload: dict[str, Any] = { + "type": "tool-input-delta", + "toolCallId": tool_call_id, + "inputTextDelta": input_text_delta, + } + return format_sse(attach_emitted_by(payload, emitter)) + + +def format_tool_input_available( + tool_call_id: str, + tool_name: str, + input_data: dict[str, Any], + *, + langchain_tool_call_id: str | None = None, + emitter: Emitter | None = None, +) -> str: + payload: dict[str, Any] = { + "type": "tool-input-available", + "toolCallId": tool_call_id, + "toolName": tool_name, + "input": input_data, + } + if langchain_tool_call_id: + payload["langchainToolCallId"] = langchain_tool_call_id + return format_sse(attach_emitted_by(payload, emitter)) + + +def format_tool_output_available( + tool_call_id: str, + output: Any, + *, + langchain_tool_call_id: str | None = None, + emitter: Emitter | None = None, +) -> str: + payload: dict[str, Any] = { + "type": "tool-output-available", + "toolCallId": tool_call_id, + "output": output, + } + if langchain_tool_call_id: + payload["langchainToolCallId"] = langchain_tool_call_id + return format_sse(attach_emitted_by(payload, emitter)) diff --git a/surfsense_backend/app/services/streaming/interrupt_correlation.py b/surfsense_backend/app/services/streaming/interrupt_correlation.py new file mode 100644 index 000000000..3045dfb6a --- /dev/null +++ b/surfsense_backend/app/services/streaming/interrupt_correlation.py @@ -0,0 +1,84 @@ +"""Id-aware lookup of pending LangGraph interrupts (replaces first-wins).""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True) +class PendingInterrupt: + interrupt_id: str | None + value: dict[str, Any] + source_task_id: str | None = None + + +def list_pending_interrupts(state: Any) -> list[PendingInterrupt]: + out: list[PendingInterrupt] = [] + + for task in getattr(state, "tasks", None) or (): + task_id = _safe_str(getattr(task, "id", None)) + for it in getattr(task, "interrupts", None) or (): + value = _coerce_interrupt_value(it) + if value is None: + continue + interrupt_id = _safe_str(getattr(it, "id", None)) + out.append( + PendingInterrupt( + interrupt_id=interrupt_id, + value=value, + source_task_id=task_id, + ) + ) + + for it in getattr(state, "interrupts", None) or (): + value = _coerce_interrupt_value(it) + if value is None: + continue + interrupt_id = _safe_str(getattr(it, "id", None)) + out.append(PendingInterrupt(interrupt_id=interrupt_id, value=value)) + + return out + + +def get_pending_interrupt_by_id( + state: Any, interrupt_id: str +) -> PendingInterrupt | None: + for pending in list_pending_interrupts(state): + if pending.interrupt_id == interrupt_id: + return pending + return None + + +def get_pending_interrupt_for_tool_call( + state: Any, tool_call_id: str +) -> PendingInterrupt | None: + for pending in list_pending_interrupts(state): + actions = pending.value.get("action_requests") + if not isinstance(actions, list): + continue + for action in actions: + if not isinstance(action, dict): + continue + if action.get("tool_call_id") == tool_call_id: + return pending + return None + + +def first_pending_interrupt(state: Any) -> PendingInterrupt | None: + """Explicit opt-in to legacy first-wins; prefer the id-aware helpers above.""" + pending = list_pending_interrupts(state) + return pending[0] if pending else None + + +def _coerce_interrupt_value(item: Any) -> dict[str, Any] | None: + if isinstance(item, dict): + return item if item else None + value = getattr(item, "value", None) + if isinstance(value, dict): + return value if value else None + return None + + +def _safe_str(value: Any) -> str | None: + return value if isinstance(value, str) and value else None diff --git a/surfsense_backend/app/services/streaming/service.py b/surfsense_backend/app/services/streaming/service.py new file mode 100644 index 000000000..5a75a1b2d --- /dev/null +++ b/surfsense_backend/app/services/streaming/service.py @@ -0,0 +1,414 @@ +"""Composition root: bundles every formatter + a per-invocation emitter registry.""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any + +from . import envelope +from .emitter import Emitter, EmitterRegistry +from .events import ( + action_log, + data, + error, + interrupt, + lifecycle, + reasoning, + source, + subagent_lifecycle, + text, + tool, +) + + +class StreamingService: + def __init__(self) -> None: + self._message_id: str | None = None + self.emitter_registry = EmitterRegistry() + + @property + def message_id(self) -> str | None: + return self._message_id + + def begin_message(self, message_id: str | None = None) -> str: + self._message_id = message_id or envelope.generate_message_id() + return self._message_id + + @staticmethod + def generate_text_id() -> str: + return envelope.generate_text_id() + + @staticmethod + def generate_reasoning_id() -> str: + return envelope.generate_reasoning_id() + + @staticmethod + def generate_tool_call_id() -> str: + return envelope.generate_tool_call_id() + + @staticmethod + def generate_subagent_run_id() -> str: + return envelope.generate_subagent_run_id() + + @staticmethod + def get_response_headers() -> dict[str, str]: + return envelope.get_response_headers() + + @staticmethod + def format_done() -> str: + return envelope.format_done() + + def resolve_emitter( + self, + *, + run_id: str | None, + parent_ids: Iterable[str] | None, + ) -> Emitter: + return self.emitter_registry.resolve(run_id=run_id, parent_ids=parent_ids) + + def format_message_start( + self, + message_id: str | None = None, + *, + emitter: Emitter | None = None, + ) -> str: + chosen = self.begin_message(message_id) + return lifecycle.format_message_start(chosen, emitter=emitter) + + def format_message_finish(self, *, emitter: Emitter | None = None) -> str: + return lifecycle.format_message_finish(emitter=emitter) + + def format_step_start(self, *, emitter: Emitter | None = None) -> str: + return lifecycle.format_step_start(emitter=emitter) + + def format_step_finish(self, *, emitter: Emitter | None = None) -> str: + return lifecycle.format_step_finish(emitter=emitter) + + def format_text_start( + self, text_id: str, *, emitter: Emitter | None = None + ) -> str: + return text.format_text_start(text_id, emitter=emitter) + + def format_text_delta( + self, text_id: str, delta: str, *, emitter: Emitter | None = None + ) -> str: + return text.format_text_delta(text_id, delta, emitter=emitter) + + def format_text_end( + self, text_id: str, *, emitter: Emitter | None = None + ) -> str: + return text.format_text_end(text_id, emitter=emitter) + + def format_reasoning_start( + self, reasoning_id: str, *, emitter: Emitter | None = None + ) -> str: + return reasoning.format_reasoning_start(reasoning_id, emitter=emitter) + + def format_reasoning_delta( + self, + reasoning_id: str, + delta: str, + *, + emitter: Emitter | None = None, + ) -> str: + return reasoning.format_reasoning_delta(reasoning_id, delta, emitter=emitter) + + def format_reasoning_end( + self, reasoning_id: str, *, emitter: Emitter | None = None + ) -> str: + return reasoning.format_reasoning_end(reasoning_id, emitter=emitter) + + def format_tool_input_start( + self, + tool_call_id: str, + tool_name: str, + *, + langchain_tool_call_id: str | None = None, + emitter: Emitter | None = None, + ) -> str: + return tool.format_tool_input_start( + tool_call_id, + tool_name, + langchain_tool_call_id=langchain_tool_call_id, + emitter=emitter, + ) + + def format_tool_input_delta( + self, + tool_call_id: str, + input_text_delta: str, + *, + emitter: Emitter | None = None, + ) -> str: + return tool.format_tool_input_delta( + tool_call_id, input_text_delta, emitter=emitter + ) + + def format_tool_input_available( + self, + tool_call_id: str, + tool_name: str, + input_data: dict[str, Any], + *, + langchain_tool_call_id: str | None = None, + emitter: Emitter | None = None, + ) -> str: + return tool.format_tool_input_available( + tool_call_id, + tool_name, + input_data, + langchain_tool_call_id=langchain_tool_call_id, + emitter=emitter, + ) + + def format_tool_output_available( + self, + tool_call_id: str, + output: Any, + *, + langchain_tool_call_id: str | None = None, + emitter: Emitter | None = None, + ) -> str: + return tool.format_tool_output_available( + tool_call_id, + output, + langchain_tool_call_id=langchain_tool_call_id, + emitter=emitter, + ) + + def format_source_url( + self, + url: str, + *, + source_id: str | None = None, + title: str | None = None, + emitter: Emitter | None = None, + ) -> str: + return source.format_source_url( + url, source_id=source_id, title=title, emitter=emitter + ) + + def format_source_document( + self, + source_id: str, + *, + media_type: str = "file", + title: str | None = None, + description: str | None = None, + emitter: Emitter | None = None, + ) -> str: + return source.format_source_document( + source_id, + media_type=media_type, + title=title, + description=description, + emitter=emitter, + ) + + def format_file( + self, url: str, media_type: str, *, emitter: Emitter | None = None + ) -> str: + return source.format_file(url, media_type, emitter=emitter) + + def format_data( + self, data_type: str, payload: Any, *, emitter: Emitter | None = None + ) -> str: + return data.format_data(data_type, payload, emitter=emitter) + + def format_terminal_info( + self, + text_value: str, + *, + message_type: str = "info", + emitter: Emitter | None = None, + ) -> str: + return data.format_terminal_info( + text_value, message_type=message_type, emitter=emitter + ) + + def format_further_questions( + self, + questions: list[str], + *, + emitter: Emitter | None = None, + ) -> str: + return data.format_further_questions(questions, emitter=emitter) + + def format_thinking_step( + self, + *, + step_id: str, + title: str, + status: str = "in_progress", + items: list[str] | None = None, + emitter: Emitter | None = None, + ) -> str: + return data.format_thinking_step( + step_id=step_id, + title=title, + status=status, + items=items, + emitter=emitter, + ) + + def format_thread_title_update( + self, + *, + thread_id: int, + title: str, + emitter: Emitter | None = None, + ) -> str: + return data.format_thread_title_update( + thread_id=thread_id, title=title, emitter=emitter + ) + + def format_turn_info( + self, + *, + chat_turn_id: str, + emitter: Emitter | None = None, + ) -> str: + return data.format_turn_info(chat_turn_id=chat_turn_id, emitter=emitter) + + def format_turn_status( + self, + *, + status: str, + emitter: Emitter | None = None, + ) -> str: + return data.format_turn_status(status=status, emitter=emitter) + + def format_user_message_id( + self, + *, + message_id: str, + turn_id: str, + emitter: Emitter | None = None, + ) -> str: + return data.format_user_message_id( + message_id=message_id, turn_id=turn_id, emitter=emitter + ) + + def format_assistant_message_id( + self, + *, + message_id: str, + turn_id: str, + emitter: Emitter | None = None, + ) -> str: + return data.format_assistant_message_id( + message_id=message_id, turn_id=turn_id, emitter=emitter + ) + + def format_error( + self, + error_text: str, + *, + error_code: str | None = None, + extra: dict[str, Any] | None = None, + emitter: Emitter | None = None, + ) -> str: + return error.format_error( + error_text, + error_code=error_code, + extra=extra, + emitter=emitter, + ) + + def format_interrupt_request( + self, + interrupt_value: dict[str, Any], + *, + interrupt_id: str | None = None, + pending_interrupt_count: int | None = None, + chat_turn_id: str | None = None, + emitter: Emitter | None = None, + ) -> str: + return interrupt.format_interrupt_request( + interrupt_value, + interrupt_id=interrupt_id, + pending_interrupt_count=pending_interrupt_count, + chat_turn_id=chat_turn_id, + emitter=emitter, + ) + + def format_subagent_start( + self, + *, + subagent_run_id: str, + subagent_type: str, + parent_tool_call_id: str, + chat_turn_id: str | None = None, + description: str | None = None, + started_at: str | None = None, + emitter: Emitter | None = None, + ) -> str: + return subagent_lifecycle.format_subagent_start( + subagent_run_id=subagent_run_id, + subagent_type=subagent_type, + parent_tool_call_id=parent_tool_call_id, + chat_turn_id=chat_turn_id, + description=description, + started_at=started_at, + emitter=emitter, + ) + + def format_subagent_finish( + self, + *, + subagent_run_id: str, + subagent_type: str, + parent_tool_call_id: str, + status: str = "completed", + ended_at: str | None = None, + duration_ms: int | None = None, + emitter: Emitter | None = None, + ) -> str: + return subagent_lifecycle.format_subagent_finish( + subagent_run_id=subagent_run_id, + subagent_type=subagent_type, + parent_tool_call_id=parent_tool_call_id, + status=status, + ended_at=ended_at, + duration_ms=duration_ms, + emitter=emitter, + ) + + def format_subagent_error( + self, + *, + subagent_run_id: str, + subagent_type: str, + parent_tool_call_id: str, + error_text: str, + error_type: str | None = None, + ended_at: str | None = None, + duration_ms: int | None = None, + emitter: Emitter | None = None, + ) -> str: + return subagent_lifecycle.format_subagent_error( + subagent_run_id=subagent_run_id, + subagent_type=subagent_type, + parent_tool_call_id=parent_tool_call_id, + error_text=error_text, + error_type=error_type, + ended_at=ended_at, + duration_ms=duration_ms, + emitter=emitter, + ) + + def format_action_log( + self, + payload: dict[str, Any], + *, + emitter: Emitter | None = None, + ) -> str: + return action_log.format_action_log(payload, emitter=emitter) + + def format_action_log_updated( + self, + payload: dict[str, Any], + *, + emitter: Emitter | None = None, + ) -> str: + return action_log.format_action_log_updated(payload, emitter=emitter) diff --git a/surfsense_backend/app/tasks/chat/content_builder.py b/surfsense_backend/app/tasks/chat/content_builder.py index 041cab286..f0804159a 100644 --- a/surfsense_backend/app/tasks/chat/content_builder.py +++ b/surfsense_backend/app/tasks/chat/content_builder.py @@ -51,6 +51,21 @@ logger = logging.getLogger(__name__) _MEANINGFUL_PART_TYPES: frozenset[str] = frozenset({"text", "reasoning", "tool-call"}) +def _merge_tool_part_metadata(part: dict[str, Any], metadata: dict[str, Any] | None) -> None: + """Shallow-merge ``metadata`` into ``part["metadata"]``; first key wins. + + Used for tool-call linkage (``spanId``, ``thinkingStepId``, …): a later + event must not overwrite an existing key so chunk order vs ``on_tool_start`` + stays stable. + """ + if not metadata: + return + md = part.setdefault("metadata", {}) + for k, v in metadata.items(): + if k not in md: + md[k] = v + + class AssistantContentBuilder: """Server-side projection of ``surfsense_web/lib/chat/streaming-state.ts``. @@ -61,6 +76,7 @@ class AssistantContentBuilder: | { type: "reasoning"; text: string } | { type: "tool-call"; toolCallId: str; toolName: str; args: dict; result?: any; argsText?: str; langchainToolCallId?: str; + metadata?: { spanId?: str; thinkingStepId?: str; ... }; state?: "aborted" } | { type: "data-thinking-steps"; data: { steps: ThinkingStepData[] } } | { type: "data-step-separator"; data: { stepIndex: int } } @@ -85,8 +101,8 @@ class AssistantContentBuilder: self._current_text_idx: int = -1 self._current_reasoning_idx: int = -1 # ``ui_id``-keyed indexes for tool-call parts. ``ui_id`` is the - # synthetic ``call_`` (legacy) or the LangChain - # ``tool_call.id`` (parity_v2) — same key the streaming layer + # synthetic ``call_`` (chunk fallback) or the LangChain + # ``tool_call.id`` (indexed chunk path) — same key the streaming layer # threads through every ``tool-input-*`` / ``tool-output-*`` event. self._tool_call_idx_by_ui_id: dict[str, int] = {} # Live argsText accumulator (concatenated ``tool-input-delta`` chunks) @@ -177,21 +193,27 @@ class AssistantContentBuilder: ui_id: str, tool_name: str, langchain_tool_call_id: str | None, + *, + metadata: dict[str, Any] | None = None, ) -> None: - """Register a tool-call card. Args are filled in by later events.""" + """Register a tool-call card. Args are filled in by later events. + + Optional ``metadata`` (``spanId``, ``thinkingStepId``, …) is stored on the + part; duplicate ``tool-input-start`` calls merge with first-key-wins. + """ if not ui_id: return - # Skip duplicate registration: parity_v2 may emit + # Skip duplicate registration: the stream may emit # ``tool-input-start`` from both ``on_chat_model_stream`` # (when tool_call_chunks register a name) and ``on_tool_start`` # (the canonical path). The FE de-dupes via ``toolCallIndices``; # we mirror that here. if ui_id in self._tool_call_idx_by_ui_id: - if langchain_tool_call_id: - idx = self._tool_call_idx_by_ui_id[ui_id] - part = self.parts[idx] - if not part.get("langchainToolCallId"): - part["langchainToolCallId"] = langchain_tool_call_id + idx = self._tool_call_idx_by_ui_id[ui_id] + part = self.parts[idx] + if langchain_tool_call_id and not part.get("langchainToolCallId"): + part["langchainToolCallId"] = langchain_tool_call_id + _merge_tool_part_metadata(part, metadata) return part: dict[str, Any] = { @@ -202,6 +224,8 @@ class AssistantContentBuilder: } if langchain_tool_call_id: part["langchainToolCallId"] = langchain_tool_call_id + if metadata: + part["metadata"] = dict(metadata) self.parts.append(part) self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1 @@ -235,6 +259,8 @@ class AssistantContentBuilder: tool_name: str, args: dict[str, Any], langchain_tool_call_id: str | None, + *, + metadata: dict[str, Any] | None = None, ) -> None: """Finalize the tool-call card's input. @@ -243,7 +269,7 @@ class AssistantContentBuilder: pretty-printed JSON, sets the full ``args`` dict, and backfills ``langchainToolCallId`` if it wasn't known at ``tool-input-start`` time. Also creates the card if no prior ``tool-input-start`` registered it - (legacy parity_v2-OFF / late-registration paths). + (late-registration when no prior ``tool-input-start``). """ if not ui_id: return @@ -264,6 +290,7 @@ class AssistantContentBuilder: part["argsText"] = final_args_text if langchain_tool_call_id and not part.get("langchainToolCallId"): part["langchainToolCallId"] = langchain_tool_call_id + _merge_tool_part_metadata(part, metadata) return # No prior tool-input-start: register the card now. @@ -276,6 +303,7 @@ class AssistantContentBuilder: } if langchain_tool_call_id: new_part["langchainToolCallId"] = langchain_tool_call_id + _merge_tool_part_metadata(new_part, metadata) self.parts.append(new_part) self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1 @@ -287,6 +315,8 @@ class AssistantContentBuilder: ui_id: str, output: Any, langchain_tool_call_id: str | None, + *, + metadata: dict[str, Any] | None = None, ) -> None: """Attach the tool's output (``result``) to the matching card. @@ -305,6 +335,7 @@ class AssistantContentBuilder: part["result"] = output if langchain_tool_call_id and not part.get("langchainToolCallId"): part["langchainToolCallId"] = langchain_tool_call_id + _merge_tool_part_metadata(part, metadata) # ------------------------------------------------------------------ # Thinking steps & step separators @@ -316,6 +347,8 @@ class AssistantContentBuilder: title: str, status: str, items: list[str] | None, + *, + metadata: dict[str, Any] | None = None, ) -> None: """Update / insert the singleton ``data-thinking-steps`` part. @@ -328,12 +361,14 @@ class AssistantContentBuilder: if not step_id: return - new_step = { + new_step: dict[str, Any] = { "id": step_id, "title": title or "", "status": status or "in_progress", "items": list(items) if items else [], } + if metadata: + new_step["metadata"] = dict(metadata) # Find existing data-thinking-steps part. existing_idx = -1 @@ -347,6 +382,8 @@ class AssistantContentBuilder: replaced = False for i, step in enumerate(current_steps): if step.get("id") == step_id: + if not metadata and step.get("metadata"): + new_step["metadata"] = dict(step["metadata"]) current_steps[i] = new_step replaced = True break diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 1a2f38077..8e135179a 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -9,13 +9,11 @@ Supports loading LLM configurations from: - NewLLMConfig database table (positive IDs for user-created configs with prompt settings) """ -import ast import asyncio import contextlib import gc import json import logging -import re import time from collections.abc import AsyncGenerator from dataclasses import dataclass, field @@ -33,7 +31,6 @@ from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent from app.agents.new_chat.checkpointer import get_checkpointer from app.agents.new_chat.context import SurfSenseContextSchema from app.agents.new_chat.errors import BusyError -from app.agents.new_chat.feature_flags import get_flags from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection from app.agents.new_chat.llm_config import ( AgentConfig, @@ -77,6 +74,7 @@ from app.services.chat_session_state_service import ( ) from app.services.connector_service import ConnectorService from app.services.new_streaming_service import VercelStreamingService +from app.tasks.chat.streaming.graph_stream.event_stream import stream_output from app.utils.content_utils import bootstrap_history_from_db from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap from app.utils.user_message_multimodal import build_human_message_content @@ -729,9 +727,9 @@ def _legacy_match_lc_id( ) -> str | None: """Best-effort match a buffered ``tool_call_chunk`` to a tool name. - Pure extract of the legacy in-line match used at ``on_tool_start`` for - parity_v2-OFF and unmatched (chunk path didn't register an index for - this call) tools. Pops the next id-bearing chunk whose ``name`` + Pure extract of the in-line match used at ``on_tool_start`` when the + chunk path didn't register an index for this call. Pops the next + id-bearing chunk whose ``name`` matches ``tool_name`` (or any id-bearing chunk as a fallback) and returns its id. Mutates ``pending_tool_call_chunks`` and ``lc_tool_call_id_by_run`` in place. @@ -803,1505 +801,22 @@ async def _stream_agent_events( Yields: SSE-formatted strings for each event. """ - accumulated_text = "" - current_text_id: str | None = None - thinking_step_counter = 1 if initial_step_id else 0 - tool_step_ids: dict[str, str] = {} - completed_step_ids: set[str] = set() - last_active_step_id: str | None = initial_step_id - last_active_step_title: str = initial_step_title - last_active_step_items: list[str] = initial_step_items or [] - just_finished_tool: bool = False - active_tool_depth: int = 0 # Track nesting: >0 means we're inside a tool - called_update_memory: bool = False + async for sse in stream_output( + agent=agent, + config=config, + input_data=input_data, + streaming_service=streaming_service, + result=result, + step_prefix=step_prefix, + initial_step_id=initial_step_id, + initial_step_title=initial_step_title, + initial_step_items=initial_step_items, + content_builder=content_builder, + runtime_context=runtime_context, + ): + yield sse - # Reasoning-block streaming. We open a reasoning block on the - # first reasoning delta of a step, append deltas as they arrive, and - # close it when text starts (the model has switched to writing its - # answer) or ``on_chat_model_end`` fires for the model node. Reuses - # the same Vercel format-helpers as text-start/delta/end. - current_reasoning_id: str | None = None - - # Streaming-parity v2 feature flag. When OFF we keep the legacy - # shape: str-only content, no reasoning blocks, no - # ``langchainToolCallId`` propagation. The schema migrations - # (135 / 136) ship unconditionally because they're forward-compatible. - parity_v2 = bool(get_flags().enable_stream_parity_v2) - - # Best-effort attach of LangChain ``tool_call_id`` to the synthetic - # ``call_`` card id we already emit. We accumulate - # ``tool_call_chunks`` from ``on_chat_model_stream``, key them by - # name, and pop the next unconsumed entry at ``on_tool_start``. The - # authoritative id is later filled in at ``on_tool_end`` from - # ``ToolMessage.tool_call_id``. Under parity_v2 we ALSO short-circuit - # this list for chunks that already registered into ``index_to_meta`` - # below — so this list is reserved for the parity_v2-OFF / unmatched - # fallback path only and never re-pops a chunk we already streamed. - pending_tool_call_chunks: list[dict[str, Any]] = [] - lc_tool_call_id_by_run: dict[str, str] = {} - file_path_by_run: dict[str, str] = {} - - # parity_v2 only: live tool-call argument streaming. ``index_to_meta`` - # is keyed by the chunk's ``index`` field — LangChain - # ``ToolCallChunk``s for the same call share an index but only the - # first chunk carries id+name (subsequent ones are id=None, - # name=None, args=""). We register an index when both id and - # name are observed on a chunk (per ToolCallChunk semantics they - # arrive together on the first chunk), then route every later chunk - # at that index to the same ``ui_id`` as a ``tool-input-delta``. - # ``ui_tool_call_id_by_run`` maps LangGraph ``run_id`` to the - # ``ui_id`` used for that call's ``tool-input-start`` so the matching - # ``tool-output-available`` (emitted from ``on_tool_end``) lands on - # the same card. - index_to_meta: dict[int, dict[str, str]] = {} - ui_tool_call_id_by_run: dict[str, str] = {} - - # Per-tool-end mutable cache for the LangChain tool_call_id resolved - # at ``on_tool_end``. ``_emit_tool_output`` reads this so every - # ``format_tool_output_available`` call automatically carries the - # authoritative id without duplicating the kwarg at every call site. - current_lc_tool_call_id: dict[str, str | None] = {"value": None} - - def _emit_tool_output(call_id: str, output: Any) -> str: - # Drive the builder before formatting the SSE so the in-memory - # ContentPart[] mirror sees the result attached to the same - # card the FE will render. Builder method is a no-op when - # ``content_builder`` is None (anonymous / legacy paths). - if content_builder is not None: - content_builder.on_tool_output_available( - call_id, output, current_lc_tool_call_id["value"] - ) - return streaming_service.format_tool_output_available( - call_id, - output, - langchain_tool_call_id=current_lc_tool_call_id["value"], - ) - - def _emit_thinking_step( - *, - step_id: str, - title: str, - status: str = "in_progress", - items: list[str] | None = None, - ) -> str: - """Format a thinking-step SSE event and notify the builder. - - Single helper used at every ``format_thinking_step`` yield site - in this generator. Drives ``AssistantContentBuilder.on_thinking_step`` - first so the FE-mirror state lands the update before the SSE - carrying the same data leaves the wire — order matches the FE - pipeline (``processSharedStreamEvent`` updates state, then - flushes). Builder call is a no-op when ``content_builder`` is - None (anonymous / legacy paths). - """ - if content_builder is not None: - content_builder.on_thinking_step(step_id, title, status, items) - return streaming_service.format_thinking_step( - step_id=step_id, - title=title, - status=status, - items=items, - ) - - def next_thinking_step_id() -> str: - nonlocal thinking_step_counter - thinking_step_counter += 1 - return f"{step_prefix}-{thinking_step_counter}" - - def complete_current_step() -> str | None: - nonlocal last_active_step_id - if last_active_step_id and last_active_step_id not in completed_step_ids: - completed_step_ids.add(last_active_step_id) - event = _emit_thinking_step( - step_id=last_active_step_id, - title=last_active_step_title, - status="completed", - items=last_active_step_items if last_active_step_items else None, - ) - last_active_step_id = None - return event - return None - - # Per-invocation runtime context (Phase 1.5). When supplied, - # ``KnowledgePriorityMiddleware`` reads ``mentioned_document_ids`` - # from ``runtime.context`` instead of its constructor closure — the - # prerequisite that lets the compiled-agent cache (Phase 1) reuse a - # single graph across turns. Astream_events_kwargs stays empty when - # callers leave ``runtime_context`` as ``None`` to preserve the - # legacy code path bit-for-bit. - astream_kwargs: dict[str, Any] = {"config": config, "version": "v2"} - if runtime_context is not None: - astream_kwargs["context"] = runtime_context - - async for event in agent.astream_events(input_data, **astream_kwargs): - event_type = event.get("event", "") - - if event_type == "on_chat_model_stream": - if active_tool_depth > 0: - continue # Suppress inner-tool LLM tokens from leaking into chat - if "surfsense:internal" in event.get("tags", []): - continue # Suppress middleware-internal LLM tokens (e.g. KB search classification) - chunk = event.get("data", {}).get("chunk") - if not chunk: - continue - parts = _extract_chunk_parts(chunk) - - reasoning_delta = parts["reasoning"] - text_delta = parts["text"] - - # Reasoning streaming. Open a reasoning block on first - # delta; append every subsequent delta until text begins. - # When text starts we close the reasoning block first so the - # frontend sees the natural hand-off. Gated behind the - # parity-v2 flag so legacy deployments keep today's shape. - if parity_v2 and reasoning_delta: - if current_text_id is not None: - yield streaming_service.format_text_end(current_text_id) - if content_builder is not None: - content_builder.on_text_end(current_text_id) - current_text_id = None - if current_reasoning_id is None: - completion_event = complete_current_step() - if completion_event: - yield completion_event - if just_finished_tool: - last_active_step_id = None - last_active_step_title = "" - last_active_step_items = [] - just_finished_tool = False - current_reasoning_id = streaming_service.generate_reasoning_id() - yield streaming_service.format_reasoning_start(current_reasoning_id) - if content_builder is not None: - content_builder.on_reasoning_start(current_reasoning_id) - yield streaming_service.format_reasoning_delta( - current_reasoning_id, reasoning_delta - ) - if content_builder is not None: - content_builder.on_reasoning_delta( - current_reasoning_id, reasoning_delta - ) - - if text_delta: - if current_reasoning_id is not None: - yield streaming_service.format_reasoning_end(current_reasoning_id) - if content_builder is not None: - content_builder.on_reasoning_end(current_reasoning_id) - current_reasoning_id = None - if current_text_id is None: - completion_event = complete_current_step() - if completion_event: - yield completion_event - if just_finished_tool: - last_active_step_id = None - last_active_step_title = "" - last_active_step_items = [] - just_finished_tool = False - current_text_id = streaming_service.generate_text_id() - yield streaming_service.format_text_start(current_text_id) - if content_builder is not None: - content_builder.on_text_start(current_text_id) - yield streaming_service.format_text_delta(current_text_id, text_delta) - accumulated_text += text_delta - if content_builder is not None: - content_builder.on_text_delta(current_text_id, text_delta) - - # Live tool-call argument streaming. Runs AFTER text/reasoning - # processing so chunks containing both stay in their natural - # wire order (text → text-end → tool-input-start). Active - # text/reasoning are closed inside the registration branch - # before ``tool-input-start`` so the frontend sees a clean - # part boundary even when providers interleave. - if parity_v2 and parts["tool_call_chunks"]: - for tcc in parts["tool_call_chunks"]: - idx = tcc.get("index") - - # Register this index when we first see id+name - # TOGETHER. Per LangChain ToolCallChunk semantics the - # first chunk for a tool call carries both fields - # together; later chunks have id=None, name=None and - # only ``args``. Requiring BOTH keeps wire - # ``tool-input-start`` always carrying a real - # toolName (assistant-ui's typed tool-part dispatch - # keys off it). - if idx is not None and idx not in index_to_meta: - lc_id = tcc.get("id") - name = tcc.get("name") - if lc_id and name: - ui_id = lc_id - - # Close active text/reasoning so wire - # ordering stays clean even on providers - # that interleave text and tool-call chunks - # within the same stream window. - if current_text_id is not None: - yield streaming_service.format_text_end(current_text_id) - if content_builder is not None: - content_builder.on_text_end(current_text_id) - current_text_id = None - if current_reasoning_id is not None: - yield streaming_service.format_reasoning_end( - current_reasoning_id - ) - if content_builder is not None: - content_builder.on_reasoning_end( - current_reasoning_id - ) - current_reasoning_id = None - - index_to_meta[idx] = { - "ui_id": ui_id, - "lc_id": lc_id, - "name": name, - } - yield streaming_service.format_tool_input_start( - ui_id, - name, - langchain_tool_call_id=lc_id, - ) - if content_builder is not None: - content_builder.on_tool_input_start(ui_id, name, lc_id) - - # Emit args delta for any chunk at a registered - # index (including idless continuations). Once an - # index is owned by ``index_to_meta`` we DO NOT - # append to ``pending_tool_call_chunks`` — that list - # is reserved for the parity_v2-OFF / unmatched - # fallback path so it never re-pops chunks already - # consumed here (skip-append). - meta = index_to_meta.get(idx) if idx is not None else None - if meta: - args_chunk = tcc.get("args") or "" - if args_chunk: - yield streaming_service.format_tool_input_delta( - meta["ui_id"], args_chunk - ) - if content_builder is not None: - content_builder.on_tool_input_delta( - meta["ui_id"], args_chunk - ) - else: - pending_tool_call_chunks.append(tcc) - - elif event_type == "on_tool_start": - active_tool_depth += 1 - tool_name = event.get("name", "unknown_tool") - run_id = event.get("run_id", "") - tool_input = event.get("data", {}).get("input", {}) - if tool_name in ("write_file", "edit_file"): - result.write_attempted = True - if isinstance(tool_input, dict): - file_path = tool_input.get("file_path") - if isinstance(file_path, str) and file_path.strip() and run_id: - file_path_by_run[run_id] = file_path.strip() - - if current_text_id is not None: - yield streaming_service.format_text_end(current_text_id) - if content_builder is not None: - content_builder.on_text_end(current_text_id) - current_text_id = None - - if last_active_step_title != "Synthesizing response": - completion_event = complete_current_step() - if completion_event: - yield completion_event - - just_finished_tool = False - tool_step_id = next_thinking_step_id() - tool_step_ids[run_id] = tool_step_id - last_active_step_id = tool_step_id - - if tool_name == "ls": - ls_path = ( - tool_input.get("path", "/") - if isinstance(tool_input, dict) - else str(tool_input) - ) - last_active_step_title = "Listing files" - last_active_step_items = [ls_path] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Listing files", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "read_file": - fp = ( - tool_input.get("file_path", "") - if isinstance(tool_input, dict) - else str(tool_input) - ) - display_fp = fp if len(fp) <= 80 else "…" + fp[-77:] - last_active_step_title = "Reading file" - last_active_step_items = [display_fp] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Reading file", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "write_file": - fp = ( - tool_input.get("file_path", "") - if isinstance(tool_input, dict) - else str(tool_input) - ) - display_fp = fp if len(fp) <= 80 else "…" + fp[-77:] - last_active_step_title = "Writing file" - last_active_step_items = [display_fp] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Writing file", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "edit_file": - fp = ( - tool_input.get("file_path", "") - if isinstance(tool_input, dict) - else str(tool_input) - ) - display_fp = fp if len(fp) <= 80 else "…" + fp[-77:] - last_active_step_title = "Editing file" - last_active_step_items = [display_fp] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Editing file", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "glob": - pat = ( - tool_input.get("pattern", "") - if isinstance(tool_input, dict) - else str(tool_input) - ) - base_path = ( - tool_input.get("path", "/") if isinstance(tool_input, dict) else "/" - ) - last_active_step_title = "Searching files" - last_active_step_items = [f"{pat} in {base_path}"] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Searching files", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "grep": - pat = ( - tool_input.get("pattern", "") - if isinstance(tool_input, dict) - else str(tool_input) - ) - grep_path = ( - tool_input.get("path", "") if isinstance(tool_input, dict) else "" - ) - display_pat = pat[:60] + ("…" if len(pat) > 60 else "") - last_active_step_title = "Searching content" - last_active_step_items = [ - f'"{display_pat}"' + (f" in {grep_path}" if grep_path else "") - ] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Searching content", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "rm": - rm_path = ( - tool_input.get("path", "") - if isinstance(tool_input, dict) - else str(tool_input) - ) - display_path = rm_path if len(rm_path) <= 80 else "…" + rm_path[-77:] - last_active_step_title = "Deleting file" - last_active_step_items = [display_path] if display_path else [] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Deleting file", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "rmdir": - rmdir_path = ( - tool_input.get("path", "") - if isinstance(tool_input, dict) - else str(tool_input) - ) - display_path = ( - rmdir_path if len(rmdir_path) <= 80 else "…" + rmdir_path[-77:] - ) - last_active_step_title = "Deleting folder" - last_active_step_items = [display_path] if display_path else [] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Deleting folder", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "mkdir": - mkdir_path = ( - tool_input.get("path", "") - if isinstance(tool_input, dict) - else str(tool_input) - ) - display_path = ( - mkdir_path if len(mkdir_path) <= 80 else "…" + mkdir_path[-77:] - ) - last_active_step_title = "Creating folder" - last_active_step_items = [display_path] if display_path else [] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Creating folder", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "move_file": - src = ( - tool_input.get("source_path", "") - if isinstance(tool_input, dict) - else "" - ) - dst = ( - tool_input.get("destination_path", "") - if isinstance(tool_input, dict) - else "" - ) - display_src = src if len(src) <= 60 else "…" + src[-57:] - display_dst = dst if len(dst) <= 60 else "…" + dst[-57:] - last_active_step_title = "Moving file" - last_active_step_items = ( - [f"{display_src} → {display_dst}"] if src or dst else [] - ) - yield _emit_thinking_step( - step_id=tool_step_id, - title="Moving file", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "write_todos": - todos = ( - tool_input.get("todos", []) if isinstance(tool_input, dict) else [] - ) - todo_count = len(todos) if isinstance(todos, list) else 0 - last_active_step_title = "Planning tasks" - last_active_step_items = ( - [f"{todo_count} task{'s' if todo_count != 1 else ''}"] - if todo_count - else [] - ) - yield _emit_thinking_step( - step_id=tool_step_id, - title="Planning tasks", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "save_document": - doc_title = ( - tool_input.get("title", "") - if isinstance(tool_input, dict) - else str(tool_input) - ) - display_title = doc_title[:60] + ("…" if len(doc_title) > 60 else "") - last_active_step_title = "Saving document" - last_active_step_items = [display_title] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Saving document", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "generate_image": - prompt = ( - tool_input.get("prompt", "") - if isinstance(tool_input, dict) - else str(tool_input) - ) - last_active_step_title = "Generating image" - last_active_step_items = [ - f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}" - ] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Generating image", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "scrape_webpage": - url = ( - tool_input.get("url", "") - if isinstance(tool_input, dict) - else str(tool_input) - ) - last_active_step_title = "Scraping webpage" - last_active_step_items = [ - f"URL: {url[:80]}{'...' if len(url) > 80 else ''}" - ] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Scraping webpage", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "generate_podcast": - podcast_title = ( - tool_input.get("podcast_title", "SurfSense Podcast") - if isinstance(tool_input, dict) - else "SurfSense Podcast" - ) - content_len = len( - tool_input.get("source_content", "") - if isinstance(tool_input, dict) - else "" - ) - last_active_step_title = "Generating podcast" - last_active_step_items = [ - f"Title: {podcast_title}", - f"Content: {content_len:,} characters", - "Preparing audio generation...", - ] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Generating podcast", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "generate_report": - report_topic = ( - tool_input.get("topic", "Report") - if isinstance(tool_input, dict) - else "Report" - ) - is_revision = bool( - isinstance(tool_input, dict) and tool_input.get("parent_report_id") - ) - step_title = "Revising report" if is_revision else "Generating report" - last_active_step_title = step_title - last_active_step_items = [ - f"Topic: {report_topic}", - "Analyzing source content...", - ] - yield _emit_thinking_step( - step_id=tool_step_id, - title=step_title, - status="in_progress", - items=last_active_step_items, - ) - elif tool_name in ("execute", "execute_code"): - cmd = ( - tool_input.get("command", "") - if isinstance(tool_input, dict) - else str(tool_input) - ) - display_cmd = cmd[:80] + ("…" if len(cmd) > 80 else "") - last_active_step_title = "Running command" - last_active_step_items = [f"$ {display_cmd}"] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Running command", - status="in_progress", - items=last_active_step_items, - ) - else: - # Fallback for tools without a curated thinking-step title - # (typically connector tools, MCP-registered tools, or - # newly added tools that haven't been wired up here yet). - # Render the snake_cased name as a sentence-cased phrase - # so non-technical users see e.g. "Send gmail email" - # rather than the raw identifier "send_gmail_email". - last_active_step_title = ( - tool_name.replace("_", " ").strip().capitalize() or tool_name - ) - last_active_step_items = [] - yield _emit_thinking_step( - step_id=tool_step_id, - title=last_active_step_title, - status="in_progress", - ) - - # Resolve the card identity. If the chunk-emission loop - # already registered an ``index`` for this tool call (parity_v2 - # path), reuse the same ui_id so the card sees: - # tool-input-start → deltas… → tool-input-available → - # tool-output-available all keyed by lc_id. Otherwise fall - # back to the synthetic ``call_`` id and the legacy - # best-effort match against ``pending_tool_call_chunks``. - matched_meta: dict[str, str] | None = None - if parity_v2: - # FIFO over indices 0,1,2…; first unassigned same-name - # match wins. Handles parallel same-name calls (e.g. two - # write_file calls) deterministically as long as the - # model interleaves on_tool_start in the same order it - # streamed the args. - taken_ui_ids = set(ui_tool_call_id_by_run.values()) - for meta in index_to_meta.values(): - if meta["name"] == tool_name and meta["ui_id"] not in taken_ui_ids: - matched_meta = meta - break - - tool_call_id: str - langchain_tool_call_id: str | None = None - if matched_meta is not None: - tool_call_id = matched_meta["ui_id"] - langchain_tool_call_id = matched_meta["lc_id"] - # ``tool-input-start`` already fired during chunk - # emission — skip the duplicate. No pruning is needed - # because the chunk-emission loop intentionally never - # appends registered-index chunks to - # ``pending_tool_call_chunks`` (skip-append). - if run_id: - lc_tool_call_id_by_run[run_id] = matched_meta["lc_id"] - else: - tool_call_id = ( - f"call_{run_id[:32]}" - if run_id - else streaming_service.generate_tool_call_id() - ) - # Legacy fallback: parity_v2 OFF, or parity_v2 ON but the - # provider didn't stream tool_call_chunks for this call - # (no index registered). Run the existing best-effort - # match BEFORE emitting start so we still attach an - # authoritative ``langchainToolCallId`` when possible. - if parity_v2: - langchain_tool_call_id = _legacy_match_lc_id( - pending_tool_call_chunks, - tool_name, - run_id, - lc_tool_call_id_by_run, - ) - yield streaming_service.format_tool_input_start( - tool_call_id, - tool_name, - langchain_tool_call_id=langchain_tool_call_id, - ) - if content_builder is not None: - content_builder.on_tool_input_start( - tool_call_id, tool_name, langchain_tool_call_id - ) - - if run_id: - ui_tool_call_id_by_run[run_id] = tool_call_id - - # Sanitize tool_input: strip runtime-injected non-serializable - # values (e.g. LangChain ToolRuntime) before sending over SSE. - if isinstance(tool_input, dict): - _safe_input: dict[str, Any] = {} - for _k, _v in tool_input.items(): - try: - json.dumps(_v) - _safe_input[_k] = _v - except (TypeError, ValueError, OverflowError): - pass - else: - _safe_input = {"input": tool_input} - yield streaming_service.format_tool_input_available( - tool_call_id, - tool_name, - _safe_input, - langchain_tool_call_id=langchain_tool_call_id, - ) - if content_builder is not None: - content_builder.on_tool_input_available( - tool_call_id, - tool_name, - _safe_input, - langchain_tool_call_id, - ) - - elif event_type == "on_tool_end": - active_tool_depth = max(0, active_tool_depth - 1) - run_id = event.get("run_id", "") - tool_name = event.get("name", "unknown_tool") - raw_output = event.get("data", {}).get("output", "") - staged_file_path = file_path_by_run.pop(run_id, None) if run_id else None - - if tool_name == "update_memory": - called_update_memory = True - - if hasattr(raw_output, "content"): - content = raw_output.content - if isinstance(content, str): - try: - tool_output = json.loads(content) - except (json.JSONDecodeError, TypeError): - tool_output = {"result": content} - elif isinstance(content, dict): - tool_output = content - else: - tool_output = {"result": str(content)} - elif isinstance(raw_output, dict): - tool_output = raw_output - else: - tool_output = {"result": str(raw_output) if raw_output else "completed"} - - if tool_name in ("write_file", "edit_file"): - if _tool_output_has_error(tool_output): - # Keep successful evidence if a previous write/edit in this turn succeeded. - pass - else: - result.write_succeeded = True - result.verification_succeeded = True - - # Look up the SAME card id used at on_tool_start (either the - # parity_v2 lc-id-derived ui_id or the legacy synthetic - # ``call_``) so the output event always lands on the - # same card as start/delta/available. Fallback preserves the - # legacy synthetic shape for parity_v2-OFF / unknown-run paths. - tool_call_id = ui_tool_call_id_by_run.get( - run_id, - f"call_{run_id[:32]}" if run_id else "call_unknown", - ) - original_step_id = tool_step_ids.get( - run_id, f"{step_prefix}-unknown-{run_id[:8]}" - ) - completed_step_ids.add(original_step_id) - - # Authoritative LangChain tool_call_id from the returned - # ``ToolMessage``. Falls back to whatever we matched - # at ``on_tool_start`` time (kept in ``lc_tool_call_id_by_run``) - # if the output isn't a ToolMessage. The value is stored in - # ``current_lc_tool_call_id`` so ``_emit_tool_output`` - # picks it up for every output emit below. - # - # Emitted in BOTH parity_v2 and legacy modes: the chat tool - # card needs the LangChain id to match against the - # ``data-action-log`` SSE event (keyed by ``lc_tool_call_id``) - # so the inline Revert button can light up. Reading - # ``raw_output.tool_call_id`` is a cheap, non-mutating attribute - # access that is safe regardless of feature-flag state. - current_lc_tool_call_id["value"] = None - authoritative = getattr(raw_output, "tool_call_id", None) - if isinstance(authoritative, str) and authoritative: - current_lc_tool_call_id["value"] = authoritative - if run_id: - lc_tool_call_id_by_run[run_id] = authoritative - elif run_id and run_id in lc_tool_call_id_by_run: - current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id] - - if tool_name == "read_file": - yield _emit_thinking_step( - step_id=original_step_id, - title="Reading file", - status="completed", - items=last_active_step_items, - ) - elif tool_name == "write_file": - yield _emit_thinking_step( - step_id=original_step_id, - title="Writing file", - status="completed", - items=last_active_step_items, - ) - elif tool_name == "edit_file": - yield _emit_thinking_step( - step_id=original_step_id, - title="Editing file", - status="completed", - items=last_active_step_items, - ) - elif tool_name == "glob": - yield _emit_thinking_step( - step_id=original_step_id, - title="Searching files", - status="completed", - items=last_active_step_items, - ) - elif tool_name == "grep": - yield _emit_thinking_step( - step_id=original_step_id, - title="Searching content", - status="completed", - items=last_active_step_items, - ) - elif tool_name == "rm": - yield _emit_thinking_step( - step_id=original_step_id, - title="Deleting file", - status="completed", - items=last_active_step_items, - ) - elif tool_name == "rmdir": - yield _emit_thinking_step( - step_id=original_step_id, - title="Deleting folder", - status="completed", - items=last_active_step_items, - ) - elif tool_name == "mkdir": - yield _emit_thinking_step( - step_id=original_step_id, - title="Creating folder", - status="completed", - items=last_active_step_items, - ) - elif tool_name == "move_file": - yield _emit_thinking_step( - step_id=original_step_id, - title="Moving file", - status="completed", - items=last_active_step_items, - ) - elif tool_name == "write_todos": - yield _emit_thinking_step( - step_id=original_step_id, - title="Planning tasks", - status="completed", - items=last_active_step_items, - ) - elif tool_name == "save_document": - result_str = ( - tool_output.get("result", "") - if isinstance(tool_output, dict) - else str(tool_output) - ) - is_error = "Error" in result_str - completed_items = [ - *last_active_step_items, - result_str[:80] if is_error else "Saved to knowledge base", - ] - yield _emit_thinking_step( - step_id=original_step_id, - title="Saving document", - status="completed", - items=completed_items, - ) - elif tool_name == "generate_image": - if isinstance(tool_output, dict) and not tool_output.get("error"): - completed_items = [ - *last_active_step_items, - "Image generated successfully", - ] - else: - error_msg = ( - tool_output.get("error", "Generation failed") - if isinstance(tool_output, dict) - else "Generation failed" - ) - completed_items = [*last_active_step_items, f"Error: {error_msg}"] - yield _emit_thinking_step( - step_id=original_step_id, - title="Generating image", - status="completed", - items=completed_items, - ) - elif tool_name == "scrape_webpage": - if isinstance(tool_output, dict): - title = tool_output.get("title", "Webpage") - word_count = tool_output.get("word_count", 0) - has_error = "error" in tool_output - if has_error: - completed_items = [ - *last_active_step_items, - f"Error: {tool_output.get('error', 'Failed to scrape')[:50]}", - ] - else: - completed_items = [ - *last_active_step_items, - f"Title: {title[:50]}{'...' if len(title) > 50 else ''}", - f"Extracted: {word_count:,} words", - ] - else: - completed_items = [*last_active_step_items, "Content extracted"] - yield _emit_thinking_step( - step_id=original_step_id, - title="Scraping webpage", - status="completed", - items=completed_items, - ) - elif tool_name == "generate_podcast": - podcast_status = ( - tool_output.get("status", "unknown") - if isinstance(tool_output, dict) - else "unknown" - ) - podcast_title = ( - tool_output.get("title", "Podcast") - if isinstance(tool_output, dict) - else "Podcast" - ) - if podcast_status in ("pending", "generating", "processing"): - completed_items = [ - f"Title: {podcast_title}", - "Podcast generation started", - "Processing in background...", - ] - elif podcast_status == "already_generating": - completed_items = [ - f"Title: {podcast_title}", - "Podcast already in progress", - "Please wait for it to complete", - ] - elif podcast_status in ("failed", "error"): - error_msg = ( - tool_output.get("error", "Unknown error") - if isinstance(tool_output, dict) - else "Unknown error" - ) - completed_items = [ - f"Title: {podcast_title}", - f"Error: {error_msg[:50]}", - ] - elif podcast_status in ("ready", "success"): - completed_items = [ - f"Title: {podcast_title}", - "Podcast ready", - ] - else: - completed_items = last_active_step_items - yield _emit_thinking_step( - step_id=original_step_id, - title="Generating podcast", - status="completed", - items=completed_items, - ) - elif tool_name == "generate_video_presentation": - vp_status = ( - tool_output.get("status", "unknown") - if isinstance(tool_output, dict) - else "unknown" - ) - vp_title = ( - tool_output.get("title", "Presentation") - if isinstance(tool_output, dict) - else "Presentation" - ) - if vp_status in ("pending", "generating"): - completed_items = [ - f"Title: {vp_title}", - "Presentation generation started", - "Processing in background...", - ] - elif vp_status == "failed": - error_msg = ( - tool_output.get("error", "Unknown error") - if isinstance(tool_output, dict) - else "Unknown error" - ) - completed_items = [ - f"Title: {vp_title}", - f"Error: {error_msg[:50]}", - ] - else: - completed_items = last_active_step_items - yield _emit_thinking_step( - step_id=original_step_id, - title="Generating video presentation", - status="completed", - items=completed_items, - ) - elif tool_name == "generate_report": - report_status = ( - tool_output.get("status", "unknown") - if isinstance(tool_output, dict) - else "unknown" - ) - report_title = ( - tool_output.get("title", "Report") - if isinstance(tool_output, dict) - else "Report" - ) - word_count = ( - tool_output.get("word_count", 0) - if isinstance(tool_output, dict) - else 0 - ) - is_revision = ( - tool_output.get("is_revision", False) - if isinstance(tool_output, dict) - else False - ) - step_title = "Revising report" if is_revision else "Generating report" - - if report_status == "ready": - completed_items = [ - f"Topic: {report_title}", - f"{word_count:,} words", - "Report ready", - ] - elif report_status == "failed": - error_msg = ( - tool_output.get("error", "Unknown error") - if isinstance(tool_output, dict) - else "Unknown error" - ) - completed_items = [ - f"Topic: {report_title}", - f"Error: {error_msg[:50]}", - ] - else: - completed_items = last_active_step_items - - yield _emit_thinking_step( - step_id=original_step_id, - title=step_title, - status="completed", - items=completed_items, - ) - elif tool_name in ("execute", "execute_code"): - raw_text = ( - tool_output.get("result", "") - if isinstance(tool_output, dict) - else str(tool_output) - ) - m = re.match(r"^Exit code:\s*(\d+)", raw_text) - exit_code_val = int(m.group(1)) if m else None - if exit_code_val is not None and exit_code_val == 0: - completed_items = [ - *last_active_step_items, - "Completed successfully", - ] - elif exit_code_val is not None: - completed_items = [ - *last_active_step_items, - f"Exit code: {exit_code_val}", - ] - else: - completed_items = [*last_active_step_items, "Finished"] - yield _emit_thinking_step( - step_id=original_step_id, - title="Running command", - status="completed", - items=completed_items, - ) - elif tool_name == "ls": - if isinstance(tool_output, dict): - ls_output = tool_output.get("result", "") - elif isinstance(tool_output, str): - ls_output = tool_output - else: - ls_output = str(tool_output) if tool_output else "" - file_names: list[str] = [] - if ls_output: - paths: list[str] = [] - try: - parsed = ast.literal_eval(ls_output) - if isinstance(parsed, list): - paths = [str(p) for p in parsed] - except (ValueError, SyntaxError): - paths = [ - line.strip() - for line in ls_output.strip().split("\n") - if line.strip() - ] - for p in paths: - name = p.rstrip("/").split("/")[-1] - if name and len(name) <= 40: - file_names.append(name) - elif name: - file_names.append(name[:37] + "...") - if file_names: - if len(file_names) <= 5: - completed_items = [f"[{name}]" for name in file_names] - else: - completed_items = [f"[{name}]" for name in file_names[:4]] - completed_items.append(f"(+{len(file_names) - 4} more)") - else: - completed_items = ["No files found"] - yield _emit_thinking_step( - step_id=original_step_id, - title="Listing files", - status="completed", - items=completed_items, - ) - else: - # Fallback completion title — see the matching in-progress - # branch above for the wording rationale. - fallback_title = ( - tool_name.replace("_", " ").strip().capitalize() or tool_name - ) - yield _emit_thinking_step( - step_id=original_step_id, - title=fallback_title, - status="completed", - items=last_active_step_items, - ) - - just_finished_tool = True - last_active_step_id = None - last_active_step_title = "" - last_active_step_items = [] - - if tool_name == "generate_podcast": - yield _emit_tool_output( - tool_call_id, - tool_output - if isinstance(tool_output, dict) - else {"result": tool_output}, - ) - if isinstance(tool_output, dict) and tool_output.get("status") in ( - "pending", - "generating", - "processing", - ): - yield streaming_service.format_terminal_info( - f"Podcast queued: {tool_output.get('title', 'Podcast')}", - "success", - ) - elif isinstance(tool_output, dict) and tool_output.get("status") in ( - "ready", - "success", - ): - yield streaming_service.format_terminal_info( - f"Podcast generated successfully: {tool_output.get('title', 'Podcast')}", - "success", - ) - elif isinstance(tool_output, dict) and tool_output.get("status") in ( - "failed", - "error", - ): - error_msg = tool_output.get("error", "Unknown error") - yield streaming_service.format_terminal_info( - f"Podcast generation failed: {error_msg}", - "error", - ) - elif tool_name == "generate_video_presentation": - yield _emit_tool_output( - tool_call_id, - tool_output - if isinstance(tool_output, dict) - else {"result": tool_output}, - ) - if ( - isinstance(tool_output, dict) - and tool_output.get("status") == "pending" - ): - yield streaming_service.format_terminal_info( - f"Video presentation queued: {tool_output.get('title', 'Presentation')}", - "success", - ) - elif ( - isinstance(tool_output, dict) - and tool_output.get("status") == "failed" - ): - error_msg = ( - tool_output.get("error", "Unknown error") - if isinstance(tool_output, dict) - else "Unknown error" - ) - yield streaming_service.format_terminal_info( - f"Presentation generation failed: {error_msg}", - "error", - ) - elif tool_name == "generate_image": - yield _emit_tool_output( - tool_call_id, - tool_output - if isinstance(tool_output, dict) - else {"result": tool_output}, - ) - if isinstance(tool_output, dict): - if tool_output.get("error"): - yield streaming_service.format_terminal_info( - f"Image generation failed: {tool_output['error'][:60]}", - "error", - ) - else: - yield streaming_service.format_terminal_info( - "Image generated successfully", - "success", - ) - elif tool_name == "scrape_webpage": - if isinstance(tool_output, dict): - display_output = { - k: v for k, v in tool_output.items() if k != "content" - } - if "content" in tool_output: - content = tool_output.get("content", "") - display_output["content_preview"] = ( - content[:500] + "..." if len(content) > 500 else content - ) - yield _emit_tool_output( - tool_call_id, - display_output, - ) - else: - yield _emit_tool_output( - tool_call_id, - {"result": tool_output}, - ) - if isinstance(tool_output, dict) and "error" not in tool_output: - title = tool_output.get("title", "Webpage") - word_count = tool_output.get("word_count", 0) - yield streaming_service.format_terminal_info( - f"Scraped: {title[:40]}{'...' if len(title) > 40 else ''} ({word_count:,} words)", - "success", - ) - else: - error_msg = ( - tool_output.get("error", "Failed to scrape") - if isinstance(tool_output, dict) - else "Failed to scrape" - ) - yield streaming_service.format_terminal_info( - f"Scrape failed: {error_msg}", - "error", - ) - elif tool_name in ("write_file", "edit_file"): - resolved_path = _extract_resolved_file_path( - tool_name=tool_name, - tool_output=tool_output, - tool_input={"file_path": staged_file_path} - if staged_file_path - else None, - ) - result_text = _tool_output_to_text(tool_output) - if _tool_output_has_error(tool_output): - yield _emit_tool_output( - tool_call_id, - { - "status": "error", - "error": result_text, - "path": resolved_path, - }, - ) - else: - yield _emit_tool_output( - tool_call_id, - { - "status": "completed", - "path": resolved_path, - "result": result_text, - }, - ) - elif tool_name == "generate_report": - # Stream the full report result so frontend can render the ReportCard - yield _emit_tool_output( - tool_call_id, - tool_output - if isinstance(tool_output, dict) - else {"result": tool_output}, - ) - # Send appropriate terminal message based on status - if ( - isinstance(tool_output, dict) - and tool_output.get("status") == "ready" - ): - word_count = tool_output.get("word_count", 0) - yield streaming_service.format_terminal_info( - f"Report generated: {tool_output.get('title', 'Report')} ({word_count:,} words)", - "success", - ) - else: - error_msg = ( - tool_output.get("error", "Unknown error") - if isinstance(tool_output, dict) - else "Unknown error" - ) - yield streaming_service.format_terminal_info( - f"Report generation failed: {error_msg}", - "error", - ) - elif tool_name == "generate_resume": - yield _emit_tool_output( - tool_call_id, - tool_output - if isinstance(tool_output, dict) - else {"result": tool_output}, - ) - if ( - isinstance(tool_output, dict) - and tool_output.get("status") == "ready" - ): - yield streaming_service.format_terminal_info( - f"Resume generated: {tool_output.get('title', 'Resume')}", - "success", - ) - else: - error_msg = ( - tool_output.get("error", "Unknown error") - if isinstance(tool_output, dict) - else "Unknown error" - ) - yield streaming_service.format_terminal_info( - f"Resume generation failed: {error_msg}", - "error", - ) - elif tool_name in ( - "create_notion_page", - "update_notion_page", - "delete_notion_page", - "create_linear_issue", - "update_linear_issue", - "delete_linear_issue", - "create_google_drive_file", - "delete_google_drive_file", - "create_onedrive_file", - "delete_onedrive_file", - "create_dropbox_file", - "delete_dropbox_file", - "create_gmail_draft", - "update_gmail_draft", - "send_gmail_email", - "trash_gmail_email", - "create_calendar_event", - "update_calendar_event", - "delete_calendar_event", - "create_jira_issue", - "update_jira_issue", - "delete_jira_issue", - "create_confluence_page", - "update_confluence_page", - "delete_confluence_page", - ): - yield _emit_tool_output( - tool_call_id, - tool_output - if isinstance(tool_output, dict) - else {"result": tool_output}, - ) - elif tool_name in ("execute", "execute_code"): - raw_text = ( - tool_output.get("result", "") - if isinstance(tool_output, dict) - else str(tool_output) - ) - exit_code: int | None = None - output_text = raw_text - m = re.match(r"^Exit code:\s*(\d+)", raw_text) - if m: - exit_code = int(m.group(1)) - om = re.search(r"\nOutput:\n([\s\S]*)", raw_text) - output_text = om.group(1) if om else "" - thread_id_str = config.get("configurable", {}).get("thread_id", "") - - for sf_match in re.finditer( - r"^SANDBOX_FILE:\s*(.+)$", output_text, re.MULTILINE - ): - fpath = sf_match.group(1).strip() - if fpath and fpath not in result.sandbox_files: - result.sandbox_files.append(fpath) - - yield _emit_tool_output( - tool_call_id, - { - "exit_code": exit_code, - "output": output_text, - "thread_id": thread_id_str, - }, - ) - elif tool_name == "web_search": - xml = ( - tool_output.get("result", str(tool_output)) - if isinstance(tool_output, dict) - else str(tool_output) - ) - citations: dict[str, dict[str, str]] = {} - for m in re.finditer( - r"<!\[CDATA\[(.*?)\]\]>\s*", - xml, - ): - title, url = m.group(1).strip(), m.group(2).strip() - if url.startswith("http") and url not in citations: - citations[url] = {"title": title} - for m in re.finditer( - r"", - xml, - ): - chunk_url, content = m.group(1).strip(), m.group(2).strip() - if ( - chunk_url.startswith("http") - and chunk_url in citations - and content - ): - citations[chunk_url]["snippet"] = ( - content[:200] + "…" if len(content) > 200 else content - ) - yield _emit_tool_output( - tool_call_id, - {"status": "completed", "citations": citations}, - ) - else: - yield _emit_tool_output( - tool_call_id, - {"status": "completed", "result_length": len(str(tool_output))}, - ) - yield streaming_service.format_terminal_info( - f"Tool {tool_name} completed", "success" - ) - - elif event_type == "on_custom_event" and event.get("name") == "report_progress": - # Live progress updates from inside the generate_report tool - data = event.get("data", {}) - message = data.get("message", "") - if message and last_active_step_id: - phase = data.get("phase", "") - # Always keep the "Topic: ..." line - topic_items = [ - item for item in last_active_step_items if item.startswith("Topic:") - ] - - if phase in ("revising_section", "adding_section"): - # During section-level ops: keep plan summary + show current op - plan_items = [ - item - for item in last_active_step_items - if item.startswith("Topic:") - or item.startswith("Modifying ") - or item.startswith("Adding ") - or item.startswith("Removing ") - ] - # Only keep plan_items that don't end with "..." (not progress lines) - plan_items = [ - item for item in plan_items if not item.endswith("...") - ] - last_active_step_items = [*plan_items, message] - else: - # Phase transitions: replace everything after topic - last_active_step_items = [*topic_items, message] - - yield _emit_thinking_step( - step_id=last_active_step_id, - title=last_active_step_title, - status="in_progress", - items=last_active_step_items, - ) - - elif ( - event_type == "on_custom_event" and event.get("name") == "document_created" - ): - data = event.get("data", {}) - if data.get("id"): - yield streaming_service.format_data( - "documents-updated", - { - "action": "created", - "document": data, - }, - ) - - elif event_type == "on_custom_event" and event.get("name") == "action_log": - # Surface a freshly committed AgentActionLog row so the chat - # tool card can render its Revert button immediately. - data = event.get("data", {}) - if data.get("id") is not None: - yield streaming_service.format_data("action-log", data) - - elif ( - event_type == "on_custom_event" - and event.get("name") == "action_log_updated" - ): - # Reversibility flipped in kb_persistence after the SAVEPOINT - # for a destructive op (rm/rmdir/move/edit/write) committed. - # Frontend uses this to flip the card's Revert - # button on without re-fetching the actions list. - data = event.get("data", {}) - if data.get("id") is not None: - yield streaming_service.format_data("action-log-updated", data) - - elif event_type in ("on_chain_end", "on_agent_end"): - if current_text_id is not None: - yield streaming_service.format_text_end(current_text_id) - if content_builder is not None: - content_builder.on_text_end(current_text_id) - current_text_id = None - - if current_text_id is not None: - yield streaming_service.format_text_end(current_text_id) - if content_builder is not None: - content_builder.on_text_end(current_text_id) - - completion_event = complete_current_step() - if completion_event: - yield completion_event + accumulated_text = result.accumulated_text state = await agent.aget_state(config) state_values = getattr(state, "values", {}) or {} @@ -2397,7 +912,6 @@ async def _stream_agent_events( result.commit_gate_reason = "" result.accumulated_text = accumulated_text - result.agent_called_update_memory = called_update_memory _log_file_contract("turn_outcome", result) interrupt_value = _first_interrupt_value(state) diff --git a/surfsense_backend/app/tasks/chat/streaming/__init__.py b/surfsense_backend/app/tasks/chat/streaming/__init__.py new file mode 100644 index 000000000..70c99342a --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/__init__.py @@ -0,0 +1,3 @@ +"""Chat streaming helpers (e.g. LangGraph → SSE relay under ``graph_stream``).""" + +from __future__ import annotations diff --git a/surfsense_backend/app/tasks/chat/streaming/errors/__init__.py b/surfsense_backend/app/tasks/chat/streaming/errors/__init__.py new file mode 100644 index 000000000..02284d4b0 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/errors/__init__.py @@ -0,0 +1,3 @@ +"""Error classification, structured logging, and terminal-error SSE emission.""" + +from __future__ import annotations diff --git a/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py b/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py new file mode 100644 index 000000000..3af2b9f9f --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py @@ -0,0 +1,187 @@ +"""Classify stream exceptions for logging and client error payloads.""" + +from __future__ import annotations + +import json +import logging +import time +from typing import Any, Literal + +from app.agents.new_chat.errors import BusyError +from app.agents.new_chat.middleware.busy_mutex import ( + get_cancel_state, + is_cancel_requested, +) + +TURN_CANCELLING_INITIAL_DELAY_MS = 200 +TURN_CANCELLING_BACKOFF_FACTOR = 2 +TURN_CANCELLING_MAX_DELAY_MS = 1500 + + +def compute_turn_cancelling_retry_delay(attempt: int) -> int: + if attempt < 1: + attempt = 1 + delay = TURN_CANCELLING_INITIAL_DELAY_MS * ( + TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1) + ) + return min(delay, TURN_CANCELLING_MAX_DELAY_MS) + + +def log_chat_stream_error( + *, + flow: Literal["new", "resume", "regenerate"], + error_kind: str, + error_code: str | None, + severity: Literal["info", "warn", "error"], + is_expected: bool, + request_id: str | None, + thread_id: int | None, + search_space_id: int | None, + user_id: str | None, + message: str, + extra: dict[str, Any] | None = None, +) -> None: + payload: dict[str, Any] = { + "event": "chat_stream_error", + "flow": flow, + "error_kind": error_kind, + "error_code": error_code, + "severity": severity, + "is_expected": is_expected, + "request_id": request_id or "unknown", + "thread_id": thread_id, + "search_space_id": search_space_id, + "user_id": user_id, + "message": message, + } + if extra: + payload.update(extra) + + logger = logging.getLogger(__name__) + rendered = json.dumps(payload, ensure_ascii=False) + if severity == "error": + logger.error("[chat_stream_error] %s", rendered) + elif severity == "warn": + logger.warning("[chat_stream_error] %s", rendered) + else: + logger.info("[chat_stream_error] %s", rendered) + + +def _parse_error_payload(message: str) -> dict[str, Any] | None: + candidates = [message] + first_brace_idx = message.find("{") + if first_brace_idx >= 0: + candidates.append(message[first_brace_idx:]) + + for candidate in candidates: + try: + parsed = json.loads(candidate) + if isinstance(parsed, dict): + return parsed + except Exception: + continue + return None + + +def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None: + if not isinstance(parsed, dict): + return None + candidates: list[Any] = [parsed.get("code")] + nested = parsed.get("error") + if isinstance(nested, dict): + candidates.append(nested.get("code")) + for value in candidates: + try: + if value is None: + continue + return int(value) + except Exception: + continue + return None + + +def is_provider_rate_limited(exc: BaseException) -> bool: + """Return True if the exception looks like an upstream HTTP 429 / rate limit.""" + raw = str(exc) + lowered = raw.lower() + if "ratelimit" in type(exc).__name__.lower(): + return True + parsed = _parse_error_payload(raw) + provider_code = _extract_provider_error_code(parsed) + if provider_code == 429: + return True + + provider_error_type = "" + if parsed: + top_type = parsed.get("type") + if isinstance(top_type, str): + provider_error_type = top_type.lower() + nested = parsed.get("error") + if isinstance(nested, dict): + nested_type = nested.get("type") + if isinstance(nested_type, str): + provider_error_type = nested_type.lower() + if provider_error_type == "rate_limit_error": + return True + + return ( + "rate limited" in lowered + or "rate-limited" in lowered + or "temporarily rate-limited upstream" in lowered + ) + + +def classify_stream_exception( + exc: Exception, + *, + flow_label: str, +) -> tuple[ + str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None +]: + """Return kind, code, severity, expected flag, message, and optional extra dict.""" + raw = str(exc) + if isinstance(exc, BusyError) or "Thread is busy with another request" in raw: + busy_thread_id = str(exc.request_id) if isinstance(exc, BusyError) else None + if busy_thread_id and is_cancel_requested(busy_thread_id): + cancel_state = get_cancel_state(busy_thread_id) + attempt = cancel_state[0] if cancel_state else 1 + retry_after_ms = compute_turn_cancelling_retry_delay(attempt) + retry_after_at = int(time.time() * 1000) + retry_after_ms + return ( + "thread_busy", + "TURN_CANCELLING", + "info", + True, + "A previous response is still stopping. Please try again in a moment.", + { + "retry_after_ms": retry_after_ms, + "retry_after_at": retry_after_at, + }, + ) + return ( + "thread_busy", + "THREAD_BUSY", + "warn", + True, + "Another response is still finishing for this thread. Please try again in a moment.", + None, + ) + + if is_provider_rate_limited(exc): + return ( + "rate_limited", + "RATE_LIMITED", + "warn", + True, + "This model is temporarily rate-limited. Please try again in a few seconds or switch models.", + None, + ) + + return ( + "server_error", + "SERVER_ERROR", + "error", + False, + f"Error during {flow_label}: {raw}", + None, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/errors/emitter.py b/surfsense_backend/app/tasks/chat/streaming/errors/emitter.py new file mode 100644 index 000000000..95806ab87 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/errors/emitter.py @@ -0,0 +1,38 @@ +"""Emit one terminal error SSE frame and log via the stream error classifier.""" + +from __future__ import annotations + +from typing import Any, Literal + +from .classifier import log_chat_stream_error + + +def emit_stream_terminal_error( + *, + streaming_service: Any, + flow: Literal["new", "resume", "regenerate"], + request_id: str | None, + thread_id: int, + search_space_id: int, + user_id: str | None, + message: str, + error_kind: str = "server_error", + error_code: str = "SERVER_ERROR", + severity: Literal["info", "warn", "error"] = "error", + is_expected: bool = False, + extra: dict[str, Any] | None = None, +) -> str: + log_chat_stream_error( + flow=flow, + error_kind=error_kind, + error_code=error_code, + severity=severity, + is_expected=is_expected, + request_id=request_id, + thread_id=thread_id, + search_space_id=search_space_id, + user_id=user_id, + message=message, + extra=extra, + ) + return streaming_service.format_error(message, error_code=error_code, extra=extra) diff --git a/surfsense_backend/app/tasks/chat/streaming/graph_stream/__init__.py b/surfsense_backend/app/tasks/chat/streaming/graph_stream/__init__.py new file mode 100644 index 000000000..e3bf0426c --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/graph_stream/__init__.py @@ -0,0 +1,21 @@ +"""LangGraph ``astream_events`` → SSE (``stream_output`` + ``StreamingResult``). + +Imports are lazy to avoid a circular import with ``relay.event_relay``. +""" + +from __future__ import annotations + +__all__ = ["StreamingResult", "stream_output"] + + +def __getattr__(name: str): + if name == "stream_output": + from app.tasks.chat.streaming.graph_stream.event_stream import stream_output + + return stream_output + if name == "StreamingResult": + from app.tasks.chat.streaming.graph_stream.result import StreamingResult + + return StreamingResult + msg = f"module {__name__!r} has no attribute {name!r}" + raise AttributeError(msg) diff --git a/surfsense_backend/app/tasks/chat/streaming/graph_stream/event_stream.py b/surfsense_backend/app/tasks/chat/streaming/graph_stream/event_stream.py new file mode 100644 index 000000000..9a309f9d7 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/graph_stream/event_stream.py @@ -0,0 +1,51 @@ +"""Run LangGraph event streams through ``EventRelay``.""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import Any + +from app.tasks.chat.streaming.graph_stream.result import StreamingResult +from app.tasks.chat.streaming.relay.event_relay import EventRelay +from app.tasks.chat.streaming.relay.state import AgentEventRelayState + + +async def stream_output( + *, + agent: Any, + config: dict[str, Any], + input_data: Any, + streaming_service: Any, + result: StreamingResult, + step_prefix: str = "thinking", + initial_step_id: str | None = None, + initial_step_title: str = "", + initial_step_items: list[str] | None = None, + content_builder: Any | None = None, + runtime_context: Any = None, +) -> AsyncIterator[str]: + """Yield SSE frames from agent ``astream_events`` via ``EventRelay``.""" + state = AgentEventRelayState.for_invocation( + initial_step_id=initial_step_id, + initial_step_title=initial_step_title, + initial_step_items=initial_step_items, + ) + + astream_kwargs: dict[str, Any] = {"config": config, "version": "v2"} + if runtime_context is not None: + astream_kwargs["context"] = runtime_context + + events = agent.astream_events(input_data, **astream_kwargs) + relay = EventRelay(streaming_service=streaming_service) + async for frame in relay.relay( + events, + state=state, + result=result, + step_prefix=step_prefix, + content_builder=content_builder, + config=config, + ): + yield frame + + result.accumulated_text = state.accumulated_text + result.agent_called_update_memory = state.called_update_memory diff --git a/surfsense_backend/app/tasks/chat/streaming/graph_stream/result.py b/surfsense_backend/app/tasks/chat/streaming/graph_stream/result.py new file mode 100644 index 000000000..40404e9d0 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/graph_stream/result.py @@ -0,0 +1,28 @@ +"""Mutable facts collected while relaying one agent stream (``stream_output``).""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class StreamingResult: + accumulated_text: str = "" + is_interrupted: bool = False + interrupt_value: dict[str, Any] | None = None + sandbox_files: list[str] = field(default_factory=list) + agent_called_update_memory: bool = False + request_id: str | None = None + turn_id: str = "" + filesystem_mode: str = "cloud" + client_platform: str = "web" + intent_detected: str = "chat_only" + intent_confidence: float = 0.0 + write_attempted: bool = False + write_succeeded: bool = False + verification_succeeded: bool = False + commit_gate_passed: bool = True + commit_gate_reason: str = "" + assistant_message_id: int | None = None + content_builder: Any | None = field(default=None, repr=False) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/__init__.py new file mode 100644 index 000000000..3e2165932 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/__init__.py @@ -0,0 +1,3 @@ +"""LangGraph stream handlers by event kind.""" + +from __future__ import annotations diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/chain_end.py b/surfsense_backend/app/tasks/chat/streaming/handlers/chain_end.py new file mode 100644 index 000000000..c61058ac7 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/chain_end.py @@ -0,0 +1,23 @@ +"""Close open text when a LangGraph chain or agent node finishes.""" + +from __future__ import annotations + +from collections.abc import Iterator +from typing import Any + +from app.tasks.chat.streaming.relay.state import AgentEventRelayState + + +def iter_chain_end_frames( + _event: dict[str, Any], + *, + state: AgentEventRelayState, + streaming_service: Any, + content_builder: Any | None, +) -> Iterator[str]: + """Close the open text stream if one is open.""" + if state.current_text_id is not None: + yield streaming_service.format_text_end(state.current_text_id) + if content_builder is not None: + content_builder.on_text_end(state.current_text_id) + state.current_text_id = None diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/chat_model_stream.py b/surfsense_backend/app/tasks/chat/streaming/handlers/chat_model_stream.py new file mode 100644 index 000000000..c3f6d6d59 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/chat_model_stream.py @@ -0,0 +1,159 @@ +"""Chat model stream: text, reasoning, and tool-call chunk SSE.""" + +from __future__ import annotations + +from collections.abc import Iterator +from typing import Any + +from app.tasks.chat.streaming.helpers.chunk_parts import extract_chunk_parts +from app.tasks.chat.streaming.relay.state import AgentEventRelayState +from app.tasks.chat.streaming.relay.task_span import ensure_pending_task_span_for_lc +from app.tasks.chat.streaming.relay.thinking_step_completion import ( + complete_active_thinking_step, +) + + +def iter_chat_model_stream_frames( + event: dict[str, Any], + *, + state: AgentEventRelayState, + streaming_service: Any, + content_builder: Any | None, + step_prefix: str, +) -> Iterator[str]: + """SSE frames for one chat-model chunk.""" + if state.active_tool_depth > 0: + return + if "surfsense:internal" in event.get("tags", []): + return + chunk = event.get("data", {}).get("chunk") + if not chunk: + return + parts = extract_chunk_parts(chunk) + + reasoning_delta = parts["reasoning"] + text_delta = parts["text"] + + if reasoning_delta: + if state.current_text_id is not None: + yield streaming_service.format_text_end(state.current_text_id) + if content_builder is not None: + content_builder.on_text_end(state.current_text_id) + state.current_text_id = None + if state.current_reasoning_id is None: + comp, new_active = complete_active_thinking_step( + state=state, + streaming_service=streaming_service, + content_builder=content_builder, + last_active_step_id=state.last_active_step_id, + last_active_step_title=state.last_active_step_title, + last_active_step_items=state.last_active_step_items, + completed_step_ids=state.completed_step_ids, + ) + if comp: + yield comp + state.last_active_step_id = new_active + if state.just_finished_tool: + state.last_active_step_id = None + state.last_active_step_title = "" + state.last_active_step_items = [] + state.just_finished_tool = False + state.current_reasoning_id = streaming_service.generate_reasoning_id() + yield streaming_service.format_reasoning_start(state.current_reasoning_id) + if content_builder is not None: + content_builder.on_reasoning_start(state.current_reasoning_id) + yield streaming_service.format_reasoning_delta( + state.current_reasoning_id, reasoning_delta + ) + if content_builder is not None: + content_builder.on_reasoning_delta( + state.current_reasoning_id, reasoning_delta + ) + + if text_delta: + if state.current_reasoning_id is not None: + yield streaming_service.format_reasoning_end(state.current_reasoning_id) + if content_builder is not None: + content_builder.on_reasoning_end(state.current_reasoning_id) + state.current_reasoning_id = None + if state.current_text_id is None: + comp, new_active = complete_active_thinking_step( + state=state, + streaming_service=streaming_service, + content_builder=content_builder, + last_active_step_id=state.last_active_step_id, + last_active_step_title=state.last_active_step_title, + last_active_step_items=state.last_active_step_items, + completed_step_ids=state.completed_step_ids, + ) + if comp: + yield comp + state.last_active_step_id = new_active + if state.just_finished_tool: + state.last_active_step_id = None + state.last_active_step_title = "" + state.last_active_step_items = [] + state.just_finished_tool = False + state.current_text_id = streaming_service.generate_text_id() + yield streaming_service.format_text_start(state.current_text_id) + if content_builder is not None: + content_builder.on_text_start(state.current_text_id) + yield streaming_service.format_text_delta(state.current_text_id, text_delta) + state.accumulated_text += text_delta + if content_builder is not None: + content_builder.on_text_delta(state.current_text_id, text_delta) + + if parts["tool_call_chunks"]: + for tcc in parts["tool_call_chunks"]: + idx = tcc.get("index") + + if idx is not None and idx not in state.index_to_meta: + lc_id = tcc.get("id") + name = tcc.get("name") + if lc_id and name: + ui_id = lc_id + tool_input_metadata: dict[str, Any] | None = None + if name == "task": + sid = ensure_pending_task_span_for_lc(state, str(lc_id)) + tool_input_metadata = {"spanId": sid} + + if state.current_text_id is not None: + yield streaming_service.format_text_end(state.current_text_id) + if content_builder is not None: + content_builder.on_text_end(state.current_text_id) + state.current_text_id = None + if state.current_reasoning_id is not None: + yield streaming_service.format_reasoning_end( + state.current_reasoning_id + ) + if content_builder is not None: + content_builder.on_reasoning_end(state.current_reasoning_id) + state.current_reasoning_id = None + + state.index_to_meta[idx] = { + "ui_id": ui_id, + "lc_id": lc_id, + "name": name, + } + yield streaming_service.format_tool_input_start( + ui_id, + name, + langchain_tool_call_id=lc_id, + metadata=tool_input_metadata, + ) + if content_builder is not None: + content_builder.on_tool_input_start( + ui_id, name, lc_id, metadata=tool_input_metadata + ) + + meta = state.index_to_meta.get(idx) if idx is not None else None + if meta: + args_chunk = tcc.get("args") or "" + if args_chunk: + yield streaming_service.format_tool_input_delta( + meta["ui_id"], args_chunk + ) + if content_builder is not None: + content_builder.on_tool_input_delta(meta["ui_id"], args_chunk) + else: + state.pending_tool_call_chunks.append(tcc) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/custom_event_dispatch.py b/surfsense_backend/app/tasks/chat/streaming/handlers/custom_event_dispatch.py new file mode 100644 index 000000000..69f4b8a24 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/custom_event_dispatch.py @@ -0,0 +1,57 @@ +"""Custom graph events routed to SSE (documents, action logs, report progress).""" + +from __future__ import annotations + +from collections.abc import Iterator +from typing import Any + +from app.tasks.chat.streaming.handlers.custom_events import ( + handle_action_log, + handle_action_log_updated, + handle_document_created, + handle_report_progress, +) +from app.tasks.chat.streaming.relay.state import AgentEventRelayState + + +def iter_custom_event_frames( + event: dict[str, Any], + *, + state: AgentEventRelayState, + streaming_service: Any, + content_builder: Any | None, +) -> Iterator[str]: + """Yield any SSE produced by ad-hoc graph events (documents, action logs, report progress).""" + name = event.get("name") + data = event.get("data", {}) + + if name == "report_progress": + frame, state.last_active_step_items = handle_report_progress( + data, + last_active_step_id=state.last_active_step_id, + last_active_step_title=state.last_active_step_title, + last_active_step_items=state.last_active_step_items, + streaming_service=streaming_service, + content_builder=content_builder, + thinking_metadata=state.span_metadata_if_active(), + ) + if frame: + yield frame + return + + if name == "document_created": + frame = handle_document_created(data, streaming_service=streaming_service) + if frame: + yield frame + return + + if name == "action_log": + frame = handle_action_log(data, streaming_service=streaming_service) + if frame: + yield frame + return + + if name == "action_log_updated": + frame = handle_action_log_updated(data, streaming_service=streaming_service) + if frame: + yield frame diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/custom_events.py b/surfsense_backend/app/tasks/chat/streaming/handlers/custom_events.py new file mode 100644 index 000000000..e48e2c493 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/custom_events.py @@ -0,0 +1,79 @@ +"""Custom-event payloads turned into SSE (no model/tool stream handling).""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame + + +def handle_report_progress( + data: dict[str, Any], + *, + last_active_step_id: str | None, + last_active_step_title: str, + last_active_step_items: list[str], + streaming_service: Any, + content_builder: Any | None, + thinking_metadata: dict[str, Any] | None = None, +) -> tuple[str | None, list[str]]: + """Update report step items; may emit one thinking SSE frame. + + Returns (frame or None, items list after update). + """ + message = data.get("message", "") + if not message or not last_active_step_id: + return None, last_active_step_items + + phase = data.get("phase", "") + topic_items = [ + item for item in last_active_step_items if item.startswith("Topic:") + ] + + if phase in ("revising_section", "adding_section"): + plan_items = [ + item + for item in last_active_step_items + if item.startswith("Topic:") + or item.startswith("Modifying ") + or item.startswith("Adding ") + or item.startswith("Removing ") + ] + plan_items = [item for item in plan_items if not item.endswith("...")] + new_items = [*plan_items, message] + else: + new_items = [*topic_items, message] + + frame = emit_thinking_step_frame( + streaming_service=streaming_service, + content_builder=content_builder, + step_id=last_active_step_id, + title=last_active_step_title, + status="in_progress", + items=new_items, + metadata=thinking_metadata, + ) + return frame, new_items + + +def handle_document_created(data: dict[str, Any], *, streaming_service: Any) -> str | None: + if not data.get("id"): + return None + return streaming_service.format_data( + "documents-updated", + {"action": "created", "document": data}, + ) + + +def handle_action_log(data: dict[str, Any], *, streaming_service: Any) -> str | None: + if data.get("id") is None: + return None + return streaming_service.format_data("action-log", data) + + +def handle_action_log_updated( + data: dict[str, Any], *, streaming_service: Any +) -> str | None: + if data.get("id") is None: + return None + return streaming_service.format_data("action-log-updated", data) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tool_end.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tool_end.py new file mode 100644 index 000000000..421c67a6d --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tool_end.py @@ -0,0 +1,121 @@ +"""Tool end: thinking completion, tool output, and terminal SSE.""" + +from __future__ import annotations + +import json +from collections.abc import Iterator +from typing import Any + +from app.tasks.chat.streaming.handlers.tools import ( + ToolCompletionEmissionContext, + iter_tool_completion_emission_frames, + resolve_tool_completed_thinking_step, +) +from app.tasks.chat.streaming.helpers.tool_output import tool_output_has_error +from app.tasks.chat.streaming.relay.state import AgentEventRelayState +from app.tasks.chat.streaming.relay.task_span import clear_task_span_if_delegating_task_ended +from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame + + +def iter_tool_end_frames( + event: dict[str, Any], + *, + state: AgentEventRelayState, + streaming_service: Any, + content_builder: Any | None, + result: Any, + step_prefix: str, + config: dict[str, Any], +) -> Iterator[str]: + """SSE frames when one tool run finishes.""" + state.active_tool_depth = max(0, state.active_tool_depth - 1) + run_id = event.get("run_id", "") + tool_name = event.get("name", "unknown_tool") + raw_output = event.get("data", {}).get("output", "") + staged_file_path = ( + state.file_path_by_run.pop(run_id, None) if run_id else None + ) + + if tool_name == "update_memory": + state.called_update_memory = True + + if hasattr(raw_output, "content"): + content = raw_output.content + if isinstance(content, str): + try: + tool_output = json.loads(content) + except (json.JSONDecodeError, TypeError): + tool_output = {"result": content} + elif isinstance(content, dict): + tool_output = content + else: + tool_output = {"result": str(content)} + elif isinstance(raw_output, dict): + tool_output = raw_output + else: + tool_output = {"result": str(raw_output) if raw_output else "completed"} + + if tool_name in ("write_file", "edit_file"): + if tool_output_has_error(tool_output): + pass + else: + result.write_succeeded = True + result.verification_succeeded = True + + tool_call_id = state.ui_tool_call_id_by_run.get( + run_id, + f"call_{run_id[:32]}" if run_id else "call_unknown", + ) + original_step_id = state.tool_step_ids.get( + run_id, f"{step_prefix}-unknown-{run_id[:8]}" + ) + state.completed_step_ids.add(original_step_id) + + holder = state.current_lc_tool_call_id + holder["value"] = None + authoritative = getattr(raw_output, "tool_call_id", None) + if isinstance(authoritative, str) and authoritative: + holder["value"] = authoritative + if run_id: + state.lc_tool_call_id_by_run[run_id] = authoritative + elif run_id and run_id in state.lc_tool_call_id_by_run: + holder["value"] = state.lc_tool_call_id_by_run[run_id] + + items = state.last_active_step_items + title, completed_items = resolve_tool_completed_thinking_step( + tool_name, tool_output, items + ) + yield emit_thinking_step_frame( + streaming_service=streaming_service, + content_builder=content_builder, + step_id=original_step_id, + title=title, + status="completed", + items=completed_items, + metadata=state.span_metadata_if_active(), + ) + + state.just_finished_tool = True + state.last_active_step_id = None + state.last_active_step_title = "" + state.last_active_step_items = [] + + emission_ctx = ToolCompletionEmissionContext( + tool_name=tool_name, + tool_call_id=tool_call_id, + tool_output=tool_output, + streaming_service=streaming_service, + content_builder=content_builder, + langchain_tool_call_id_holder=holder, + stream_result=result, + langgraph_config=config, + staged_workspace_file_path=staged_file_path, + tool_metadata=state.tool_activity_metadata( + thinking_step_id=original_step_id, + ), + ) + yield from iter_tool_completion_emission_frames(emission_ctx) + + clear_task_span_if_delegating_task_ended( + state, tool_name=tool_name, run_id=run_id + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tool_output_frame.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tool_output_frame.py new file mode 100644 index 000000000..4cd8e3274 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tool_output_frame.py @@ -0,0 +1,29 @@ +"""Emit tool-output SSE and optional assistant content updates.""" + +from __future__ import annotations + +from typing import Any + + +def emit_tool_output_available_frame( + *, + streaming_service: Any, + content_builder: Any | None, + langchain_id_holder: dict[str, str | None], + call_id: str, + output: Any, + tool_metadata: dict[str, Any] | None = None, +) -> str: + if content_builder is not None: + content_builder.on_tool_output_available( + call_id, + output, + langchain_id_holder["value"], + metadata=tool_metadata, + ) + return streaming_service.format_tool_output_available( + call_id, + output, + langchain_tool_call_id=langchain_id_holder["value"], + metadata=tool_metadata, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tool_start.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tool_start.py new file mode 100644 index 000000000..e0cac307c --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tool_start.py @@ -0,0 +1,161 @@ +"""Tool start: thinking-step and tool-input SSE.""" + +from __future__ import annotations + +import json +from collections.abc import Iterator +from typing import Any + +from app.tasks.chat.streaming.handlers.tools import resolve_tool_start_thinking +from app.tasks.chat.streaming.helpers.tool_call_matching import ( + match_buffered_langchain_tool_call_id, +) +from app.tasks.chat.streaming.relay.state import AgentEventRelayState +from app.tasks.chat.streaming.relay.task_span import open_task_span +from app.tasks.chat.streaming.relay.thinking_step_completion import ( + complete_active_thinking_step, +) +from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame + + +def iter_tool_start_frames( + event: dict[str, Any], + *, + state: AgentEventRelayState, + streaming_service: Any, + content_builder: Any | None, + result: Any, + step_prefix: str, +) -> Iterator[str]: + """SSE frames for the start of one tool run.""" + state.active_tool_depth += 1 + tool_name = event.get("name", "unknown_tool") + run_id = event.get("run_id", "") + tool_input = event.get("data", {}).get("input", {}) + if tool_name in ("write_file", "edit_file"): + result.write_attempted = True + if isinstance(tool_input, dict): + file_path = tool_input.get("file_path") + if isinstance(file_path, str) and file_path.strip() and run_id: + state.file_path_by_run[run_id] = file_path.strip() + + if state.current_text_id is not None: + yield streaming_service.format_text_end(state.current_text_id) + if content_builder is not None: + content_builder.on_text_end(state.current_text_id) + state.current_text_id = None + + if state.last_active_step_title != "Synthesizing response": + comp, new_active = complete_active_thinking_step( + state=state, + streaming_service=streaming_service, + content_builder=content_builder, + last_active_step_id=state.last_active_step_id, + last_active_step_title=state.last_active_step_title, + last_active_step_items=state.last_active_step_items, + completed_step_ids=state.completed_step_ids, + ) + if comp: + yield comp + state.last_active_step_id = new_active + + state.just_finished_tool = False + tool_step_id = state.next_thinking_step_id(step_prefix) + state.tool_step_ids[run_id] = tool_step_id + state.last_active_step_id = tool_step_id + + matched_meta: dict[str, str] | None = None + taken_ui_ids = set(state.ui_tool_call_id_by_run.values()) + for meta in state.index_to_meta.values(): + if meta["name"] == tool_name and meta["ui_id"] not in taken_ui_ids: + matched_meta = meta + break + + tool_call_id: str + langchain_tool_call_id: str | None = None + if matched_meta is not None: + tool_call_id = matched_meta["ui_id"] + langchain_tool_call_id = matched_meta["lc_id"] + if run_id: + state.lc_tool_call_id_by_run[run_id] = matched_meta["lc_id"] + else: + tool_call_id = ( + f"call_{run_id[:32]}" + if run_id + else streaming_service.generate_tool_call_id() + ) + langchain_tool_call_id = match_buffered_langchain_tool_call_id( + state.pending_tool_call_chunks, + tool_name, + run_id, + state.lc_tool_call_id_by_run, + ) + + if tool_name == "task": + open_task_span( + state, + run_id=run_id, + langchain_tool_call_id=langchain_tool_call_id, + ) + + span_md = state.span_metadata_if_active() + tool_md = state.tool_activity_metadata(thinking_step_id=tool_step_id) + + if matched_meta is None: + yield streaming_service.format_tool_input_start( + tool_call_id, + tool_name, + langchain_tool_call_id=langchain_tool_call_id, + metadata=tool_md, + ) + if content_builder is not None: + content_builder.on_tool_input_start( + tool_call_id, + tool_name, + langchain_tool_call_id, + metadata=tool_md, + ) + + thinking = resolve_tool_start_thinking(tool_name, tool_input) + state.last_active_step_title = thinking.title + state.last_active_step_items = thinking.items + frame_kw: dict[str, Any] = { + "streaming_service": streaming_service, + "content_builder": content_builder, + "step_id": tool_step_id, + "title": thinking.title, + "status": "in_progress", + "metadata": span_md, + } + if thinking.include_items_on_frame: + frame_kw["items"] = thinking.items + yield emit_thinking_step_frame(**frame_kw) + + if run_id: + state.ui_tool_call_id_by_run[run_id] = tool_call_id + + if isinstance(tool_input, dict): + _safe_input: dict[str, Any] = {} + for _k, _v in tool_input.items(): + try: + json.dumps(_v) + _safe_input[_k] = _v + except (TypeError, ValueError, OverflowError): + pass + else: + _safe_input = {"input": tool_input} + yield streaming_service.format_tool_input_available( + tool_call_id, + tool_name, + _safe_input, + langchain_tool_call_id=langchain_tool_call_id, + metadata=tool_md, + ) + if content_builder is not None: + content_builder.on_tool_input_available( + tool_call_id, + tool_name, + _safe_input, + langchain_tool_call_id, + metadata=tool_md, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/__init__.py new file mode 100644 index 000000000..4b191c100 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/__init__.py @@ -0,0 +1,23 @@ +"""Per-tool streaming: thinking-step and completion emission.""" + +from __future__ import annotations + +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) +from app.tasks.chat.streaming.handlers.tools.registry import ( + iter_tool_completion_emission_frames, + resolve_tool_completed_thinking_step, + resolve_tool_start_thinking, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + +__all__ = [ + "ToolCompletionEmissionContext", + "ToolStartThinking", + "iter_tool_completion_emission_frames", + "resolve_tool_completed_thinking_step", + "resolve_tool_start_thinking", +] diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/shared/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/shared/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/shared/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/shared/emission.py new file mode 100644 index 000000000..8e19dc224 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/shared/emission.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from collections.abc import Iterator + +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) + + +def iter_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + out = ctx.tool_output + payload = out if isinstance(out, dict) else {"result": out} + yield ctx.emit_tool_output_card(payload) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/shared/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/shared/thinking.py new file mode 100644 index 000000000..7e9dd8b96 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/shared/thinking.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.default import ( + thinking as default_thinking, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + return default_thinking.resolve_start_thinking(tool_name, tool_input) + + +def resolve_completed_thinking( + tool_name: str, tool_output: Any, last_items: list[str], +) -> tuple[str, list[str]]: + return default_thinking.resolve_completed_thinking( + tool_name, tool_output, last_items + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/shared/tool_names.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/shared/tool_names.py new file mode 100644 index 000000000..ab698b32d --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/shared/tool_names.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +SHARED_CONNECTOR_TOOLS: frozenset[str] = frozenset( + { + "create_calendar_event", + "create_confluence_page", + "create_dropbox_file", + "create_gmail_draft", + "create_google_drive_file", + "create_jira_issue", + "create_linear_issue", + "create_notion_page", + "create_onedrive_file", + "delete_calendar_event", + "delete_confluence_page", + "delete_dropbox_file", + "delete_google_drive_file", + "delete_jira_issue", + "delete_linear_issue", + "delete_notion_page", + "delete_onedrive_file", + "send_gmail_email", + "trash_gmail_email", + "update_calendar_event", + "update_confluence_page", + "update_gmail_draft", + "update_jira_issue", + "update_linear_issue", + "update_notion_page", + } +) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/default/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/default/__init__.py new file mode 100644 index 000000000..5e84a37f4 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/default/__init__.py @@ -0,0 +1,3 @@ +"""Fallback tool package.""" + +from __future__ import annotations diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/default/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/default/emission.py new file mode 100644 index 000000000..e24c619a7 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/default/emission.py @@ -0,0 +1,24 @@ +"""Default tool-output card and a short completion terminal line.""" + +from __future__ import annotations + +from collections.abc import Iterator + +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) + + +def iter_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + yield ctx.emit_tool_output_card( + { + "status": "completed", + "result_length": len(str(ctx.tool_output)), + }, + ) + yield ctx.streaming_service.format_terminal_info( + f"Tool {ctx.tool_name} completed", + "success", + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/default/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/default/thinking.py new file mode 100644 index 000000000..46d15a4e7 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/default/thinking.py @@ -0,0 +1,23 @@ +"""Fallback thinking-step copy for unknown tools and connectors without custom UI.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_input + title = tool_name.replace("_", " ").strip().capitalize() or tool_name + return ToolStartThinking(title=title, items=[], include_items_on_frame=False) + + +def resolve_completed_thinking( + tool_name: str, tool_output: Any, last_items: list[str] +) -> tuple[str, list[str]]: + del tool_output + title = tool_name.replace("_", " ").strip().capitalize() or tool_name + return (title, last_items) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_image/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_image/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_image/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_image/emission.py new file mode 100644 index 000000000..762f75cca --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_image/emission.py @@ -0,0 +1,28 @@ +"""generate_image: tool card + terminal summary.""" + +from __future__ import annotations + +from collections.abc import Iterator + +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) + + +def iter_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + out = ctx.tool_output + payload = out if isinstance(out, dict) else {"result": out} + yield ctx.emit_tool_output_card(payload) + if isinstance(out, dict): + if out.get("error"): + yield ctx.streaming_service.format_terminal_info( + f"Image generation failed: {out['error'][:60]}", + "error", + ) + else: + yield ctx.streaming_service.format_terminal_info( + "Image generated successfully", + "success", + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_image/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_image/thinking.py new file mode 100644 index 000000000..9675cb0f2 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_image/thinking.py @@ -0,0 +1,39 @@ +"""generate_image: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.deliverables.shared.tool_input import ( + as_tool_input_dict, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + prompt = d.get("prompt", "") if isinstance(tool_input, dict) else str(tool_input) + return ToolStartThinking( + title="Generating image", + items=[f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}"], + ) + + +def resolve_completed_thinking( + tool_name: str, tool_output: Any, last_items: list[str], +) -> tuple[str, list[str]]: + del tool_name + items = last_items + if isinstance(tool_output, dict) and not tool_output.get("error"): + completed = [*items, "Image generated successfully"] + else: + error_msg = ( + tool_output.get("error", "Generation failed") + if isinstance(tool_output, dict) + else "Generation failed" + ) + completed = [*items, f"Error: {error_msg}"] + return ("Generating image", completed) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/emission.py new file mode 100644 index 000000000..f1a1e9c37 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/emission.py @@ -0,0 +1,37 @@ +"""generate_podcast: tool card + queue / success / failure terminal lines.""" + +from __future__ import annotations + +from collections.abc import Iterator + +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) + + +def iter_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + out = ctx.tool_output + payload = out if isinstance(out, dict) else {"result": out} + yield ctx.emit_tool_output_card(payload) + if isinstance(out, dict) and out.get("status") in ( + "pending", + "generating", + "processing", + ): + yield ctx.streaming_service.format_terminal_info( + f"Podcast queued: {out.get('title', 'Podcast')}", + "success", + ) + elif isinstance(out, dict) and out.get("status") in ("ready", "success"): + yield ctx.streaming_service.format_terminal_info( + f"Podcast generated successfully: {out.get('title', 'Podcast')}", + "success", + ) + elif isinstance(out, dict) and out.get("status") in ("failed", "error"): + error_msg = out.get("error", "Unknown error") + yield ctx.streaming_service.format_terminal_info( + f"Podcast generation failed: {error_msg}", + "error", + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/thinking.py new file mode 100644 index 000000000..b92e0c91f --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/thinking.py @@ -0,0 +1,80 @@ +"""generate_podcast: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.deliverables.shared.tool_input import ( + as_tool_input_dict, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + podcast_title = ( + d.get("podcast_title", "SurfSense Podcast") + if isinstance(tool_input, dict) + else "SurfSense Podcast" + ) + content_len = len( + d.get("source_content", "") if isinstance(tool_input, dict) else "" + ) + return ToolStartThinking( + title="Generating podcast", + items=[ + f"Title: {podcast_title}", + f"Content: {content_len:,} characters", + "Preparing audio generation...", + ], + ) + + +def resolve_completed_thinking( + tool_name: str, tool_output: Any, last_items: list[str], +) -> tuple[str, list[str]]: + del tool_name + items = last_items + podcast_status = ( + tool_output.get("status", "unknown") + if isinstance(tool_output, dict) + else "unknown" + ) + podcast_title = ( + tool_output.get("title", "Podcast") + if isinstance(tool_output, dict) + else "Podcast" + ) + if podcast_status in ("pending", "generating", "processing"): + completed = [ + f"Title: {podcast_title}", + "Podcast generation started", + "Processing in background...", + ] + elif podcast_status == "already_generating": + completed = [ + f"Title: {podcast_title}", + "Podcast already in progress", + "Please wait for it to complete", + ] + elif podcast_status in ("failed", "error"): + error_msg = ( + tool_output.get("error", "Unknown error") + if isinstance(tool_output, dict) + else "Unknown error" + ) + completed = [ + f"Title: {podcast_title}", + f"Error: {error_msg[:50]}", + ] + elif podcast_status in ("ready", "success"): + completed = [ + f"Title: {podcast_title}", + "Podcast ready", + ] + else: + completed = items + return ("Generating podcast", completed) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_report/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_report/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_report/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_report/emission.py new file mode 100644 index 000000000..1c5c71b8b --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_report/emission.py @@ -0,0 +1,33 @@ +"""generate_report: full payload + terminal line.""" + +from __future__ import annotations + +from collections.abc import Iterator + +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) + + +def iter_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + out = ctx.tool_output + payload = out if isinstance(out, dict) else {"result": out} + yield ctx.emit_tool_output_card(payload) + if isinstance(out, dict) and out.get("status") == "ready": + word_count = out.get("word_count", 0) + yield ctx.streaming_service.format_terminal_info( + f"Report generated: {out.get('title', 'Report')} ({word_count:,} words)", + "success", + ) + else: + error_msg = ( + out.get("error", "Unknown error") + if isinstance(out, dict) + else "Unknown error" + ) + yield ctx.streaming_service.format_terminal_info( + f"Report generation failed: {error_msg}", + "error", + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_report/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_report/thinking.py new file mode 100644 index 000000000..f912350f8 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_report/thinking.py @@ -0,0 +1,77 @@ +"""generate_report: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.deliverables.shared.tool_input import ( + as_tool_input_dict, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + report_topic = ( + d.get("topic", "Report") if isinstance(tool_input, dict) else "Report" + ) + is_revision = bool( + isinstance(tool_input, dict) and tool_input.get("parent_report_id") + ) + step_title = "Revising report" if is_revision else "Generating report" + return ToolStartThinking( + title=step_title, + items=[f"Topic: {report_topic}", "Analyzing source content..."], + ) + + +def resolve_completed_thinking( + tool_name: str, tool_output: Any, last_items: list[str], +) -> tuple[str, list[str]]: + del tool_name + items = last_items + report_status = ( + tool_output.get("status", "unknown") + if isinstance(tool_output, dict) + else "unknown" + ) + report_title = ( + tool_output.get("title", "Report") + if isinstance(tool_output, dict) + else "Report" + ) + word_count = ( + tool_output.get("word_count", 0) + if isinstance(tool_output, dict) + else 0 + ) + is_revision = ( + tool_output.get("is_revision", False) + if isinstance(tool_output, dict) + else False + ) + step_title = "Revising report" if is_revision else "Generating report" + + if report_status == "ready": + completed = [ + f"Topic: {report_title}", + f"{word_count:,} words", + "Report ready", + ] + elif report_status == "failed": + error_msg = ( + tool_output.get("error", "Unknown error") + if isinstance(tool_output, dict) + else "Unknown error" + ) + completed = [ + f"Topic: {report_title}", + f"Error: {error_msg[:50]}", + ] + else: + completed = items + + return (step_title, completed) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_resume/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_resume/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_resume/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_resume/emission.py new file mode 100644 index 000000000..dc8d3c7fc --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_resume/emission.py @@ -0,0 +1,32 @@ +"""generate_resume: full payload + terminal line.""" + +from __future__ import annotations + +from collections.abc import Iterator + +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) + + +def iter_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + out = ctx.tool_output + payload = out if isinstance(out, dict) else {"result": out} + yield ctx.emit_tool_output_card(payload) + if isinstance(out, dict) and out.get("status") == "ready": + yield ctx.streaming_service.format_terminal_info( + f"Resume generated: {out.get('title', 'Resume')}", + "success", + ) + else: + error_msg = ( + out.get("error", "Unknown error") + if isinstance(out, dict) + else "Unknown error" + ) + yield ctx.streaming_service.format_terminal_info( + f"Resume generation failed: {error_msg}", + "error", + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_resume/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_resume/thinking.py new file mode 100644 index 000000000..e81a80679 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_resume/thinking.py @@ -0,0 +1,24 @@ +"""generate_resume: generic thinking titles and items.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.default import ( + thinking as default_thinking, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + return default_thinking.resolve_start_thinking(tool_name, tool_input) + + +def resolve_completed_thinking( + tool_name: str, tool_output: Any, last_items: list[str], +) -> tuple[str, list[str]]: + return default_thinking.resolve_completed_thinking( + tool_name, tool_output, last_items + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_video_presentation/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_video_presentation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_video_presentation/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_video_presentation/emission.py new file mode 100644 index 000000000..21e27d4c3 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_video_presentation/emission.py @@ -0,0 +1,28 @@ +"""generate_video_presentation: tool card + terminal line.""" + +from __future__ import annotations + +from collections.abc import Iterator + +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) + + +def iter_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + out = ctx.tool_output + payload = out if isinstance(out, dict) else {"result": out} + yield ctx.emit_tool_output_card(payload) + if isinstance(out, dict) and out.get("status") == "pending": + yield ctx.streaming_service.format_terminal_info( + f"Video presentation queued: {out.get('title', 'Presentation')}", + "success", + ) + elif isinstance(out, dict) and out.get("status") == "failed": + error_msg = out.get("error", "Unknown error") + yield ctx.streaming_service.format_terminal_info( + f"Presentation generation failed: {error_msg}", + "error", + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_video_presentation/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_video_presentation/thinking.py new file mode 100644 index 000000000..5c5aa977d --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_video_presentation/thinking.py @@ -0,0 +1,52 @@ +"""generate_video_presentation: generic in-progress thinking; completion is status-driven.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.default import ( + thinking as default_thinking, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + return default_thinking.resolve_start_thinking(tool_name, tool_input) + + +def resolve_completed_thinking( + tool_name: str, tool_output: Any, last_items: list[str], +) -> tuple[str, list[str]]: + del tool_name + items = last_items + vp_status = ( + tool_output.get("status", "unknown") + if isinstance(tool_output, dict) + else "unknown" + ) + vp_title = ( + tool_output.get("title", "Presentation") + if isinstance(tool_output, dict) + else "Presentation" + ) + if vp_status in ("pending", "generating"): + completed = [ + f"Title: {vp_title}", + "Presentation generation started", + "Processing in background...", + ] + elif vp_status == "failed": + error_msg = ( + tool_output.get("error", "Unknown error") + if isinstance(tool_output, dict) + else "Unknown error" + ) + completed = [ + f"Title: {vp_title}", + f"Error: {error_msg[:50]}", + ] + else: + completed = items + return ("Generating video presentation", completed) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/save_document/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/save_document/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/save_document/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/save_document/emission.py new file mode 100644 index 000000000..68c93dede --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/save_document/emission.py @@ -0,0 +1,16 @@ +"""save_document: default completion card and terminal line.""" + +from __future__ import annotations + +from collections.abc import Iterator + +from app.tasks.chat.streaming.handlers.tools.default import emission as _default +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) + + +def iter_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + yield from _default.iter_completion_emission_frames(ctx) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/save_document/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/save_document/thinking.py new file mode 100644 index 000000000..77059a28c --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/save_document/thinking.py @@ -0,0 +1,38 @@ +"""save_document: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.deliverables.shared.tool_input import ( + as_tool_input_dict, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + doc_title = d.get("title", "") if isinstance(tool_input, dict) else str(tool_input) + display_title = doc_title[:60] + ("…" if len(doc_title) > 60 else "") + return ToolStartThinking(title="Saving document", items=[display_title]) + + +def resolve_completed_thinking( + tool_name: str, tool_output: Any, last_items: list[str], +) -> tuple[str, list[str]]: + del tool_name + items = last_items + result_str = ( + tool_output.get("result", "") + if isinstance(tool_output, dict) + else str(tool_output) + ) + is_error = "Error" in result_str + completed = [ + *items, + result_str[:80] if is_error else "Saved to knowledge base", + ] + return ("Saving document", completed) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/shared/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/shared/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/shared/tool_input.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/shared/tool_input.py new file mode 100644 index 000000000..1303cf09f --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/shared/tool_input.py @@ -0,0 +1,9 @@ +"""Tool-call args for deliverable thinking modules.""" + +from __future__ import annotations + +from typing import Any + + +def as_tool_input_dict(tool_input: Any) -> dict[str, Any]: + return tool_input if isinstance(tool_input, dict) else {} diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/tool_names.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/tool_names.py new file mode 100644 index 000000000..5924af196 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/tool_names.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +DELIVERABLE_TOOLS: frozenset[str] = frozenset( + { + "generate_image", + "generate_podcast", + "generate_report", + "generate_resume", + "generate_video_presentation", + "save_document", + } +) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/emission_context.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/emission_context.py new file mode 100644 index 000000000..baa1d7336 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/emission_context.py @@ -0,0 +1,36 @@ +"""Context for one tool-completion emission pass.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from app.tasks.chat.streaming.handlers.tool_output_frame import ( + emit_tool_output_available_frame, +) + + +@dataclass +class ToolCompletionEmissionContext: + """Streaming service, tool output, and ids for completion frames.""" + + tool_name: str + tool_call_id: str + tool_output: Any + streaming_service: Any + content_builder: Any | None + langchain_tool_call_id_holder: dict[str, str | None] + stream_result: Any + langgraph_config: dict[str, Any] + staged_workspace_file_path: str | None + tool_metadata: dict[str, Any] | None = None + + def emit_tool_output_card(self, payload: Any) -> str: + return emit_tool_output_available_frame( + streaming_service=self.streaming_service, + content_builder=self.content_builder, + langchain_id_holder=self.langchain_tool_call_id_holder, + call_id=self.tool_call_id, + output=payload, + tool_metadata=self.tool_metadata, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/edit_file/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/edit_file/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/edit_file/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/edit_file/thinking.py new file mode 100644 index 000000000..8669107db --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/edit_file/thinking.py @@ -0,0 +1,27 @@ +"""edit_file: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import ( + as_tool_input_dict, + truncate_path, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + fp = d.get("file_path", "") if isinstance(tool_input, dict) else str(tool_input) + return ToolStartThinking(title="Editing file", items=[truncate_path(fp)]) + + +def resolve_completed_thinking( + tool_name: str, tool_output: Any, last_items: list[str], +) -> tuple[str, list[str]]: + del tool_output, tool_name + return ("Editing file", last_items) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/execute/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/execute/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/execute/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/execute/emission.py new file mode 100644 index 000000000..0ff87a907 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/execute/emission.py @@ -0,0 +1,40 @@ +"""execute: exit code, stdout, sandbox file hints.""" + +from __future__ import annotations + +import re +from collections.abc import Iterator + +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) + + +def iter_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + out = ctx.tool_output + raw_text = out.get("result", "") if isinstance(out, dict) else str(out) + exit_code: int | None = None + output_text = raw_text + m = re.match(r"^Exit code:\s*(\d+)", raw_text) + if m: + exit_code = int(m.group(1)) + om = re.search(r"\nOutput:\n([\s\S]*)", raw_text) + output_text = om.group(1) if om else "" + thread_id_str = ctx.langgraph_config.get("configurable", {}).get("thread_id", "") + + for sf_match in re.finditer( + r"^SANDBOX_FILE:\s*(.+)$", output_text, re.MULTILINE + ): + fpath = sf_match.group(1).strip() + if fpath and fpath not in ctx.stream_result.sandbox_files: + ctx.stream_result.sandbox_files.append(fpath) + + yield ctx.emit_tool_output_card( + { + "exit_code": exit_code, + "output": output_text, + "thread_id": thread_id_str, + }, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/execute/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/execute/thinking.py new file mode 100644 index 000000000..2c8aa296b --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/execute/thinking.py @@ -0,0 +1,42 @@ +"""execute: sandbox command thinking + completion lines.""" + +from __future__ import annotations + +import re +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import ( + as_tool_input_dict, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + cmd = d.get("command", "") if isinstance(tool_input, dict) else str(tool_input) + display_cmd = cmd[:80] + ("…" if len(cmd) > 80 else "") + return ToolStartThinking(title="Running command", items=[f"$ {display_cmd}"]) + + +def resolve_completed_thinking( + tool_name: str, tool_output: Any, last_items: list[str], +) -> tuple[str, list[str]]: + del tool_name + items = last_items + raw_text = ( + tool_output.get("result", "") + if isinstance(tool_output, dict) + else str(tool_output) + ) + m = re.match(r"^Exit code:\s*(\d+)", raw_text) + exit_code_val = int(m.group(1)) if m else None + if exit_code_val is not None and exit_code_val == 0: + completed = [*items, "Completed successfully"] + elif exit_code_val is not None: + completed = [*items, f"Exit code: {exit_code_val}"] + else: + completed = [*items, "Finished"] + return ("Running command", completed) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/glob/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/glob/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/glob/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/glob/thinking.py new file mode 100644 index 000000000..f5a57beac --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/glob/thinking.py @@ -0,0 +1,27 @@ +"""glob: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import ( + as_tool_input_dict, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + pat = d.get("pattern", "") if isinstance(tool_input, dict) else str(tool_input) + base = d.get("path", "/") if isinstance(tool_input, dict) else "/" + return ToolStartThinking(title="Searching files", items=[f"{pat} in {base}"]) + + +def resolve_completed_thinking( + tool_name: str, tool_output: Any, last_items: list[str], +) -> tuple[str, list[str]]: + del tool_output, tool_name + return ("Searching files", last_items) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/grep/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/grep/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/grep/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/grep/thinking.py new file mode 100644 index 000000000..da0864177 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/grep/thinking.py @@ -0,0 +1,31 @@ +"""grep: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import ( + as_tool_input_dict, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + pat = d.get("pattern", "") if isinstance(tool_input, dict) else str(tool_input) + grep_path = d.get("path", "") if isinstance(tool_input, dict) else "" + display_pat = pat[:60] + ("…" if len(pat) > 60 else "") + return ToolStartThinking( + title="Searching content", + items=[f'"{display_pat}"' + (f" in {grep_path}" if grep_path else "")], + ) + + +def resolve_completed_thinking( + tool_name: str, tool_output: Any, last_items: list[str], +) -> tuple[str, list[str]]: + del tool_output, tool_name + return ("Searching content", last_items) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/ls/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/ls/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/ls/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/ls/thinking.py new file mode 100644 index 000000000..80c547b5a --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/ls/thinking.py @@ -0,0 +1,59 @@ +"""ls: thinking-step copy for directory listing.""" + +from __future__ import annotations + +import ast +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + if isinstance(tool_input, dict): + path = tool_input.get("path", "/") + else: + path = str(tool_input) + return ToolStartThinking(title="Listing files", items=[path]) + + +def resolve_completed_thinking( + tool_name: str, tool_output: Any, last_items: list[str], +) -> tuple[str, list[str]]: + del tool_name + if isinstance(tool_output, dict): + ls_output = tool_output.get("result", "") + elif isinstance(tool_output, str): + ls_output = tool_output + else: + ls_output = str(tool_output) if tool_output else "" + file_names: list[str] = [] + if ls_output: + paths: list[str] = [] + try: + parsed = ast.literal_eval(ls_output) + if isinstance(parsed, list): + paths = [str(p) for p in parsed] + except (ValueError, SyntaxError): + paths = [ + line.strip() + for line in ls_output.strip().split("\n") + if line.strip() + ] + for p in paths: + name = p.rstrip("/").split("/")[-1] + if name and len(name) <= 40: + file_names.append(name) + elif name: + file_names.append(name[:37] + "...") + if file_names: + if len(file_names) <= 5: + completed = [f"[{name}]" for name in file_names] + else: + completed = [f"[{name}]" for name in file_names[:4]] + completed.append(f"(+{len(file_names) - 4} more)") + else: + completed = ["No files found"] + return ("Listing files", completed) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/mkdir/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/mkdir/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/mkdir/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/mkdir/thinking.py new file mode 100644 index 000000000..3a3707698 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/mkdir/thinking.py @@ -0,0 +1,27 @@ +"""mkdir: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import ( + as_tool_input_dict, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + p = d.get("path", "") if isinstance(tool_input, dict) else str(tool_input) + display = p if len(p) <= 80 else "…" + p[-77:] + return ToolStartThinking(title="Creating folder", items=[display] if display else []) + + +def resolve_completed_thinking( + tool_name: str, tool_output: Any, last_items: list[str], +) -> tuple[str, list[str]]: + del tool_output, tool_name + return ("Creating folder", last_items) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/move_file/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/move_file/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/move_file/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/move_file/thinking.py new file mode 100644 index 000000000..192a789f4 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/move_file/thinking.py @@ -0,0 +1,33 @@ +"""move_file: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import ( + as_tool_input_dict, + truncate_middle, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + src = d.get("source_path", "") if isinstance(tool_input, dict) else "" + dst = d.get("destination_path", "") if isinstance(tool_input, dict) else "" + display_src = truncate_middle(src, max_len=60) + display_dst = truncate_middle(dst, max_len=60) + return ToolStartThinking( + title="Moving file", + items=[f"{display_src} → {display_dst}"] if src or dst else [], + ) + + +def resolve_completed_thinking( + tool_name: str, tool_output: Any, last_items: list[str], +) -> tuple[str, list[str]]: + del tool_output, tool_name + return ("Moving file", last_items) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/read_file/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/read_file/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/read_file/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/read_file/thinking.py new file mode 100644 index 000000000..3f4290ad7 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/read_file/thinking.py @@ -0,0 +1,27 @@ +"""read_file: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import ( + as_tool_input_dict, + truncate_path, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + fp = d.get("file_path", "") if isinstance(tool_input, dict) else str(tool_input) + return ToolStartThinking(title="Reading file", items=[truncate_path(fp)]) + + +def resolve_completed_thinking( + tool_name: str, tool_output: Any, last_items: list[str], +) -> tuple[str, list[str]]: + del tool_output, tool_name + return ("Reading file", last_items) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/rm/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/rm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/rm/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/rm/thinking.py new file mode 100644 index 000000000..a82a44e6f --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/rm/thinking.py @@ -0,0 +1,28 @@ +"""rm: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import ( + as_tool_input_dict, + truncate_path, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + rm_path = d.get("path", "") if isinstance(tool_input, dict) else str(tool_input) + display = truncate_path(rm_path) + return ToolStartThinking(title="Deleting file", items=[display] if display else []) + + +def resolve_completed_thinking( + tool_name: str, tool_output: Any, last_items: list[str], +) -> tuple[str, list[str]]: + del tool_output, tool_name + return ("Deleting file", last_items) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/rmdir/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/rmdir/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/rmdir/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/rmdir/thinking.py new file mode 100644 index 000000000..6c97904b7 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/rmdir/thinking.py @@ -0,0 +1,27 @@ +"""rmdir: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import ( + as_tool_input_dict, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + p = d.get("path", "") if isinstance(tool_input, dict) else str(tool_input) + display = p if len(p) <= 80 else "…" + p[-77:] + return ToolStartThinking(title="Deleting folder", items=[display] if display else []) + + +def resolve_completed_thinking( + tool_name: str, tool_output: Any, last_items: list[str], +) -> tuple[str, list[str]]: + del tool_output, tool_name + return ("Deleting folder", last_items) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/shared/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/shared/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/shared/tool_input.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/shared/tool_input.py new file mode 100644 index 000000000..507782283 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/shared/tool_input.py @@ -0,0 +1,17 @@ +"""Tool-call args + display truncation for filesystem thinking modules.""" + +from __future__ import annotations + +from typing import Any + + +def as_tool_input_dict(tool_input: Any) -> dict[str, Any]: + return tool_input if isinstance(tool_input, dict) else {} + + +def truncate_path(fp: str, *, max_len: int = 80) -> str: + return fp if len(fp) <= max_len else "…" + fp[-(max_len - 3) :] + + +def truncate_middle(s: str, *, max_len: int = 60) -> str: + return s if len(s) <= max_len else "…" + s[-(max_len - 3) :] diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/tool_names.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/tool_names.py new file mode 100644 index 000000000..e2ad33736 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/tool_names.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +FILESYSTEM_TOOLS: frozenset[str] = frozenset( + { + "read_file", + "glob", + "grep", + "ls", + "mkdir", + "move_file", + "rm", + "rmdir", + "write_todos", + "write_file", + "edit_file", + "execute", + } +) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_file/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_file/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_file/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_file/emission.py new file mode 100644 index 000000000..820235379 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_file/emission.py @@ -0,0 +1,43 @@ +"""write_file: path + status envelope on the tool-output card.""" + +from __future__ import annotations + +from collections.abc import Iterator + +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) +from app.tasks.chat.streaming.helpers.tool_output import ( + extract_resolved_file_path, + tool_output_has_error, + tool_output_to_text, +) + + +def iter_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + resolved_path = extract_resolved_file_path( + tool_name=ctx.tool_name, + tool_output=ctx.tool_output, + tool_input={"file_path": ctx.staged_workspace_file_path} + if ctx.staged_workspace_file_path + else None, + ) + result_text = tool_output_to_text(ctx.tool_output) + if tool_output_has_error(ctx.tool_output): + yield ctx.emit_tool_output_card( + { + "status": "error", + "error": result_text, + "path": resolved_path, + }, + ) + else: + yield ctx.emit_tool_output_card( + { + "status": "completed", + "path": resolved_path, + "result": result_text, + }, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_file/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_file/thinking.py new file mode 100644 index 000000000..43bc8a65f --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_file/thinking.py @@ -0,0 +1,27 @@ +"""write_file: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import ( + as_tool_input_dict, + truncate_path, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + fp = d.get("file_path", "") if isinstance(tool_input, dict) else str(tool_input) + return ToolStartThinking(title="Writing file", items=[truncate_path(fp)]) + + +def resolve_completed_thinking( + tool_name: str, tool_output: Any, last_items: list[str], +) -> tuple[str, list[str]]: + del tool_output, tool_name + return ("Writing file", last_items) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_todos/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_todos/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_todos/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_todos/thinking.py new file mode 100644 index 000000000..43e533daa --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_todos/thinking.py @@ -0,0 +1,34 @@ +"""write_todos: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import ( + as_tool_input_dict, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + todos = d.get("todos", []) if isinstance(tool_input, dict) else [] + todo_count = len(todos) if isinstance(todos, list) else 0 + return ToolStartThinking( + title="Planning tasks", + items=( + [f"{todo_count} task{'s' if todo_count != 1 else ''}"] + if todo_count + else [] + ), + ) + + +def resolve_completed_thinking( + tool_name: str, tool_output: Any, last_items: list[str], +) -> tuple[str, list[str]]: + del tool_output, tool_name + return ("Planning tasks", last_items) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/registry.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/registry.py new file mode 100644 index 000000000..c0568f870 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/registry.py @@ -0,0 +1,88 @@ +"""Resolve thinking and emission modules by tool name.""" + +from __future__ import annotations + +import importlib +from collections.abc import Iterator +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.connector.shared.tool_names import ( + SHARED_CONNECTOR_TOOLS, +) +from app.tasks.chat.streaming.handlers.tools.deliverables.tool_names import ( + DELIVERABLE_TOOLS, +) +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) +from app.tasks.chat.streaming.handlers.tools.filesystem.tool_names import ( + FILESYSTEM_TOOLS, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + +_BASE = "app.tasks.chat.streaming.handlers.tools" +_CONNECTOR_SHARED = "connector.shared" + +_THINKING_ALIAS: dict[str, str] = { + "execute_code": "filesystem.execute", +} +_EMISSION_ALIAS: dict[str, str] = { + "edit_file": "filesystem.write_file", + "execute_code": "filesystem.execute", +} + + +def _thinking_module(tool_name: str) -> str: + if tool_name in SHARED_CONNECTOR_TOOLS: + return _CONNECTOR_SHARED + if tool_name in FILESYSTEM_TOOLS: + return f"filesystem.{tool_name}" + if tool_name in DELIVERABLE_TOOLS: + return f"deliverables.{tool_name}" + return _THINKING_ALIAS.get(tool_name, tool_name) + + +def _emission_module(tool_name: str) -> str: + if tool_name in _EMISSION_ALIAS: + return _EMISSION_ALIAS[tool_name] + if tool_name in SHARED_CONNECTOR_TOOLS: + return _CONNECTOR_SHARED + if tool_name in DELIVERABLE_TOOLS: + return f"deliverables.{tool_name}" + if tool_name in FILESYSTEM_TOOLS: + return f"filesystem.{tool_name}" + return tool_name + + +def _import_thinking(tool_name: str): + try: + return importlib.import_module(f"{_BASE}.{_thinking_module(tool_name)}.thinking") + except ModuleNotFoundError: + return importlib.import_module(f"{_BASE}.default.thinking") + + +def _import_emission(tool_name: str): + try: + return importlib.import_module(f"{_BASE}.{_emission_module(tool_name)}.emission") + except ModuleNotFoundError: + return importlib.import_module(f"{_BASE}.default.emission") + + +def resolve_tool_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + return _import_thinking(tool_name).resolve_start_thinking(tool_name, tool_input) + + +def resolve_tool_completed_thinking_step( + tool_name: str, tool_output: Any, last_items: list[str] +) -> tuple[str, list[str]]: + return _import_thinking(tool_name).resolve_completed_thinking( + tool_name, tool_output, last_items + ) + + +def iter_tool_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + yield from _import_emission(ctx.tool_name).iter_completion_emission_frames(ctx) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/scrape_webpage/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/scrape_webpage/emission.py new file mode 100644 index 000000000..293d2a1e9 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/scrape_webpage/emission.py @@ -0,0 +1,43 @@ +"""scrape_webpage: redacted payload + terminal summary.""" + +from __future__ import annotations + +from collections.abc import Iterator + +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) + + +def iter_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + out = ctx.tool_output + if isinstance(out, dict): + display_output = {k: v for k, v in out.items() if k != "content"} + if "content" in out: + content = out.get("content", "") + display_output["content_preview"] = ( + content[:500] + "..." if len(content) > 500 else content + ) + yield ctx.emit_tool_output_card(display_output) + else: + yield ctx.emit_tool_output_card({"result": out}) + + if isinstance(out, dict) and "error" not in out: + title = out.get("title", "Webpage") + word_count = out.get("word_count", 0) + yield ctx.streaming_service.format_terminal_info( + f"Scraped: {title[:40]}{'...' if len(title) > 40 else ''} ({word_count:,} words)", + "success", + ) + else: + error_msg = ( + out.get("error", "Failed to scrape") + if isinstance(out, dict) + else "Failed to scrape" + ) + yield ctx.streaming_service.format_terminal_info( + f"Scrape failed: {error_msg}", + "error", + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/scrape_webpage/shared/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/scrape_webpage/shared/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/scrape_webpage/shared/tool_input.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/scrape_webpage/shared/tool_input.py new file mode 100644 index 000000000..581f0e64a --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/scrape_webpage/shared/tool_input.py @@ -0,0 +1,9 @@ +"""Tool-call args for scrape_webpage thinking.""" + +from __future__ import annotations + +from typing import Any + + +def as_tool_input_dict(tool_input: Any) -> dict[str, Any]: + return tool_input if isinstance(tool_input, dict) else {} diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/scrape_webpage/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/scrape_webpage/thinking.py new file mode 100644 index 000000000..335cc9703 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/scrape_webpage/thinking.py @@ -0,0 +1,47 @@ +"""scrape_webpage: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.scrape_webpage.shared.tool_input import ( + as_tool_input_dict, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + url = d.get("url", "") if isinstance(tool_input, dict) else str(tool_input) + return ToolStartThinking( + title="Scraping webpage", + items=[f"URL: {url[:80]}{'...' if len(url) > 80 else ''}"], + ) + + +def resolve_completed_thinking( + tool_name: str, tool_output: Any, last_items: list[str], +) -> tuple[str, list[str]]: + del tool_name + items = last_items + if isinstance(tool_output, dict): + title = tool_output.get("title", "Webpage") + word_count = tool_output.get("word_count", 0) + has_error = "error" in tool_output + if has_error: + completed = [ + *items, + f"Error: {tool_output.get('error', 'Failed to scrape')[:50]}", + ] + else: + completed = [ + *items, + f"Title: {title[:50]}{'...' if len(title) > 50 else ''}", + f"Extracted: {word_count:,} words", + ] + else: + completed = [*items, "Content extracted"] + return ("Scraping webpage", completed) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/shared/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/shared/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/shared/model.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/shared/model.py new file mode 100644 index 000000000..047a84374 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/shared/model.py @@ -0,0 +1,12 @@ +"""In-progress thinking-step title and bullet lines.""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True, slots=True) +class ToolStartThinking: + title: str + items: list[str] + include_items_on_frame: bool = True diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/web_search/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/web_search/emission.py new file mode 100644 index 000000000..eccaed708 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/web_search/emission.py @@ -0,0 +1,41 @@ +"""web_search: citations parsed from provider XML.""" + +from __future__ import annotations + +import re +from collections.abc import Iterator + +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) + + +def iter_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + out = ctx.tool_output + xml = out.get("result", str(out)) if isinstance(out, dict) else str(out) + citations: dict[str, dict[str, str]] = {} + for m in re.finditer( + r"<!\[CDATA\[(.*?)\]\]>\s*", + xml, + ): + title, url = m.group(1).strip(), m.group(2).strip() + if url.startswith("http") and url not in citations: + citations[url] = {"title": title} + for m in re.finditer( + r"", + xml, + ): + chunk_url, content = m.group(1).strip(), m.group(2).strip() + if ( + chunk_url.startswith("http") + and chunk_url in citations + and content + ): + citations[chunk_url]["snippet"] = ( + content[:200] + "…" if len(content) > 200 else content + ) + yield ctx.emit_tool_output_card( + {"status": "completed", "citations": citations}, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/helpers/__init__.py b/surfsense_backend/app/tasks/chat/streaming/helpers/__init__.py new file mode 100644 index 000000000..151dfdaac --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/helpers/__init__.py @@ -0,0 +1,3 @@ +"""Pure helpers for chat streaming.""" + +from __future__ import annotations diff --git a/surfsense_backend/app/tasks/chat/streaming/helpers/chunk_parts.py b/surfsense_backend/app/tasks/chat/streaming/helpers/chunk_parts.py new file mode 100644 index 000000000..48b44fc1d --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/helpers/chunk_parts.py @@ -0,0 +1,60 @@ +"""Split a model chunk into text, reasoning, and tool-call fragment lists.""" + +from __future__ import annotations + +from typing import Any + + +def extract_chunk_parts(chunk: Any) -> dict[str, Any]: + """Return dict with keys text, reasoning, and tool_call_chunks (merged from chunk fields).""" + out: dict[str, Any] = {"text": "", "reasoning": "", "tool_call_chunks": []} + if chunk is None: + return out + + content = getattr(chunk, "content", None) + if isinstance(content, str): + if content: + out["text"] = content + elif isinstance(content, list): + text_parts: list[str] = [] + reasoning_parts: list[str] = [] + for block in content: + if not isinstance(block, dict): + continue + block_type = block.get("type") + if block_type == "text": + value = block.get("text") or block.get("content") or "" + if isinstance(value, str) and value: + text_parts.append(value) + elif block_type == "reasoning": + value = ( + block.get("reasoning") + or block.get("text") + or block.get("content") + or "" + ) + if isinstance(value, str) and value: + reasoning_parts.append(value) + elif block_type in ("tool_call_chunk", "tool_use"): + out["tool_call_chunks"].append(block) + if text_parts: + out["text"] = "".join(text_parts) + if reasoning_parts: + out["reasoning"] = "".join(reasoning_parts) + + additional = getattr(chunk, "additional_kwargs", None) or {} + if isinstance(additional, dict): + extra_reasoning = additional.get("reasoning_content") + if isinstance(extra_reasoning, str) and extra_reasoning: + existing = out["reasoning"] + out["reasoning"] = ( + (existing + extra_reasoning) if existing else extra_reasoning + ) + + extra_tool_chunks = getattr(chunk, "tool_call_chunks", None) + if isinstance(extra_tool_chunks, list): + for tcc in extra_tool_chunks: + if isinstance(tcc, dict): + out["tool_call_chunks"].append(tcc) + + return out diff --git a/surfsense_backend/app/tasks/chat/streaming/helpers/interrupt_inspector.py b/surfsense_backend/app/tasks/chat/streaming/helpers/interrupt_inspector.py new file mode 100644 index 000000000..dca099b3f --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/helpers/interrupt_inspector.py @@ -0,0 +1,47 @@ +"""Read the first interrupt payload from a LangGraph state snapshot.""" + +from __future__ import annotations + +from typing import Any + + +def first_interrupt_value(state: Any) -> dict[str, Any] | None: + """Return the first interrupt payload across all snapshot tasks.""" + + def _extract(candidate: Any) -> dict[str, Any] | None: + if isinstance(candidate, dict): + value = candidate.get("value", candidate) + return value if isinstance(value, dict) else None + value = getattr(candidate, "value", None) + if isinstance(value, dict): + return value + if isinstance(candidate, list | tuple): + for item in candidate: + extracted = _extract(item) + if extracted is not None: + return extracted + return None + + for task in getattr(state, "tasks", ()) or (): + try: + interrupts = getattr(task, "interrupts", ()) or () + except (AttributeError, IndexError, TypeError): + interrupts = () + if not interrupts: + extracted = _extract(task) + if extracted is not None: + return extracted + continue + for interrupt_item in interrupts: + extracted = _extract(interrupt_item) + if extracted is not None: + return extracted + + try: + state_interrupts = getattr(state, "interrupts", ()) or () + except (AttributeError, IndexError, TypeError): + state_interrupts = () + extracted = _extract(state_interrupts) + if extracted is not None: + return extracted + return None diff --git a/surfsense_backend/app/tasks/chat/streaming/helpers/tool_call_matching.py b/surfsense_backend/app/tasks/chat/streaming/helpers/tool_call_matching.py new file mode 100644 index 000000000..fbe4c94b7 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/helpers/tool_call_matching.py @@ -0,0 +1,32 @@ +"""Match buffered model tool-call chunks to a tool start when ids were missing.""" + +from __future__ import annotations + +from typing import Any + + +def match_buffered_langchain_tool_call_id( + pending_tool_call_chunks: list[dict[str, Any]], + tool_name: str, + run_id: str, + lc_tool_call_id_by_run: dict[str, str], +) -> str | None: + matched_idx: int | None = None + for idx, tcc in enumerate(pending_tool_call_chunks): + if tcc.get("name") == tool_name and tcc.get("id"): + matched_idx = idx + break + if matched_idx is None: + for idx, tcc in enumerate(pending_tool_call_chunks): + if tcc.get("id"): + matched_idx = idx + break + if matched_idx is None: + return None + matched = pending_tool_call_chunks.pop(matched_idx) + candidate = matched.get("id") + if isinstance(candidate, str) and candidate: + if run_id: + lc_tool_call_id_by_run[run_id] = candidate + return candidate + return None diff --git a/surfsense_backend/app/tasks/chat/streaming/helpers/tool_output.py b/surfsense_backend/app/tasks/chat/streaming/helpers/tool_output.py new file mode 100644 index 000000000..a7c401dee --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/helpers/tool_output.py @@ -0,0 +1,43 @@ +"""Normalize filesystem tool payloads for SSE cards and messages.""" + +from __future__ import annotations + +import json +from typing import Any + + +def tool_output_to_text(tool_output: Any) -> str: + if isinstance(tool_output, dict): + if isinstance(tool_output.get("result"), str): + return tool_output["result"] + if isinstance(tool_output.get("error"), str): + return tool_output["error"] + return json.dumps(tool_output, ensure_ascii=False) + return str(tool_output) + + +def tool_output_has_error(tool_output: Any) -> bool: + if isinstance(tool_output, dict): + if tool_output.get("error"): + return True + result = tool_output.get("result") + return bool( + isinstance(result, str) and result.strip().lower().startswith("error:") + ) + if isinstance(tool_output, str): + return tool_output.strip().lower().startswith("error:") + return False + + +def extract_resolved_file_path( + *, tool_name: str, tool_output: Any, tool_input: Any | None = None +) -> str | None: + if isinstance(tool_output, dict): + path_value = tool_output.get("path") + if isinstance(path_value, str) and path_value.strip(): + return path_value.strip() + if tool_name in ("write_file", "edit_file") and isinstance(tool_input, dict): + file_path = tool_input.get("file_path") + if isinstance(file_path, str) and file_path.strip(): + return file_path.strip() + return None diff --git a/surfsense_backend/app/tasks/chat/streaming/relay/__init__.py b/surfsense_backend/app/tasks/chat/streaming/relay/__init__.py new file mode 100644 index 000000000..18eda9a6d --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/relay/__init__.py @@ -0,0 +1,23 @@ +"""Relay: thinking steps, tool bookkeeping, and ``EventRelay``. + +Package imports are lazy so ``relay.thinking_step_sse`` (and siblings) can load +without pulling in ``event_relay`` (which imports handler modules that may +import those siblings). +""" + +from __future__ import annotations + +__all__ = ["EventRelay", "EventRelayConfig"] + + +def __getattr__(name: str): + if name == "EventRelay": + from app.tasks.chat.streaming.relay.event_relay import EventRelay + + return EventRelay + if name == "EventRelayConfig": + from app.tasks.chat.streaming.relay.event_relay import EventRelayConfig + + return EventRelayConfig + msg = f"module {__name__!r} has no attribute {name!r}" + raise AttributeError(msg) diff --git a/surfsense_backend/app/tasks/chat/streaming/relay/event_relay.py b/surfsense_backend/app/tasks/chat/streaming/relay/event_relay.py new file mode 100644 index 000000000..03d6a66e6 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/relay/event_relay.py @@ -0,0 +1,128 @@ +"""Turn LangGraph astream_events into SSE strings via the handler modules.""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from dataclasses import dataclass, field +from typing import Any + +from app.services.streaming.emitter import EmitterRegistry +from app.tasks.chat.streaming.graph_stream.result import StreamingResult +from app.tasks.chat.streaming.handlers.chain_end import iter_chain_end_frames +from app.tasks.chat.streaming.handlers.chat_model_stream import ( + iter_chat_model_stream_frames, +) +from app.tasks.chat.streaming.handlers.custom_event_dispatch import ( + iter_custom_event_frames, +) +from app.tasks.chat.streaming.handlers.tool_end import iter_tool_end_frames +from app.tasks.chat.streaming.handlers.tool_start import iter_tool_start_frames +from app.tasks.chat.streaming.relay.state import AgentEventRelayState +from app.tasks.chat.streaming.relay.thinking_step_completion import ( + complete_active_thinking_step, +) + + +@dataclass +class EventRelayConfig: + """Optional relay tuning (sub-agent tools, text suppression).""" + + subagent_entry_tool_names: frozenset[str] = field( + default_factory=lambda: frozenset({"task"}) + ) + suppress_main_text_inside_tools: bool = True + + +class EventRelay: + """Dispatches graph events to streaming handlers and optional emitters.""" + + def __init__( + self, + *, + streaming_service: Any, + config: EventRelayConfig | None = None, + ) -> None: + self.streaming_service = streaming_service + self.config = config or EventRelayConfig() + reg = getattr(streaming_service, "emitter_registry", None) + self.emitter_registry = reg if reg is not None else EmitterRegistry() + + async def relay( + self, + events: AsyncIterator[dict[str, Any]], + *, + state: AgentEventRelayState, + result: StreamingResult, + step_prefix: str = "thinking", + content_builder: Any | None = None, + config: dict[str, Any] | None = None, + ) -> AsyncIterator[str]: + """Yield SSE for each event from the async iterator, then finalize text/thinking.""" + graph_config = config or {} + async for event in events: + event_type = event.get("event", "") + if event_type == "on_chat_model_stream": + for frame in iter_chat_model_stream_frames( + event, + state=state, + streaming_service=self.streaming_service, + content_builder=content_builder, + step_prefix=step_prefix, + ): + yield frame + elif event_type == "on_tool_start": + for frame in iter_tool_start_frames( + event, + state=state, + streaming_service=self.streaming_service, + content_builder=content_builder, + result=result, + step_prefix=step_prefix, + ): + yield frame + elif event_type == "on_tool_end": + for frame in iter_tool_end_frames( + event, + state=state, + streaming_service=self.streaming_service, + content_builder=content_builder, + result=result, + step_prefix=step_prefix, + config=graph_config, + ): + yield frame + elif event_type == "on_custom_event": + for frame in iter_custom_event_frames( + event, + state=state, + streaming_service=self.streaming_service, + content_builder=content_builder, + ): + yield frame + elif event_type in ("on_chain_end", "on_agent_end"): + for frame in iter_chain_end_frames( + event, + state=state, + streaming_service=self.streaming_service, + content_builder=content_builder, + ): + yield frame + + if state.current_text_id is not None: + yield self.streaming_service.format_text_end(state.current_text_id) + if content_builder is not None: + content_builder.on_text_end(state.current_text_id) + state.current_text_id = None + + completion_event, new_active = complete_active_thinking_step( + state=state, + streaming_service=self.streaming_service, + content_builder=content_builder, + last_active_step_id=state.last_active_step_id, + last_active_step_title=state.last_active_step_title, + last_active_step_items=state.last_active_step_items, + completed_step_ids=state.completed_step_ids, + ) + if completion_event: + yield completion_event + state.last_active_step_id = new_active diff --git a/surfsense_backend/app/tasks/chat/streaming/relay/state.py b/surfsense_backend/app/tasks/chat/streaming/relay/state.py new file mode 100644 index 000000000..27898403d --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/relay/state.py @@ -0,0 +1,98 @@ +"""Mutable counters and maps for one agent stream.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class AgentEventRelayState: + """Tracks text, thinking steps, tool depth, and pending tool-call metadata. + + **Task span (`spanId`)** — ``active_span_id`` groups steps and tools for one + open delegating ``task`` episode. ``active_task_run_id`` is the LangGraph + ``run_id`` of that ``task`` so the span clears only when that run ends, not + when child tools end. Open/close uses ``relay.task_span`` helpers. + + **Tool ↔ thinking link (`thinkingStepId`)** — Each tool run gets a thinking-row + id (``tool_step_ids[run_id]``, emitted as ``data-thinking-step`` ``data.id``). + ``tool_activity_metadata`` supplies ``metadata`` for ``tool-input-start`` / + ``tool-input-available`` (``handlers.tool_start``) and + ``tool-output-available`` (``handlers.tool_end``). + """ + + accumulated_text: str = "" + current_text_id: str | None = None + thinking_step_counter: int = 0 + tool_step_ids: dict[str, str] = field(default_factory=dict) + completed_step_ids: set[str] = field(default_factory=set) + last_active_step_id: str | None = None + last_active_step_title: str = "" + last_active_step_items: list[str] = field(default_factory=list) + just_finished_tool: bool = False + active_tool_depth: int = 0 + called_update_memory: bool = False + current_reasoning_id: str | None = None + pending_tool_call_chunks: list[dict[str, Any]] = field(default_factory=list) + lc_tool_call_id_by_run: dict[str, str] = field(default_factory=dict) + file_path_by_run: dict[str, str] = field(default_factory=dict) + index_to_meta: dict[int, dict[str, str]] = field(default_factory=dict) + ui_tool_call_id_by_run: dict[str, str] = field(default_factory=dict) + current_lc_tool_call_id: dict[str, str | None] = field( + default_factory=lambda: {"value": None} + ) + # Open ``task`` delegation span (one id shared by nested activity); unset outside. + active_span_id: str | None = None + active_task_run_id: str | None = None + # Span id minted when a ``task`` tool_call_chunk registers (before ``on_tool_start``). + pending_task_span_by_lc: dict[str, str] = field(default_factory=dict) + + def span_metadata_if_active(self) -> dict[str, Any] | None: + """``{"spanId": ...}`` when a span is active; ``None`` otherwise.""" + if self.active_span_id: + return {"spanId": self.active_span_id} + return None + + def tool_activity_metadata( + self, *, thinking_step_id: str | None + ) -> dict[str, Any] | None: + """Build ``metadata`` for tool SSE and ``tool-call`` persistence. + + Contract (keys omitted when not applicable): + + - ``spanId`` (str): present while a task-delegation span is active + (same value as ``span_metadata_if_active()``). + - ``thinkingStepId`` (str): equals the thinking-step row ``id`` for this + tool (``data-thinking-step`` payload ``data.id`` on the wire). + + Returns ``None`` if neither applies. Whitespace-only + ``thinking_step_id`` is ignored. + """ + out: dict[str, Any] = {} + if self.active_span_id: + out["spanId"] = self.active_span_id + tid = (thinking_step_id or "").strip() + if tid: + out["thinkingStepId"] = tid + return out if out else None + + @classmethod + def for_invocation( + cls, + *, + initial_step_id: str | None = None, + initial_step_title: str = "", + initial_step_items: list[str] | None = None, + ) -> AgentEventRelayState: + counter = 1 if initial_step_id else 0 + return cls( + thinking_step_counter=counter, + last_active_step_id=initial_step_id, + last_active_step_title=initial_step_title, + last_active_step_items=list(initial_step_items or []), + ) + + def next_thinking_step_id(self, step_prefix: str) -> str: + self.thinking_step_counter += 1 + return f"{step_prefix}-{self.thinking_step_counter}" diff --git a/surfsense_backend/app/tasks/chat/streaming/relay/task_span.py b/surfsense_backend/app/tasks/chat/streaming/relay/task_span.py new file mode 100644 index 000000000..c4cdf24ba --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/relay/task_span.py @@ -0,0 +1,74 @@ +"""Open/close ``active_span_id`` around a delegating ``task`` tool run.""" + +from __future__ import annotations + +import uuid + +from app.tasks.chat.streaming.relay.state import AgentEventRelayState + + +def new_span_id() -> str: + """One delegation-episode id (shared by activity under an open ``task``).""" + return f"spn_{uuid.uuid4().hex}" + + +def _run_key(run_id: str) -> str: + return (run_id or "").strip() + + +def _lc_key(langchain_tool_call_id: str | None) -> str: + return (langchain_tool_call_id or "").strip() + + +def ensure_pending_task_span_for_lc(state: AgentEventRelayState, lc_id: str) -> str: + """Return span id for this LangChain tool call id, storing it in ``pending`` if new. + + Used from ``chat_model_stream`` when the first ``task`` chunk registers so + early ``tool-input-start`` can carry ``metadata.spanId`` before ``on_tool_start``. + """ + key = _lc_key(lc_id) + if not key: + return new_span_id() + existing = state.pending_task_span_by_lc.get(key) + if existing: + return existing + sid = new_span_id() + state.pending_task_span_by_lc[key] = sid + return sid + + +def open_task_span( + state: AgentEventRelayState, + *, + run_id: str, + langchain_tool_call_id: str | None = None, +) -> str: + """Set ``active_span_id`` from pending (same lc) or mint; remember ``active_task_run_id``. + + Call when the ``task`` tool **starts**. Nested ``task`` is not supported: + a second call replaces the previous span without restoring it. + """ + key = _lc_key(langchain_tool_call_id) + sid: str | None = state.pending_task_span_by_lc.pop(key, None) if key else None + if not sid: + sid = new_span_id() + state.active_span_id = sid + state.active_task_run_id = _run_key(run_id) or None + return sid + + +def clear_task_span_if_delegating_task_ended( + state: AgentEventRelayState, + *, + tool_name: str, + run_id: str, +) -> None: + """Clear span state only when this event is the end of the opening ``task`` run.""" + if tool_name != "task": + return + if state.active_task_run_id is None: + return + if state.active_task_run_id != _run_key(run_id): + return + state.active_span_id = None + state.active_task_run_id = None diff --git a/surfsense_backend/app/tasks/chat/streaming/relay/thinking_step_completion.py b/surfsense_backend/app/tasks/chat/streaming/relay/thinking_step_completion.py new file mode 100644 index 000000000..ad0930341 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/relay/thinking_step_completion.py @@ -0,0 +1,34 @@ +"""Close the in-progress thinking step with a completed status frame.""" + +from __future__ import annotations + +from typing import Any + +from .state import AgentEventRelayState +from .thinking_step_sse import emit_thinking_step_frame + + +def complete_active_thinking_step( + *, + state: AgentEventRelayState, + streaming_service: Any, + content_builder: Any | None, + last_active_step_id: str | None, + last_active_step_title: str, + last_active_step_items: list[str], + completed_step_ids: set[str], +) -> tuple[str | None, str | None]: + """Emit a completed thinking-step frame once; return (frame or None, next active step id).""" + if last_active_step_id and last_active_step_id not in completed_step_ids: + completed_step_ids.add(last_active_step_id) + event = emit_thinking_step_frame( + streaming_service=streaming_service, + content_builder=content_builder, + step_id=last_active_step_id, + title=last_active_step_title, + status="completed", + items=last_active_step_items if last_active_step_items else None, + metadata=state.span_metadata_if_active(), + ) + return event, None + return None, last_active_step_id diff --git a/surfsense_backend/app/tasks/chat/streaming/relay/thinking_step_sse.py b/surfsense_backend/app/tasks/chat/streaming/relay/thinking_step_sse.py new file mode 100644 index 000000000..6737f536b --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/relay/thinking_step_sse.py @@ -0,0 +1,28 @@ +"""Thinking-step SSE plus optional content-builder updates.""" + +from __future__ import annotations + +from typing import Any + + +def emit_thinking_step_frame( + *, + streaming_service: Any, + content_builder: Any | None, + step_id: str, + title: str, + status: str = "in_progress", + items: list[str] | None = None, + metadata: dict[str, Any] | None = None, +) -> str: + if content_builder is not None: + content_builder.on_thinking_step( + step_id, title, status, items, metadata=metadata + ) + return streaming_service.format_thinking_step( + step_id=step_id, + title=title, + status=status, + items=items, + metadata=metadata, + ) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py b/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py index 6800be2af..099aea882 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py @@ -31,7 +31,6 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None: "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", "SURFSENSE_ENABLE_ACTION_LOG", "SURFSENSE_ENABLE_REVERT_ROUTE", - "SURFSENSE_ENABLE_STREAM_PARITY_V2", "SURFSENSE_ENABLE_PLUGIN_LOADER", "SURFSENSE_ENABLE_OTEL", "SURFSENSE_ENABLE_AGENT_CACHE", @@ -61,7 +60,6 @@ def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) -> assert flags.enable_kb_planner_runnable is True assert flags.enable_action_log is True assert flags.enable_revert_route is True - assert flags.enable_stream_parity_v2 is True assert flags.enable_plugin_loader is False assert flags.enable_otel is False # Phase 2: agent cache is now default-on (the prerequisite tool @@ -127,7 +125,6 @@ def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) -> "enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", "enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG", "enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE", - "enable_stream_parity_v2": "SURFSENSE_ENABLE_STREAM_PARITY_V2", "enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER", "enable_otel": "SURFSENSE_ENABLE_OTEL", } diff --git a/surfsense_backend/tests/unit/services/streaming/__init__.py b/surfsense_backend/tests/unit/services/streaming/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/services/streaming/test_emitter.py b/surfsense_backend/tests/unit/services/streaming/test_emitter.py new file mode 100644 index 000000000..6c4e1ff58 --- /dev/null +++ b/surfsense_backend/tests/unit/services/streaming/test_emitter.py @@ -0,0 +1,79 @@ +"""Pin the wire compactness rule and the top-level ``emitted_by`` field name.""" + +from __future__ import annotations + +import pytest + +from app.services.streaming.emitter import ( + Emitter, + attach_emitted_by, + main_emitter, + subagent_emitter, +) + +pytestmark = pytest.mark.unit + + +def test_main_emitter_payload_contains_only_level() -> None: + payload = main_emitter().to_payload() + assert payload == {"level": "main"} + + +def test_subagent_emitter_payload_includes_all_set_fields() -> None: + payload = subagent_emitter( + subagent_type="deliverables", + subagent_run_id="subagent_abc", + parent_tool_call_id="call_xyz", + ).to_payload() + assert payload == { + "level": "subagent", + "subagent_type": "deliverables", + "subagent_run_id": "subagent_abc", + "parent_tool_call_id": "call_xyz", + } + + +def test_subagent_emitter_payload_omits_unset_optional_fields() -> None: + """parent_tool_call_id is None when the run is started outside a tool boundary.""" + payload = Emitter( + level="subagent", + subagent_type="email", + subagent_run_id="subagent_1", + ).to_payload() + assert "parent_tool_call_id" not in payload + assert payload["subagent_type"] == "email" + + +def test_extra_fields_merge_into_payload() -> None: + """Future extension fields (e.g. lane colour, label) flow through ``extra``.""" + emitter = subagent_emitter( + subagent_type="search", + subagent_run_id="r1", + extra={"label": "Web Search"}, + ) + assert emitter.to_payload()["label"] == "Web Search" + + +def test_attach_emitted_by_with_none_is_noop() -> None: + payload = {"type": "text-delta", "delta": "hi"} + result = attach_emitted_by(payload, None) + assert "emitted_by" not in result + assert result is payload + + +def test_attach_emitted_by_adds_payload_under_snake_case_top_level_key() -> None: + payload = {"type": "text-delta", "delta": "hi"} + attach_emitted_by( + payload, + subagent_emitter( + subagent_type="x", + subagent_run_id="y", + parent_tool_call_id="z", + ), + ) + assert payload["emitted_by"] == { + "level": "subagent", + "subagent_type": "x", + "subagent_run_id": "y", + "parent_tool_call_id": "z", + } diff --git a/surfsense_backend/tests/unit/services/streaming/test_emitter_registry.py b/surfsense_backend/tests/unit/services/streaming/test_emitter_registry.py new file mode 100644 index 000000000..e459c946a --- /dev/null +++ b/surfsense_backend/tests/unit/services/streaming/test_emitter_registry.py @@ -0,0 +1,111 @@ +"""Pin the parent_ids walk + parallel sub-agent isolation that drives lane attribution.""" + +from __future__ import annotations + +import pytest + +from app.services.streaming.emitter import ( + Emitter, + EmitterRegistry, + main_emitter, + subagent_emitter, +) + +pytestmark = pytest.mark.unit + + +def _sub(run_id: str, kind: str = "deliverables") -> Emitter: + return subagent_emitter( + subagent_type=kind, + subagent_run_id=f"sub_{run_id}", + parent_tool_call_id=f"call_{run_id}", + ) + + +def test_unregistered_event_resolves_to_main_emitter() -> None: + registry = EmitterRegistry() + resolved = registry.resolve(run_id="run_1", parent_ids=["root"]) + assert resolved is main_emitter() + + +def test_event_owned_by_registered_run_id_returns_that_emitter() -> None: + registry = EmitterRegistry() + emitter = _sub("a") + registry.register("run_task_a", emitter) + assert registry.resolve(run_id="run_task_a", parent_ids=[]) is emitter + + +def test_descendant_resolves_via_parent_ids_chain() -> None: + """A model-call event nested under the task tool inherits its sub-agent emitter.""" + registry = EmitterRegistry() + emitter = _sub("a") + registry.register("run_task_a", emitter) + descendant = registry.resolve( + run_id="run_chat_model", + parent_ids=["root", "run_agent", "run_task_a"], + ) + assert descendant is emitter + + +def test_nearest_registered_ancestor_wins_over_distant_ones() -> None: + """Inner sub-agents owe their emitter to the nearest task tool, not the outer one.""" + registry = EmitterRegistry() + outer = _sub("outer", kind="planner") + inner = _sub("inner", kind="email") + registry.register("run_outer", outer) + registry.register("run_inner", inner) + resolved = registry.resolve( + run_id="run_inner_tool", + parent_ids=["root", "run_outer", "run_inner"], + ) + assert resolved is inner + + +def test_parallel_subagents_do_not_bleed_into_each_other() -> None: + """Two concurrent task tools each own their own descendant events.""" + registry = EmitterRegistry() + a = _sub("a", kind="search") + b = _sub("b", kind="email") + registry.register("run_task_a", a) + registry.register("run_task_b", b) + + from_a = registry.resolve(run_id="x", parent_ids=["root", "run_task_a"]) + from_b = registry.resolve(run_id="y", parent_ids=["root", "run_task_b"]) + from_main = registry.resolve(run_id="z", parent_ids=["root"]) + + assert from_a is a + assert from_b is b + assert from_main is main_emitter() + + +def test_unregister_releases_run_id_so_descendants_fall_back_to_main() -> None: + registry = EmitterRegistry() + emitter = _sub("a") + registry.register("run_task_a", emitter) + registry.unregister("run_task_a") + assert registry.resolve(run_id="x", parent_ids=["run_task_a"]) is main_emitter() + + +def test_unregister_returns_the_previously_registered_emitter() -> None: + """Lets callers emit ``data-subagent-finish`` carrying the same emitter they opened with.""" + registry = EmitterRegistry() + emitter = _sub("a") + registry.register("run_task_a", emitter) + assert registry.unregister("run_task_a") is emitter + + +def test_has_active_subagents_tracks_open_lanes() -> None: + registry = EmitterRegistry() + assert not registry.has_active_subagents() + registry.register("run_task_a", _sub("a")) + assert registry.has_active_subagents() + registry.unregister("run_task_a") + assert not registry.has_active_subagents() + + +def test_empty_run_id_and_parent_ids_resolves_to_main() -> None: + """Defensive: events without identifiers always belong to the main lane.""" + registry = EmitterRegistry() + registry.register("run_task_a", _sub("a")) + assert registry.resolve(run_id=None, parent_ids=None) is main_emitter() + assert registry.resolve(run_id="", parent_ids=[]) is main_emitter() diff --git a/surfsense_backend/tests/unit/services/streaming/test_interrupt_correlation.py b/surfsense_backend/tests/unit/services/streaming/test_interrupt_correlation.py new file mode 100644 index 000000000..edf4ecb9a --- /dev/null +++ b/surfsense_backend/tests/unit/services/streaming/test_interrupt_correlation.py @@ -0,0 +1,164 @@ +"""Pin id-aware pending-interrupt lookup that replaces the buggy first-wins.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import pytest + +from app.services.streaming.interrupt_correlation import ( + PendingInterrupt, + first_pending_interrupt, + get_pending_interrupt_by_id, + get_pending_interrupt_for_tool_call, + list_pending_interrupts, +) + +pytestmark = pytest.mark.unit + + +@dataclass +class _Interrupt: + value: dict[str, Any] + id: str | None = None + + +@dataclass +class _Task: + interrupts: tuple[_Interrupt, ...] = () + id: str | None = None + + +@dataclass +class _State: + tasks: tuple[_Task, ...] = () + interrupts: tuple[_Interrupt, ...] = () + + +def _hitl(name: str, tool_call_id: str | None = None) -> dict[str, Any]: + """Minimal LangChain HITLRequest payload for one action.""" + action: dict[str, Any] = {"name": name, "args": {}} + if tool_call_id is not None: + action["tool_call_id"] = tool_call_id + return { + "action_requests": [action], + "review_configs": [{"action_name": name, "allowed_decisions": ["approve"]}], + } + + +def test_empty_state_has_no_pending_interrupts() -> None: + state = _State() + assert list_pending_interrupts(state) == [] + assert first_pending_interrupt(state) is None + + +def test_single_pending_interrupt_in_task_is_returned() -> None: + state = _State( + tasks=( + _Task( + id="task_1", + interrupts=(_Interrupt(value=_hitl("send_email"), id="int_1"),), + ), + ) + ) + pending = list_pending_interrupts(state) + assert len(pending) == 1 + assert pending[0] == PendingInterrupt( + interrupt_id="int_1", + value=_hitl("send_email"), + source_task_id="task_1", + ) + + +def test_pending_interrupts_returned_in_task_then_root_order() -> None: + """Determinism matters: callers iterate in this order to render the UI.""" + state = _State( + tasks=( + _Task( + id="task_a", + interrupts=(_Interrupt(value=_hitl("a"), id="int_a"),), + ), + _Task( + id="task_b", + interrupts=(_Interrupt(value=_hitl("b"), id="int_b"),), + ), + ), + interrupts=(_Interrupt(value=_hitl("c"), id="int_c"),), + ) + pending = list_pending_interrupts(state) + ids = [p.interrupt_id for p in pending] + assert ids == ["int_a", "int_b", "int_c"] + + +def test_get_by_id_finds_the_right_interrupt_under_parallel_load() -> None: + """Replacing first-wins: id-aware lookup MUST pick the requested one.""" + state = _State( + tasks=( + _Task(interrupts=(_Interrupt(value=_hitl("a"), id="int_a"),)), + _Task(interrupts=(_Interrupt(value=_hitl("b"), id="int_b"),)), + _Task(interrupts=(_Interrupt(value=_hitl("c"), id="int_c"),)), + ) + ) + found = get_pending_interrupt_by_id(state, "int_b") + assert found is not None + assert found.value["action_requests"][0]["name"] == "b" + + +def test_get_by_id_returns_none_when_id_is_not_pending() -> None: + state = _State( + tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id="int_a"),)),) + ) + assert get_pending_interrupt_by_id(state, "missing") is None + + +def test_get_by_tool_call_id_matches_action_request_payload() -> None: + """HITLRequest carries ``tool_call_id`` per action; lookup uses that.""" + state = _State( + tasks=( + _Task( + interrupts=( + _Interrupt( + value=_hitl("a", tool_call_id="call_xxx"), id="int_a" + ), + _Interrupt( + value=_hitl("b", tool_call_id="call_yyy"), id="int_b" + ), + ) + ), + ) + ) + found = get_pending_interrupt_for_tool_call(state, "call_yyy") + assert found is not None + assert found.interrupt_id == "int_b" + + +def test_first_pending_interrupt_matches_legacy_first_wins_behaviour() -> None: + """Sequential-turn safety: the explicit shortcut still returns the first.""" + state = _State( + tasks=(_Task(interrupts=(_Interrupt(value=_hitl("first"), id="int_1"),)),), + interrupts=(_Interrupt(value=_hitl("second"), id="int_2"),), + ) + first = first_pending_interrupt(state) + assert first is not None + assert first.interrupt_id == "int_1" + + +def test_interrupt_without_id_falls_back_to_none() -> None: + """Snapshots from older LangGraph versions may omit ``id`` — preserve that.""" + state = _State( + tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id=None),)),) + ) + pending = list_pending_interrupts(state) + assert len(pending) == 1 + assert pending[0].interrupt_id is None + + +def test_non_dict_interrupt_values_are_ignored() -> None: + """Defensive: a non-dict value should not crash the iteration.""" + + class _Raw: + value = "not a dict" + + state = _State(tasks=(_Task(interrupts=(_Raw(),)),)) # type: ignore[arg-type] + assert list_pending_interrupts(state) == [] diff --git a/surfsense_backend/tests/unit/services/streaming/test_interrupt_events.py b/surfsense_backend/tests/unit/services/streaming/test_interrupt_events.py new file mode 100644 index 000000000..dbdd607bf --- /dev/null +++ b/surfsense_backend/tests/unit/services/streaming/test_interrupt_events.py @@ -0,0 +1,91 @@ +"""Pin interrupt-payload normalisation and the optional correlation fields on the wire.""" + +from __future__ import annotations + +import json + +import pytest + +from app.services.streaming.events.interrupt import ( + format_interrupt_request, + normalize_interrupt_payload, +) + +pytestmark = pytest.mark.unit + + +def _decode(frame: str) -> dict: + body = frame.removeprefix("data: ").removesuffix("\n\n") + return json.loads(body) + + +def test_hitlrequest_shape_is_passed_through_unchanged() -> None: + raw = { + "action_requests": [{"name": "send_email", "args": {"to": "a@b"}}], + "review_configs": [ + {"action_name": "send_email", "allowed_decisions": ["approve"]} + ], + } + assert normalize_interrupt_payload(raw) == raw + + +def test_custom_interrupt_primitive_is_converted_to_canonical_shape() -> None: + raw = { + "type": "permission", + "message": "Allow send?", + "action": {"tool": "send_email", "params": {"to": "a@b"}}, + "context": {"reason": "destructive"}, + } + out = normalize_interrupt_payload(raw) + assert out["action_requests"] == [ + {"name": "send_email", "args": {"to": "a@b"}} + ] + assert out["review_configs"] == [ + { + "action_name": "send_email", + "allowed_decisions": ["approve", "edit", "reject"], + } + ] + assert out["interrupt_type"] == "permission" + assert out["message"] == "Allow send?" + assert out["context"] == {"reason": "destructive"} + + +def test_custom_interrupt_without_message_omits_message_key() -> None: + """Optional fields stay optional on the wire; FE does not see ``"message": None``.""" + raw = {"action": {"tool": "send_email"}} + out = normalize_interrupt_payload(raw) + assert "message" not in out + + +def test_custom_interrupt_without_tool_falls_back_to_unknown_tool() -> None: + """Defensive: a malformed ``action`` block must not crash the relay.""" + out = normalize_interrupt_payload({"type": "x", "action": {}}) + assert out["action_requests"][0]["name"] == "unknown_tool" + assert out["review_configs"][0]["action_name"] == "unknown_tool" + + +def test_format_interrupt_request_carries_correlation_fields_on_the_wire() -> None: + frame = format_interrupt_request( + {"action_requests": [], "review_configs": []}, + interrupt_id="int_42", + pending_interrupt_count=3, + chat_turn_id="turn_99", + ) + payload = _decode(frame) + assert payload["type"] == "data-interrupt-request" + inner = payload["data"] + assert inner["interrupt_id"] == "int_42" + assert inner["pending_interrupt_count"] == 3 + assert inner["chat_turn_id"] == "turn_99" + + +def test_format_interrupt_request_omits_correlation_fields_when_unset() -> None: + """Backward compat: legacy single-interrupt callers don't have to supply ids.""" + frame = format_interrupt_request( + {"action_requests": [], "review_configs": []}, + ) + inner = _decode(frame)["data"] + assert "interrupt_id" not in inner + assert "pending_interrupt_count" not in inner + assert "chat_turn_id" not in inner diff --git a/surfsense_backend/tests/unit/services/streaming/test_service_emitter_propagation.py b/surfsense_backend/tests/unit/services/streaming/test_service_emitter_propagation.py new file mode 100644 index 000000000..b381f13bc --- /dev/null +++ b/surfsense_backend/tests/unit/services/streaming/test_service_emitter_propagation.py @@ -0,0 +1,142 @@ +"""Pin that sub-agent emitter reaches every wire event the relay emits.""" + +from __future__ import annotations + +import json + +import pytest + +from app.services.streaming.emitter import subagent_emitter +from app.services.streaming.service import StreamingService + +pytestmark = pytest.mark.unit + + +def _decode(frame: str) -> dict: + body = frame.removeprefix("data: ").removesuffix("\n\n") + return json.loads(body) + + +@pytest.fixture +def service() -> StreamingService: + return StreamingService() + + +@pytest.fixture +def sub_emitter(): + return subagent_emitter( + subagent_type="deliverables", + subagent_run_id="sub_xyz", + parent_tool_call_id="call_parent", + ) + + +def test_text_delta_carries_subagent_emitter_on_the_wire(service, sub_emitter) -> None: + payload = _decode(service.format_text_delta("text_1", "hi", emitter=sub_emitter)) + assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz" + assert payload["delta"] == "hi" + + +def test_reasoning_delta_carries_subagent_emitter_on_the_wire( + service, sub_emitter +) -> None: + payload = _decode( + service.format_reasoning_delta("r_1", "thinking", emitter=sub_emitter) + ) + assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz" + + +def test_tool_input_start_carries_subagent_emitter_and_lc_id( + service, sub_emitter +) -> None: + payload = _decode( + service.format_tool_input_start( + "call_1", + "send_email", + langchain_tool_call_id="lc_1", + emitter=sub_emitter, + ) + ) + assert payload["emitted_by"]["subagent_type"] == "deliverables" + assert payload["langchainToolCallId"] == "lc_1" + assert payload["toolName"] == "send_email" + + +def test_tool_output_available_carries_subagent_emitter(service, sub_emitter) -> None: + payload = _decode( + service.format_tool_output_available( + "call_1", {"ok": True}, emitter=sub_emitter + ) + ) + assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz" + assert payload["output"] == {"ok": True} + + +def test_thinking_step_carries_subagent_emitter(service, sub_emitter) -> None: + payload = _decode( + service.format_thinking_step( + step_id="s1", + title="Sending email", + status="in_progress", + emitter=sub_emitter, + ) + ) + assert payload["type"] == "data-thinking-step" + assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz" + + +def test_action_log_carries_subagent_emitter(service, sub_emitter) -> None: + payload = _decode( + service.format_action_log( + {"id": 1, "tool_name": "send_email", "reversible": False}, + emitter=sub_emitter, + ) + ) + assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz" + assert payload["data"]["tool_name"] == "send_email" + + +def test_subagent_lifecycle_events_share_run_id_for_pairing( + service, sub_emitter +) -> None: + start = _decode( + service.format_subagent_start( + subagent_run_id="sub_xyz", + subagent_type="deliverables", + parent_tool_call_id="call_parent", + emitter=sub_emitter, + ) + ) + finish = _decode( + service.format_subagent_finish( + subagent_run_id="sub_xyz", + subagent_type="deliverables", + parent_tool_call_id="call_parent", + emitter=sub_emitter, + ) + ) + assert start["data"]["subagent_run_id"] == finish["data"]["subagent_run_id"] + assert start["type"] == "data-subagent-start" + assert finish["type"] == "data-subagent-finish" + + +def test_main_emitter_events_omit_emitted_by_field(service) -> None: + payload = _decode(service.format_text_delta("text_1", "hi")) + assert "emitted_by" not in payload + + +def test_resolve_emitter_through_service_uses_registry(service, sub_emitter) -> None: + service.emitter_registry.register("run_task_1", sub_emitter) + resolved = service.resolve_emitter( + run_id="run_chat_model", + parent_ids=["root", "run_task_1"], + ) + assert resolved is sub_emitter + + +def test_message_id_is_assigned_on_message_start_and_reused(service) -> None: + frame = service.format_message_start() + payload = _decode(frame) + assigned = payload["messageId"] + assert assigned.startswith("msg_") + assert service.message_id == assigned diff --git a/surfsense_backend/tests/unit/services/streaming/test_sse_envelope.py b/surfsense_backend/tests/unit/services/streaming/test_sse_envelope.py new file mode 100644 index 000000000..511e4575a --- /dev/null +++ b/surfsense_backend/tests/unit/services/streaming/test_sse_envelope.py @@ -0,0 +1,51 @@ +"""Pin the exact SSE wire bytes the FE parser depends on.""" + +from __future__ import annotations + +import json + +import pytest + +from app.services.streaming.envelope import ( + format_done, + format_sse, + get_response_headers, +) + +pytestmark = pytest.mark.unit + + +class TestFormatSse: + def test_dict_payload_is_json_serialised(self) -> None: + frame = format_sse({"type": "start", "messageId": "msg_1"}) + assert frame.startswith("data: ") + assert frame.endswith("\n\n") + body = frame[len("data: ") : -2] + assert json.loads(body) == {"type": "start", "messageId": "msg_1"} + + def test_string_payload_is_emitted_verbatim(self) -> None: + frame = format_sse('{"already":"json"}') + assert frame == 'data: {"already":"json"}\n\n' + + def test_nested_payload_round_trips(self) -> None: + payload = { + "type": "data-action-log", + "data": {"id": 7, "tool_name": "ls", "reversible": False}, + } + frame = format_sse(payload) + body = frame.removeprefix("data: ").removesuffix("\n\n") + assert json.loads(body) == payload + + +class TestFormatDone: + def test_done_marker_is_literal(self) -> None: + assert format_done() == "data: [DONE]\n\n" + + +class TestResponseHeaders: + def test_headers_pin_ai_sdk_v1_protocol(self) -> None: + headers = get_response_headers() + assert headers["Content-Type"] == "text/event-stream" + assert headers["Cache-Control"] == "no-cache" + assert headers["Connection"] == "keep-alive" + assert headers["x-vercel-ai-ui-message-stream"] == "v1" diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/__init__.py b/surfsense_backend/tests/unit/tasks/chat/streaming/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py new file mode 100644 index 000000000..023c8b999 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py @@ -0,0 +1,292 @@ +"""Pin Stage 1 extractions as faithful copies of the old helpers. + +Extractions under ``app.tasks.chat.streaming`` are compared to +``app.tasks.chat.stream_new_chat`` helpers. +For each Stage 1 extraction we assert the new function returns the same +output as the old one for a representative input set. The moment the +two diverge - intentionally or otherwise - this file fails loudly so +the divergence is reviewed rather than shipped silently. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import Any + +import pytest + +from app.agents.new_chat.errors import BusyError +from app.agents.new_chat.middleware.busy_mutex import request_cancel, reset_cancel +from app.tasks.chat.stream_new_chat import ( + _classify_stream_exception as old_classify, + _emit_stream_terminal_error as old_emit_terminal_error, + _extract_chunk_parts as old_extract_chunk_parts, + _extract_resolved_file_path as old_extract_resolved_file_path, + _first_interrupt_value as old_first_interrupt_value, + _tool_output_has_error as old_tool_output_has_error, + _tool_output_to_text as old_tool_output_to_text, +) +from app.tasks.chat.streaming.errors.classifier import ( + classify_stream_exception as new_classify, +) +from app.tasks.chat.streaming.errors.emitter import ( + emit_stream_terminal_error as new_emit_terminal_error, +) +from app.tasks.chat.streaming.helpers.chunk_parts import ( + extract_chunk_parts as new_extract_chunk_parts, +) +from app.tasks.chat.streaming.helpers.interrupt_inspector import ( + first_interrupt_value as new_first_interrupt_value, +) +from app.tasks.chat.streaming.helpers.tool_output import ( + extract_resolved_file_path as new_extract_resolved_file_path, + tool_output_has_error as new_tool_output_has_error, + tool_output_to_text as new_tool_output_to_text, +) + +pytestmark = pytest.mark.unit + + +# ---------------------------------------------------------------- chunk parts + + +@dataclass +class _Chunk: + content: Any = "" + additional_kwargs: dict[str, Any] = field(default_factory=dict) + tool_call_chunks: list[dict[str, Any]] = field(default_factory=list) + + +_CHUNK_CASES: list[Any] = [ + None, + _Chunk(content=""), + _Chunk(content="hello"), + _Chunk(content=42), # invalid type, defensively coerced to empty + _Chunk( + content=[ + {"type": "text", "text": "Hello "}, + {"type": "text", "text": "world"}, + ] + ), + _Chunk( + content=[ + {"type": "reasoning", "reasoning": "hmm "}, + {"type": "reasoning", "text": "still"}, + {"type": "text", "text": "answer"}, + ] + ), + _Chunk( + content=[ + {"type": "tool_call_chunk", "id": "c1", "name": "x", "args": "{"}, + {"type": "tool_use", "id": "c2", "name": "y"}, + {"type": "image_url", "url": "ignored"}, + ] + ), + _Chunk( + content="visible", + additional_kwargs={"reasoning_content": "private"}, + ), + _Chunk( + tool_call_chunks=[ + {"id": None, "name": None, "args": '{"a":1}', "index": 0}, + {"id": "c", "name": "n", "args": "}", "index": 0}, + ] + ), + _Chunk( + content=[{"type": "tool_call_chunk", "id": "from-block", "name": "x"}], + tool_call_chunks=[{"id": "from-attr", "name": "y"}], + ), +] + + +@pytest.mark.parametrize("chunk", _CHUNK_CASES) +def test_extract_chunk_parts_matches_old_implementation(chunk: Any) -> None: + assert new_extract_chunk_parts(chunk) == old_extract_chunk_parts(chunk) + + +# ---------------------------------------------------------- interrupt inspector + + +@dataclass +class _Interrupt: + value: dict[str, Any] + + +@dataclass +class _Task: + interrupts: tuple[Any, ...] = () + + +@dataclass +class _State: + tasks: tuple[Any, ...] = () + interrupts: tuple[Any, ...] = () + + +_INTERRUPT_CASES: list[Any] = [ + _State(), + _State(tasks=(_Task(interrupts=(_Interrupt(value={"name": "send"}),)),)), + # Multiple tasks: must return the FIRST one in iteration order. + _State( + tasks=( + _Task(interrupts=(_Interrupt(value={"name": "first"}),)), + _Task(interrupts=(_Interrupt(value={"name": "second"}),)), + ) + ), + # Empty task interrupts -> falls back to root state.interrupts. + _State( + tasks=(_Task(interrupts=()),), + interrupts=(_Interrupt(value={"name": "root"}),), + ), + # Interrupts as plain dicts (not wrapper objects). + _State(interrupts=({"value": {"name": "dict_root"}},)), + # A defective task whose `.interrupts` raises - must be tolerated. + _State(tasks=(object(),)), +] + + +@pytest.mark.parametrize("state", _INTERRUPT_CASES) +def test_first_interrupt_value_matches_old_implementation(state: Any) -> None: + assert new_first_interrupt_value(state) == old_first_interrupt_value(state) + + +# ----------------------------------------------------------- error classifier + + +def _classify_cases() -> list[Exception]: + """Inputs that the FE depends on being mapped to specific error codes.""" + return [ + Exception("totally generic error"), + Exception( + '{"error":{"type":"rate_limit_error","message":"slow down"}}' + ), + Exception( + 'OpenrouterException - {"error":{"message":"Provider returned error",' + '"code":429}}' + ), + BusyError(request_id="thread-busy-parity"), + Exception("Thread is busy with another request"), + ] + + +@pytest.mark.parametrize("exc", _classify_cases()) +def test_classify_stream_exception_matches_old_implementation( + exc: Exception, +) -> None: + new = new_classify(exc, flow_label="parity-test") + old = old_classify(exc, flow_label="parity-test") + # Strip the wall-clock retry timestamp before comparing — both + # implementations call ``time.time()`` independently and the call + # order is enough to differ by 1 ms in practice. Every other field + # in the tuple must match exactly. + new_extra = dict(new[5]) if isinstance(new[5], dict) else new[5] + old_extra = dict(old[5]) if isinstance(old[5], dict) else old[5] + if isinstance(new_extra, dict) and isinstance(old_extra, dict): + new_extra.pop("retry_after_at", None) + old_extra.pop("retry_after_at", None) + assert new[:5] == old[:5] + assert new_extra == old_extra + + +def test_classify_turn_cancelling_branch_parity() -> None: + """The TURN_CANCELLING branch reads cancel state for the busy thread id; + both implementations must agree on retry-window semantics, not just the + plain THREAD_BUSY code.""" + thread_id = "parity-cancelling-thread" + reset_cancel(thread_id) + request_cancel(thread_id) + exc = BusyError(request_id=thread_id) + new = new_classify(exc, flow_label="parity-test") + old = old_classify(exc, flow_label="parity-test") + assert new[0] == old[0] == "thread_busy" + assert new[1] == old[1] == "TURN_CANCELLING" + assert isinstance(new[5], dict) and isinstance(old[5], dict) + assert new[5]["retry_after_ms"] == old[5]["retry_after_ms"] + + +# ------------------------------------------------------------ terminal emitter + + +class _FakeStreamingService: + """Duck-types ``format_error`` for both old and new emitters.""" + + def __init__(self) -> None: + self.calls: list[dict[str, Any]] = [] + + def format_error( + self, message: str, *, error_code: str, extra: dict[str, Any] | None = None + ) -> str: + self.calls.append( + {"message": message, "error_code": error_code, "extra": extra} + ) + return f"data: {{\"type\":\"error\",\"errorText\":\"{message}\"}}\n\n" + + +def test_emit_stream_terminal_error_matches_old_output_and_logs(caplog) -> None: + """The new emitter must produce the same SSE frame and log the same + structured payload as the old one for the same arguments.""" + args: dict[str, Any] = { + "flow": "new", + "request_id": "req-parity", + "thread_id": 7, + "search_space_id": 9, + "user_id": "user-parity", + "message": "boom", + "error_kind": "server_error", + "error_code": "SERVER_ERROR", + "severity": "error", + "is_expected": False, + "extra": {"foo": "bar"}, + } + + new_svc = _FakeStreamingService() + old_svc = _FakeStreamingService() + + with caplog.at_level(logging.ERROR): + new_frame = new_emit_terminal_error(streaming_service=new_svc, **args) + old_frame = old_emit_terminal_error(streaming_service=old_svc, **args) + + assert new_frame == old_frame + assert new_svc.calls == old_svc.calls + chat_error_records = [ + r for r in caplog.records if "[chat_stream_error]" in r.message + ] + # One log line per emit call (two emits -> two records). + assert len(chat_error_records) == 2 + + +# ---------------------------------------------------------------- tool output + + +def test_tool_output_helpers_match_old_implementation() -> None: + samples: list[Any] = [ + {"result": "ok"}, + {"error": "bad"}, + {"result": "Error: x"}, + "Error: plain", + "fine", + {"nested": {"a": 1}}, + ] + for s in samples: + assert new_tool_output_to_text(s) == old_tool_output_to_text(s) + assert new_tool_output_has_error(s) == old_tool_output_has_error(s) + + assert new_extract_resolved_file_path( + tool_name="write_file", + tool_output={"path": " /tmp/x "}, + tool_input=None, + ) == old_extract_resolved_file_path( + tool_name="write_file", + tool_output={"path": " /tmp/x "}, + tool_input=None, + ) + assert new_extract_resolved_file_path( + tool_name="write_file", + tool_output={}, + tool_input={"file_path": " /fallback "}, + ) == old_extract_resolved_file_path( + tool_name="write_file", + tool_output={}, + tool_input={"file_path": " /fallback "}, + ) diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_2_parity.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_2_parity.py new file mode 100644 index 000000000..3ee1ab622 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_2_parity.py @@ -0,0 +1,241 @@ +"""Parity tests for Stage 2 extractions (tool matching, thinking step, custom events).""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from app.tasks.chat.stream_new_chat import _legacy_match_lc_id as old_legacy_match +from app.tasks.chat.streaming.handlers.custom_events import ( + handle_action_log, + handle_action_log_updated, + handle_document_created, + handle_report_progress, +) +from app.tasks.chat.streaming.helpers.tool_call_matching import ( + match_buffered_langchain_tool_call_id as new_legacy_match, +) +from app.tasks.chat.streaming.relay.state import AgentEventRelayState +from app.tasks.chat.streaming.relay.thinking_step_completion import ( + complete_active_thinking_step, +) +from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame + +pytestmark = pytest.mark.unit + + +def _copy_chunk_buffer(raw: list[dict[str, Any]]) -> list[dict[str, Any]]: + return [dict(x) for x in raw] + + +def test_legacy_tool_call_match_matches_old_implementation() -> None: + cases: list[tuple[list[dict[str, Any]], str, str, dict[str, str]]] = [ + ( + [ + {"name": "write_file", "id": "lc-a"}, + {"name": "other", "id": "lc-b"}, + ], + "write_file", + "run-1", + {}, + ), + ( + [{"name": "x", "id": None}, {"name": "y", "id": "lc-fallback"}], + "write_file", + "run-2", + {}, + ), + ([{"name": "no_id"}], "write_file", "run-3", {}), + ] + for chunks_template, tool_name, run_id, lc_map_seed in cases: + old_chunks = _copy_chunk_buffer(chunks_template) + new_chunks = _copy_chunk_buffer(chunks_template) + old_map = dict(lc_map_seed) + new_map = dict(lc_map_seed) + old_out = old_legacy_match(old_chunks, tool_name, run_id, old_map) + new_out = new_legacy_match(new_chunks, tool_name, run_id, new_map) + assert new_out == old_out + assert new_chunks == old_chunks + assert new_map == old_map + + +def test_emit_thinking_step_frame_invokes_builder_before_service() -> None: + order: list[str] = [] + builder = MagicMock() + + def on_ts(*args: Any, **kwargs: Any) -> None: + order.append("builder") + + builder.on_thinking_step.side_effect = on_ts + + svc = MagicMock() + + def fmt(**kwargs: Any) -> str: + order.append("service") + return "frame" + + svc.format_thinking_step.side_effect = fmt + + out = emit_thinking_step_frame( + streaming_service=svc, + content_builder=builder, + step_id="thinking-1", + title="Working", + status="in_progress", + items=["a"], + ) + assert out == "frame" + assert order == ["builder", "service"] + builder.on_thinking_step.assert_called_once() + svc.format_thinking_step.assert_called_once() + + +def test_emit_thinking_step_frame_skips_builder_when_none() -> None: + svc = MagicMock(return_value="x") + svc.format_thinking_step.return_value = "frame" + assert ( + emit_thinking_step_frame( + streaming_service=svc, + content_builder=None, + step_id="s", + title="t", + ) + == "frame" + ) + svc.format_thinking_step.assert_called_once() + + +def test_complete_active_thinking_step_mirrors_closure_semantics() -> None: + svc = MagicMock() + svc.format_thinking_step.return_value = "done-frame" + completed: set[str] = set() + relay_state = AgentEventRelayState.for_invocation() + + frame, new_id = complete_active_thinking_step( + state=relay_state, + streaming_service=svc, + content_builder=None, + last_active_step_id="thinking-1", + last_active_step_title="T", + last_active_step_items=["x"], + completed_step_ids=completed, + ) + assert frame == "done-frame" + assert new_id is None + assert "thinking-1" in completed + + frame2, id2 = complete_active_thinking_step( + state=relay_state, + streaming_service=svc, + content_builder=None, + last_active_step_id="thinking-1", + last_active_step_title="T", + last_active_step_items=[], + completed_step_ids=completed, + ) + assert frame2 is None + assert id2 == "thinking-1" + + +def test_agent_event_relay_state_factory_matches_counter_rule() -> None: + s0 = AgentEventRelayState.for_invocation() + assert s0.thinking_step_counter == 0 + assert s0.last_active_step_id is None + + s1 = AgentEventRelayState.for_invocation( + initial_step_id="thinking-resume-1", + initial_step_title="Inherited", + initial_step_items=["Topic: X"], + ) + assert s1.thinking_step_counter == 1 + assert s1.last_active_step_id == "thinking-resume-1" + assert s1.next_thinking_step_id("thinking") == "thinking-2" + + +@pytest.mark.parametrize( + ("phase", "message", "start_items", "expected_tail"), + [ + ( + "revising_section", + "progress line", + ["Topic: Foo", "Modifying bar", "stale..."], + ["Topic: Foo", "Modifying bar", "progress line"], + ), + ( + "other", + "phase msg", + ["Topic: Foo", "old line"], + ["Topic: Foo", "phase msg"], + ), + ], +) +def test_report_progress_items_match_reference( + phase: str, + message: str, + start_items: list[str], + expected_tail: list[str], +) -> None: + svc = MagicMock() + svc.format_thinking_step.return_value = "sse" + + items = list(start_items) + frame, new_items = handle_report_progress( + {"message": message, "phase": phase}, + last_active_step_id="step-1", + last_active_step_title="Report", + last_active_step_items=items, + streaming_service=svc, + content_builder=None, + ) + assert frame == "sse" + assert new_items == expected_tail + kwargs = svc.format_thinking_step.call_args.kwargs + assert kwargs["items"] == expected_tail + + +def test_report_progress_noop_when_missing_message_or_step() -> None: + svc = MagicMock() + items = ["Topic: A"] + f1, i1 = handle_report_progress( + {"message": "", "phase": "x"}, + last_active_step_id="s", + last_active_step_title="t", + last_active_step_items=items, + streaming_service=svc, + content_builder=None, + ) + assert f1 is None and i1 is items + + f2, i2 = handle_report_progress( + {"message": "m", "phase": "x"}, + last_active_step_id=None, + last_active_step_title="t", + last_active_step_items=items, + streaming_service=svc, + content_builder=None, + ) + assert f2 is None and i2 is items + + +def test_document_action_handlers_match_format_data_guards() -> None: + svc = MagicMock() + svc.format_data.return_value = "data-frame" + + assert handle_document_created({}, streaming_service=svc) is None + assert handle_document_created({"id": 0}, streaming_service=svc) is None + handle_document_created({"id": 42, "title": "x"}, streaming_service=svc) + svc.format_data.assert_called_with( + "documents-updated", {"action": "created", "document": {"id": 42, "title": "x"}} + ) + + svc.reset_mock() + assert handle_action_log({"id": None}, streaming_service=svc) is None + handle_action_log({"id": 1}, streaming_service=svc) + svc.format_data.assert_called_once_with("action-log", {"id": 1}) + + svc.reset_mock() + assert handle_action_log_updated({"id": None}, streaming_service=svc) is None + handle_action_log_updated({"id": 2}, streaming_service=svc) + svc.format_data.assert_called_once_with("action-log-updated", {"id": 2}) diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_stream_output.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stream_output.py new file mode 100644 index 000000000..9fb876dd7 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stream_output.py @@ -0,0 +1,116 @@ +"""Tests for ``stream_output`` (LangGraph events → SSE).""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +import pytest + +from app.tasks.chat.streaming.graph_stream import stream_output +from app.tasks.chat.streaming.graph_stream.result import StreamingResult + +pytestmark = pytest.mark.unit + + +@dataclass +class _Chunk: + content: Any = "" + additional_kwargs: dict[str, Any] = field(default_factory=dict) + tool_call_chunks: list[dict[str, Any]] = field(default_factory=list) + + +class _StreamingService: + def __init__(self) -> None: + self._text_idx = 0 + + def generate_text_id(self) -> str: + self._text_idx += 1 + return f"text-{self._text_idx}" + + def format_text_start(self, text_id: str) -> str: + return f"text_start:{text_id}" + + def format_text_delta(self, text_id: str, text: str) -> str: + return f"text_delta:{text_id}:{text}" + + def format_text_end(self, text_id: str) -> str: + return f"text_end:{text_id}" + + +class _Agent: + def __init__(self, events: list[dict[str, Any]]) -> None: + self.events = list(events) + self.calls: list[tuple[Any, dict[str, Any]]] = [] + + async def astream_events(self, input_data: Any, **kwargs: Any): + self.calls.append((input_data, kwargs)) + for event in self.events: + yield event + + +async def _collect(stream: Any) -> list[str]: + out: list[str] = [] + async for x in stream: + out.append(x) + return out + + +async def test_stream_output_emits_text_lifecycle_and_updates_result() -> None: + service = _StreamingService() + agent = _Agent( + [ + {"event": "on_chat_model_stream", "data": {"chunk": _Chunk(content="Hello")}}, + {"event": "on_chat_model_stream", "data": {"chunk": _Chunk(content=" world")}}, + ] + ) + result = StreamingResult() + + frames = await _collect( + stream_output( + agent=agent, + config={"configurable": {"thread_id": "t-1"}}, + input_data={"messages": []}, + streaming_service=service, + result=result, + ) + ) + + assert frames == [ + "text_start:text-1", + "text_delta:text-1:Hello", + "text_delta:text-1: world", + "text_end:text-1", + ] + assert result.accumulated_text == "Hello world" + assert result.agent_called_update_memory is False + + +async def test_stream_output_passes_runtime_context_to_agent() -> None: + service = _StreamingService() + + class _ContextAwareAgent: + async def astream_events(self, input_data: Any, **kwargs: Any): + del input_data + text = "ctx-ok" if kwargs.get("context") else "ctx-missing" + yield {"event": "on_chat_model_stream", "data": {"chunk": _Chunk(text)}} + + agent = _ContextAwareAgent() + result = StreamingResult() + + frames = await _collect( + stream_output( + agent=agent, + config={"configurable": {"thread_id": "t-2"}}, + input_data={"messages": []}, + streaming_service=service, + result=result, + runtime_context={"mentioned_document_ids": [1, 2]}, + ) + ) + + assert frames == [ + "text_start:text-1", + "text_delta:text-1:ctx-ok", + "text_end:text-1", + ] diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_task_span.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_task_span.py new file mode 100644 index 000000000..349c9879c --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_task_span.py @@ -0,0 +1,69 @@ +"""Unit tests for ``task_span`` open/close helpers.""" + +from __future__ import annotations + +import pytest + +from app.tasks.chat.streaming.relay.state import AgentEventRelayState +from app.tasks.chat.streaming.relay.task_span import ( + clear_task_span_if_delegating_task_ended, + ensure_pending_task_span_for_lc, + open_task_span, +) + +pytestmark = pytest.mark.unit + + +def test_open_task_span_sets_span_and_run_id() -> None: + state = AgentEventRelayState.for_invocation() + sid = open_task_span(state, run_id="run-abc") + assert sid.startswith("spn_") + assert state.active_span_id == sid + assert state.active_task_run_id == "run-abc" + assert state.span_metadata_if_active() == {"spanId": sid} + + +def test_clear_ignored_for_non_task_tool() -> None: + state = AgentEventRelayState.for_invocation() + open_task_span(state, run_id="run-1") + sid = state.active_span_id + clear_task_span_if_delegating_task_ended( + state, tool_name="web_search", run_id="run-1" + ) + assert state.active_span_id == sid + assert state.active_task_run_id == "run-1" + + +def test_clear_ignored_when_task_run_id_mismatches() -> None: + state = AgentEventRelayState.for_invocation() + open_task_span(state, run_id="run-open") + clear_task_span_if_delegating_task_ended(state, tool_name="task", run_id="run-other") + assert state.active_span_id is not None + assert state.active_task_run_id == "run-open" + + +def test_clear_on_matching_task_end() -> None: + state = AgentEventRelayState.for_invocation() + open_task_span(state, run_id="run-x") + clear_task_span_if_delegating_task_ended(state, tool_name="task", run_id="run-x") + assert state.active_span_id is None + assert state.active_task_run_id is None + assert state.span_metadata_if_active() is None + + +def test_clear_noop_when_no_open_span() -> None: + state = AgentEventRelayState.for_invocation() + clear_task_span_if_delegating_task_ended(state, tool_name="task", run_id="run-x") + assert state.active_span_id is None + + +def test_pending_then_open_reuses_same_span_id() -> None: + state = AgentEventRelayState.for_invocation() + sid_pending = ensure_pending_task_span_for_lc(state, "lc-task-1") + assert state.pending_task_span_by_lc["lc-task-1"] == sid_pending + sid_active = open_task_span( + state, run_id="run-1", langchain_tool_call_id="lc-task-1" + ) + assert sid_active == sid_pending + assert state.active_span_id == sid_pending + assert "lc-task-1" not in state.pending_task_span_by_lc diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_tool_activity_metadata.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_tool_activity_metadata.py new file mode 100644 index 000000000..c2e68dacd --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_tool_activity_metadata.py @@ -0,0 +1,42 @@ +"""Unit tests for ``AgentEventRelayState.tool_activity_metadata``.""" + +from __future__ import annotations + +import pytest + +from app.tasks.chat.streaming.relay.state import AgentEventRelayState +from app.tasks.chat.streaming.relay.task_span import open_task_span + +pytestmark = pytest.mark.unit + + +def test_returns_none_when_no_span_and_no_thinking_step() -> None: + state = AgentEventRelayState.for_invocation() + assert state.tool_activity_metadata(thinking_step_id=None) is None + assert state.tool_activity_metadata(thinking_step_id="") is None + assert state.tool_activity_metadata(thinking_step_id=" ") is None + + +def test_thinking_step_id_only() -> None: + state = AgentEventRelayState.for_invocation() + assert state.tool_activity_metadata(thinking_step_id="thinking-3") == { + "thinkingStepId": "thinking-3", + } + + +def test_span_only_when_active() -> None: + state = AgentEventRelayState.for_invocation() + open_task_span(state, run_id="run-x") + assert state.tool_activity_metadata(thinking_step_id=None) == { + "spanId": state.active_span_id, + } + + +def test_merges_span_and_thinking_step_when_both_set() -> None: + state = AgentEventRelayState.for_invocation() + open_task_span(state, run_id="run-x") + md = state.tool_activity_metadata(thinking_step_id="thinking-7") + assert md == { + "spanId": state.active_span_id, + "thinkingStepId": "thinking-7", + } diff --git a/surfsense_backend/tests/unit/tasks/chat/test_content_builder.py b/surfsense_backend/tests/unit/tasks/chat/test_content_builder.py index c317eba20..9d3eb6fa4 100644 --- a/surfsense_backend/tests/unit/tasks/chat/test_content_builder.py +++ b/surfsense_backend/tests/unit/tasks/chat/test_content_builder.py @@ -15,6 +15,7 @@ import json import pytest +from app.services.new_streaming_service import VercelStreamingService from app.tasks.chat.content_builder import AssistantContentBuilder pytestmark = pytest.mark.unit @@ -161,7 +162,7 @@ class TestToolHeavyTurn: _assert_jsonb_safe(snap) def test_tool_input_available_without_prior_start_creates_card(self): - # Legacy / parity_v2-OFF path: tool-input-available may be + # Late-registration: tool-input-available may be # emitted without a prior tool-input-start (no streamed # tool_call_chunks). The card should still be created. b = AssistantContentBuilder() @@ -187,7 +188,7 @@ class TestToolHeavyTurn: assert part["result"] == {"matches": 3} def test_tool_input_start_idempotent_for_same_ui_id(self): - # parity_v2: tool-input-start can fire from BOTH the chunk + # tool-input-start can fire from BOTH the chunk # registration path AND the canonical ``on_tool_start`` path. # The second call must not create a duplicate part. b = AssistantContentBuilder() @@ -231,6 +232,155 @@ class TestToolHeavyTurn: ) +# --------------------------------------------------------------------------- +# Task-span metadata on tool-call parts (JSONB persistence) +# --------------------------------------------------------------------------- + + +class TestToolCallSpanMetadata: + def test_input_available_merges_new_metadata_keys_after_start(self): + b = AssistantContentBuilder() + b.on_tool_input_start( + "call_t", "task", "lc_t", metadata={"spanId": "spn_1"} + ) + b.on_tool_input_available( + "call_t", + "task", + {"goal": "x"}, + "lc_t", + metadata={"traceId": "tr_1"}, + ) + part = b.snapshot()[0] + assert part["metadata"]["spanId"] == "spn_1" + assert part["metadata"]["traceId"] == "tr_1" + _assert_jsonb_safe(b.snapshot()) + + def test_input_available_does_not_overwrite_existing_metadata_keys(self): + b = AssistantContentBuilder() + b.on_tool_input_start( + "call_t", "task", "lc_t", metadata={"spanId": "spn_keep"} + ) + b.on_tool_input_available( + "call_t", "task", {}, "lc_t", metadata={"spanId": "spn_other"} + ) + assert b.snapshot()[0]["metadata"]["spanId"] == "spn_keep" + + def test_late_tool_input_available_carries_metadata(self): + b = AssistantContentBuilder() + b.on_tool_input_available( + "call_l", + "grep", + {"pattern": "TODO"}, + None, + metadata={"spanId": "spn_l"}, + ) + part = b.snapshot()[0] + assert part["metadata"] == {"spanId": "spn_l"} + _assert_jsonb_safe(b.snapshot()) + + def test_output_available_merges_without_clobbering_span_id(self): + b = AssistantContentBuilder() + b.on_tool_input_start("call_t", "ls", "lc", metadata={"spanId": "spn_x"}) + b.on_tool_input_available("call_t", "ls", {"path": "/"}, "lc") + b.on_tool_output_available( + "call_t", + {"ok": True}, + "lc", + metadata={"spanId": "spn_y", "extra": 1}, + ) + md = b.snapshot()[0]["metadata"] + assert md["spanId"] == "spn_x" + assert md["extra"] == 1 + + def test_output_available_adds_thinking_step_id_without_clobbering_span(self): + b = AssistantContentBuilder() + b.on_tool_input_start( + "call_t", + "ls", + "lc", + metadata={"spanId": "spn_x", "thinkingStepId": "thinking-3"}, + ) + b.on_tool_input_available("call_t", "ls", {"path": "/"}, "lc") + b.on_tool_output_available( + "call_t", + {"ok": True}, + "lc", + metadata={"spanId": "spn_x", "thinkingStepId": "thinking-3"}, + ) + md = b.snapshot()[0]["metadata"] + assert md["spanId"] == "spn_x" + assert md["thinkingStepId"] == "thinking-3" + + def test_output_available_with_none_metadata_preserves_prior(self): + b = AssistantContentBuilder() + b.on_tool_input_start("c", "ls", "lc", metadata={"spanId": "spn_1"}) + b.on_tool_input_available("c", "ls", {}, "lc") + b.on_tool_output_available("c", {"r": 1}, "lc", metadata=None) + assert b.snapshot()[0]["metadata"] == {"spanId": "spn_1"} + + def test_available_adds_thinking_step_id_after_chunk_only_start(self): + """Mirrors chunk ``tool-input-start`` then ``on_tool_start`` ``available``.""" + b = AssistantContentBuilder() + b.on_tool_input_start("lc_1", "ls", "lc_1", metadata={"spanId": "spn_a"}) + b.on_tool_input_available( + "lc_1", + "ls", + {"path": "/"}, + "lc_1", + metadata={"spanId": "spn_a", "thinkingStepId": "thinking-2"}, + ) + md = b.snapshot()[0]["metadata"] + assert md["spanId"] == "spn_a" + assert md["thinkingStepId"] == "thinking-2" + + +class TestVercelStreamingServiceToolMetadataWire: + """SSE payloads include optional ``metadata`` for FE grouping.""" + + @staticmethod + def _parse_sse_data_line(raw: str) -> dict: + assert raw.startswith("data: ") + payload = raw.split("data: ", 1)[1].split("\n\n", 1)[0].strip() + return json.loads(payload) + + def test_tool_input_available_includes_metadata_when_set(self): + svc = VercelStreamingService() + raw = svc.format_tool_input_available( + "id1", + "task", + {"a": 1}, + langchain_tool_call_id="lc1", + metadata={"spanId": "spn_w", "thinkingStepId": "thinking-4"}, + ) + body = self._parse_sse_data_line(raw) + assert body["type"] == "tool-input-available" + assert body["metadata"] == { + "spanId": "spn_w", + "thinkingStepId": "thinking-4", + } + + def test_tool_output_available_includes_metadata_when_set(self): + svc = VercelStreamingService() + raw = svc.format_tool_output_available( + "id1", + {"status": "completed"}, + langchain_tool_call_id="lc1", + metadata={"spanId": "spn_o", "thinkingStepId": "thinking-9"}, + ) + body = self._parse_sse_data_line(raw) + assert body["type"] == "tool-output-available" + assert body["metadata"] == { + "spanId": "spn_o", + "thinkingStepId": "thinking-9", + } + + def test_tool_input_available_omits_metadata_key_when_none(self): + svc = VercelStreamingService() + raw = svc.format_tool_input_available("id1", "ls", {}) + body = self._parse_sse_data_line(raw) + assert "metadata" not in body + + # --------------------------------------------------------------------------- # Thinking steps & separators # --------------------------------------------------------------------------- diff --git a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py index 60750396c..ada32d168 100644 --- a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py +++ b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py @@ -1,16 +1,13 @@ """Unit tests for live tool-call argument streaming. -Pins the wire format that ``_stream_agent_events`` emits when -``SURFSENSE_ENABLE_STREAM_PARITY_V2=true``: ``tool-input-start`` → -``tool-input-delta``... → ``tool-input-available`` → ``tool-output-available`` -all keyed by the same LangChain ``tool_call.id``. +Pins the wire format that ``_stream_agent_events`` emits: +``tool-input-start`` → ``tool-input-delta``... → ``tool-input-available`` → +``tool-output-available``, keyed consistently with LangChain ``tool_call.id`` +when the model streams indexed chunks. Identity is tracked in ``index_to_meta`` (per-chunk ``index``) and -``ui_tool_call_id_by_run`` (LangGraph ``run_id``); both are private to -``_stream_agent_events`` so we exercise them via the public wire output. - -These tests also lock in the legacy / parity_v2-OFF behaviour so the -synthetic ``call_`` shape stays stable for older clients. +``ui_tool_call_id_by_run`` (LangGraph ``run_id``); both are internal to the +streaming layer so we assert on the public SSE payloads. """ from __future__ import annotations @@ -22,8 +19,6 @@ from typing import Any import pytest -import app.tasks.chat.stream_new_chat as stream_module -from app.agents.new_chat.feature_flags import AgentFeatureFlags from app.services.new_streaming_service import VercelStreamingService from app.tasks.chat.stream_new_chat import ( StreamResult, @@ -164,24 +159,6 @@ def _tool_end( } -@pytest.fixture -def parity_v2_on(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr( - stream_module, - "get_flags", - lambda: AgentFeatureFlags(enable_stream_parity_v2=True), - ) - - -@pytest.fixture -def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr( - stream_module, - "get_flags", - lambda: AgentFeatureFlags(enable_stream_parity_v2=False), - ) - - async def _drain( events: list[dict[str, Any]], state: _FakeAgentState | None = None ) -> list[dict[str, Any]]: @@ -253,12 +230,12 @@ class TestLegacyMatch: # --------------------------------------------------------------------------- -# parity_v2 wire format tests. +# Tool input streaming wire format # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_idless_chunk_merging_by_index(parity_v2_on: None) -> None: +async def test_idless_chunk_merging_by_index() -> None: """First chunk carries id+name; later idless chunks at the same ``index`` merge into the SAME ``tool-input-start`` ui id and emit one ``tool-input-delta`` per chunk.""" @@ -302,9 +279,7 @@ async def test_idless_chunk_merging_by_index(parity_v2_on: None) -> None: @pytest.mark.asyncio -async def test_two_interleaved_tool_calls_route_by_index( - parity_v2_on: None, -) -> None: +async def test_two_interleaved_tool_calls_route_by_index() -> None: """Two same-name calls with distinct indices keep their deltas routed to the right card.""" events = [ @@ -344,7 +319,7 @@ async def test_two_interleaved_tool_calls_route_by_index( @pytest.mark.asyncio -async def test_identity_stable_across_lifecycle(parity_v2_on: None) -> None: +async def test_identity_stable_across_lifecycle() -> None: """Whatever id ``tool-input-start`` chose must be the SAME id used on ``tool-input-available`` AND ``tool-output-available``.""" events = [ @@ -367,7 +342,7 @@ async def test_identity_stable_across_lifecycle(parity_v2_on: None) -> None: @pytest.mark.asyncio -async def test_no_duplicate_tool_input_start(parity_v2_on: None) -> None: +async def test_no_duplicate_tool_input_start() -> None: """When the chunk-emission loop already fired ``tool-input-start`` for this run, ``on_tool_start`` MUST NOT emit a second one.""" events = [ @@ -386,9 +361,7 @@ async def test_no_duplicate_tool_input_start(parity_v2_on: None) -> None: @pytest.mark.asyncio -async def test_active_text_closes_before_early_tool_input_start( - parity_v2_on: None, -) -> None: +async def test_active_text_closes_before_early_tool_input_start() -> None: """Streaming a text-delta then a tool-call chunk in subsequent chunks: the wire MUST contain ``text-end`` before the FIRST ``tool-input-start`` (clean part boundary on the frontend).""" @@ -409,9 +382,7 @@ async def test_active_text_closes_before_early_tool_input_start( @pytest.mark.asyncio -async def test_mixed_text_and_tool_chunk_preserve_order( - parity_v2_on: None, -) -> None: +async def test_mixed_text_and_tool_chunk_preserve_order() -> None: """One AIMessageChunk that carries BOTH ``text`` content AND ``tool_call_chunks`` should emit the text delta FIRST, then close text, then ``tool-input-start``+``tool-input-delta``.""" @@ -441,45 +412,7 @@ async def test_mixed_text_and_tool_chunk_preserve_order( @pytest.mark.asyncio -async def test_parity_v2_off_preserves_legacy_shape( - parity_v2_off: None, -) -> None: - """When the flag is OFF, no deltas are emitted and the ``toolCallId`` - is ``call_`` (NOT the lc id).""" - events = [ - _model_stream( - tool_call_chunks=[ - {"id": "lc-1", "name": "ls", "args": '{"path":"/"}', "index": 0} - ] - ), - _tool_start(name="ls", run_id="run-A", input_payload={"path": "/"}), - _tool_end(name="ls", run_id="run-A", tool_call_id="lc-1"), - ] - payloads = await _drain(events) - - assert _of_type(payloads, "tool-input-delta") == [] - starts = _of_type(payloads, "tool-input-start") - assert len(starts) == 1 - assert starts[0]["toolCallId"].startswith("call_run-A") - # No ``langchainToolCallId`` propagation on ``tool-input-start`` in - # legacy mode (the start event fires before the ToolMessage is - # available, so we can't extract the authoritative LangChain id yet). - assert "langchainToolCallId" not in starts[0] - output = _of_type(payloads, "tool-output-available") - assert output[0]["toolCallId"].startswith("call_run-A") - # ``tool-output-available`` MUST carry ``langchainToolCallId`` even - # in legacy mode: the chat tool card uses it to backfill the - # LangChain id and join against the ``data-action-log`` SSE event - # (keyed by ``lc_tool_call_id``) so the inline Revert button can - # light up. Sourced from the returned ``ToolMessage.tool_call_id``, - # which is populated regardless of feature-flag state. - assert output[0]["langchainToolCallId"] == "lc-1" - - -@pytest.mark.asyncio -async def test_skip_append_prevents_stale_id_reuse( - parity_v2_on: None, -) -> None: +async def test_skip_append_prevents_stale_id_reuse() -> None: """Two same-name tools: the SECOND tool's ``langchainToolCallId`` must NOT come from the first tool's chunk (``pending_tool_call_chunks`` must stay empty for indexed-registered chunks).""" @@ -506,9 +439,7 @@ async def test_skip_append_prevents_stale_id_reuse( @pytest.mark.asyncio -async def test_registration_waits_for_both_id_and_name( - parity_v2_on: None, -) -> None: +async def test_registration_waits_for_both_id_and_name() -> None: """An id-only chunk (no name yet) must NOT emit ``tool-input-start``.""" events = [ _model_stream( @@ -520,12 +451,9 @@ async def test_registration_waits_for_both_id_and_name( @pytest.mark.asyncio -async def test_unmatched_fallback_still_attaches_lc_id( - parity_v2_on: None, -) -> None: - """parity_v2 ON, but the provider didn't include an ``index``: the - legacy fallback path must still emit ``tool-input-start`` with the - matching ``langchainToolCallId``.""" +async def test_unmatched_fallback_still_attaches_lc_id() -> None: + """When the provider omits chunk ``index``, buffered chunks still get a + ``tool-input-start`` with the matching ``langchainToolCallId``.""" events = [ # No index on the chunk → not registered into index_to_meta; # falls through to ``pending_tool_call_chunks`` so the legacy @@ -542,9 +470,7 @@ async def test_unmatched_fallback_still_attaches_lc_id( @pytest.mark.asyncio -async def test_interrupt_request_uses_task_that_contains_interrupt( - parity_v2_on: None, -) -> None: +async def test_interrupt_request_uses_task_that_contains_interrupt() -> None: interrupt_payload = { "type": "calendar_event_create", "action": { diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 9b5510df3..9550eed05 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -43,13 +43,14 @@ import { type EditMessageDialogChoice, } from "@/components/assistant-ui/edit-message-dialog"; import { StepSeparatorDataUI } from "@/components/assistant-ui/step-separator"; -import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps"; import { Thread } from "@/components/assistant-ui/thread"; import { createTokenUsageStore, type TokenUsageData, TokenUsageProvider, } from "@/components/assistant-ui/token-usage-context"; +import { type HitlDecision, PendingInterruptProvider } from "@/features/chat-messages/hitl"; +import { TimelineDataUI } from "@/features/chat-messages/timeline"; import { applyActionLogSse, applyActionLogUpdatedSse, @@ -63,7 +64,10 @@ import { documentsApiService } from "@/lib/apis/documents-api.service"; import { getBearerToken } from "@/lib/auth-utils"; import { type ChatFlow, classifyChatError } from "@/lib/chat/chat-error-classifier"; import { tagPreAcceptSendFailure, toHttpResponseError } from "@/lib/chat/chat-request-errors"; -import { convertToThreadMessage } from "@/lib/chat/message-utils"; +import { + convertToThreadMessage, + reconcileInterruptedAssistantMessages, +} from "@/lib/chat/message-utils"; import { isPodcastGenerating, looksLikePodcastRequest, @@ -107,7 +111,6 @@ import { type NewChatUserImagePayload, } from "@/lib/chat/user-turn-api-parts"; import { NotFoundError } from "@/lib/error"; -import { type BundleSubmit, HitlBundleProvider } from "@/lib/hitl"; import { trackChatBlocked, trackChatCreated, @@ -126,7 +129,7 @@ const MobileEditorPanel = dynamic( ); const MobileHitlEditPanel = dynamic( () => - import("@/components/hitl-edit-panel/hitl-edit-panel").then((m) => ({ + import("@/features/chat-messages/hitl").then((m) => ({ default: m.MobileHitlEditPanel, })), { ssr: false } @@ -395,7 +398,7 @@ export default function NewChatPage() { const memberById = new Map(membersData?.map((m) => [m.user_id, m]) ?? []); const prevById = new Map(prev.map((m) => [m.id, m])); - return syncedMessages.map((msg) => { + return reconcileInterruptedAssistantMessages(syncedMessages).map((msg) => { const member = msg.author_id ? (memberById.get(msg.author_id) ?? null) : null; // Preserve existing author info if member lookup fails (e.g., cloned chats) @@ -622,7 +625,9 @@ export default function NewChatPage() { setCurrentThread(threadData); if (messagesResponse.messages && messagesResponse.messages.length > 0) { - const loadedMessages = messagesResponse.messages.map(convertToThreadMessage); + const loadedMessages = reconcileInterruptedAssistantMessages( + messagesResponse.messages + ).map(convertToThreadMessage); setMessages(loadedMessages); for (const msg of messagesResponse.messages) { @@ -1388,6 +1393,8 @@ export default function NewChatPage() { const existingMsg = messages.find((m) => m.id === assistantMsgId); if (existingMsg && Array.isArray(existingMsg.content)) { + // See ``ContentPartsState.suppressStepSeparators`` doc. + contentPartsState.suppressStepSeparators = true; for (const part of existingMsg.content) { if (typeof part === "object" && part !== null) { const p = part as Record; @@ -1402,15 +1409,19 @@ export default function NewChatPage() { toolName: String(p.toolName), args: (p.args as Record) ?? {}, result: p.result as unknown, - // Restore argsText so persisted pretty-printed - // JSON survives reloads (assistant-ui prefers - // supplied argsText over JSON.stringify(args)). - // langchainToolCallId restoration also fixes a - // pre-existing dropped-id bug on resume. + // argsText: assistant-ui prefers it over + // JSON.stringify(args), so restoring it keeps + // pretty-printed JSON across reloads. ...(typeof p.argsText === "string" ? { argsText: p.argsText } : {}), ...(typeof p.langchainToolCallId === "string" ? { langchainToolCallId: p.langchainToolCallId } : {}), + // metadata: spanId / thinkingStepId drive the + // timeline's step↔tool join. Dropping these + // here orphans every rehydrated tool-call. + ...(p.metadata && typeof p.metadata === "object" + ? { metadata: p.metadata as Record } + : {}), }); contentPartsState.currentTextPartIndex = -1; } else if (p.type === "data-thinking-steps") { @@ -1730,57 +1741,6 @@ export default function NewChatPage() { return () => window.removeEventListener("hitl-decision", handler); }, [handleResume, pendingInterrupt]); - // Mirror staged bundle decisions onto the cards visually so prev/next nav - // reflects past choices instead of re-prompting. Submit's ``hitl-decision`` - // handler still runs the actual resume. - useEffect(() => { - const handler = (e: Event) => { - const detail = (e as CustomEvent).detail as { - toolCallId: string; - decision: { - type: string; - message?: string; - edited_action?: { name: string; args: Record }; - }; - }; - if (!detail?.toolCallId || !detail?.decision || !pendingInterrupt) return; - setMessages((prev) => - prev.map((m) => { - if (m.id !== pendingInterrupt.assistantMsgId) return m; - const parts = m.content as unknown as Array>; - const newContent = parts.map((part) => { - if (part.toolCallId !== detail.toolCallId) return part; - if (part.type !== "tool-call") return part; - if (typeof part.result !== "object" || part.result === null) return part; - if (!("__interrupt__" in (part.result as Record))) return part; - const decided = detail.decision.type as "approve" | "reject" | "edit"; - if (decided === "edit" && detail.decision.edited_action) { - return { - ...part, - args: detail.decision.edited_action.args, - argsText: JSON.stringify(detail.decision.edited_action.args, null, 2), - result: { - ...(part.result as Record), - __decided__: decided, - }, - }; - } - return { - ...part, - result: { - ...(part.result as Record), - __decided__: decided, - }, - }; - }); - return { ...m, content: newContent as unknown as ThreadMessageLike["content"] }; - }) - ); - }; - window.addEventListener("hitl-stage", handler); - return () => window.removeEventListener("hitl-stage", handler); - }, [pendingInterrupt]); - // Convert message (pass through since already in correct format) const convertMessage = useCallback( (message: ThreadMessageLike): ThreadMessageLike => message, @@ -2279,7 +2239,7 @@ export default function NewChatPage() { [handleRegenerate, messages, agentActionItems] ); - const handleBundleSubmit = useCallback((orderedDecisions) => { + const handleApprovalSubmit = useCallback((orderedDecisions: HitlDecision[]) => { window.dispatchEvent( new CustomEvent("hitl-decision", { detail: { decisions: orderedDecisions } }) ); @@ -2353,11 +2313,11 @@ export default function NewChatPage() { return ( - + -
@@ -2367,7 +2327,7 @@ export default function NewChatPage() {
- + { diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx index 7bccc22ee..00f3acebf 100644 --- a/surfsense_web/components/assistant-ui/assistant-message.tsx +++ b/surfsense_web/components/assistant-ui/assistant-message.tsx @@ -4,6 +4,7 @@ import { AuiIf, ErrorPrimitive, MessagePrimitive, + type ToolCallMessagePartComponent, useAui, useAuiState, } from "@assistant-ui/react"; @@ -36,11 +37,9 @@ import { MarkdownText } from "@/components/assistant-ui/markdown-text"; import { ReasoningMessagePart } from "@/components/assistant-ui/reasoning-message-part"; import { RevertTurnButton } from "@/components/assistant-ui/revert-turn-button"; import { useTokenUsage } from "@/components/assistant-ui/token-usage-context"; -import { ToolFallback } from "@/components/assistant-ui/tool-fallback"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { CommentPanelContainer } from "@/components/chat-comments/comment-panel-container/comment-panel-container"; import { CommentSheet } from "@/components/chat-comments/comment-sheet/comment-sheet"; -import { withBundleStep } from "@/components/hitl-bundle-pager"; import type { SerializableCitation } from "@/components/tool-ui/citation"; import { openSafeNavigationHref, @@ -100,146 +99,6 @@ const GenerateImageToolUI = dynamic( import("@/components/tool-ui/generate-image").then((m) => ({ default: m.GenerateImageToolUI })), { ssr: false } ); -const UpdateMemoryToolUI = dynamic( - () => import("@/components/tool-ui/user-memory").then((m) => ({ default: m.UpdateMemoryToolUI })), - { ssr: false } -); -const SandboxExecuteToolUI = dynamic( - () => - import("@/components/tool-ui/sandbox-execute").then((m) => ({ - default: m.SandboxExecuteToolUI, - })), - { ssr: false } -); -const CreateNotionPageToolUI = dynamic( - () => import("@/components/tool-ui/notion").then((m) => ({ default: m.CreateNotionPageToolUI })), - { ssr: false } -); -const UpdateNotionPageToolUI = dynamic( - () => import("@/components/tool-ui/notion").then((m) => ({ default: m.UpdateNotionPageToolUI })), - { ssr: false } -); -const DeleteNotionPageToolUI = dynamic( - () => import("@/components/tool-ui/notion").then((m) => ({ default: m.DeleteNotionPageToolUI })), - { ssr: false } -); -const CreateLinearIssueToolUI = dynamic( - () => import("@/components/tool-ui/linear").then((m) => ({ default: m.CreateLinearIssueToolUI })), - { ssr: false } -); -const UpdateLinearIssueToolUI = dynamic( - () => import("@/components/tool-ui/linear").then((m) => ({ default: m.UpdateLinearIssueToolUI })), - { ssr: false } -); -const DeleteLinearIssueToolUI = dynamic( - () => import("@/components/tool-ui/linear").then((m) => ({ default: m.DeleteLinearIssueToolUI })), - { ssr: false } -); -const CreateGoogleDriveFileToolUI = dynamic( - () => - import("@/components/tool-ui/google-drive").then((m) => ({ - default: m.CreateGoogleDriveFileToolUI, - })), - { ssr: false } -); -const DeleteGoogleDriveFileToolUI = dynamic( - () => - import("@/components/tool-ui/google-drive").then((m) => ({ - default: m.DeleteGoogleDriveFileToolUI, - })), - { ssr: false } -); -const CreateOneDriveFileToolUI = dynamic( - () => - import("@/components/tool-ui/onedrive").then((m) => ({ default: m.CreateOneDriveFileToolUI })), - { ssr: false } -); -const DeleteOneDriveFileToolUI = dynamic( - () => - import("@/components/tool-ui/onedrive").then((m) => ({ default: m.DeleteOneDriveFileToolUI })), - { ssr: false } -); -const CreateDropboxFileToolUI = dynamic( - () => - import("@/components/tool-ui/dropbox").then((m) => ({ default: m.CreateDropboxFileToolUI })), - { ssr: false } -); -const DeleteDropboxFileToolUI = dynamic( - () => - import("@/components/tool-ui/dropbox").then((m) => ({ default: m.DeleteDropboxFileToolUI })), - { ssr: false } -); -const CreateCalendarEventToolUI = dynamic( - () => - import("@/components/tool-ui/google-calendar").then((m) => ({ - default: m.CreateCalendarEventToolUI, - })), - { ssr: false } -); -const UpdateCalendarEventToolUI = dynamic( - () => - import("@/components/tool-ui/google-calendar").then((m) => ({ - default: m.UpdateCalendarEventToolUI, - })), - { ssr: false } -); -const DeleteCalendarEventToolUI = dynamic( - () => - import("@/components/tool-ui/google-calendar").then((m) => ({ - default: m.DeleteCalendarEventToolUI, - })), - { ssr: false } -); -const CreateGmailDraftToolUI = dynamic( - () => import("@/components/tool-ui/gmail").then((m) => ({ default: m.CreateGmailDraftToolUI })), - { ssr: false } -); -const UpdateGmailDraftToolUI = dynamic( - () => import("@/components/tool-ui/gmail").then((m) => ({ default: m.UpdateGmailDraftToolUI })), - { ssr: false } -); -const SendGmailEmailToolUI = dynamic( - () => import("@/components/tool-ui/gmail").then((m) => ({ default: m.SendGmailEmailToolUI })), - { ssr: false } -); -const TrashGmailEmailToolUI = dynamic( - () => import("@/components/tool-ui/gmail").then((m) => ({ default: m.TrashGmailEmailToolUI })), - { ssr: false } -); -const CreateJiraIssueToolUI = dynamic( - () => import("@/components/tool-ui/jira").then((m) => ({ default: m.CreateJiraIssueToolUI })), - { ssr: false } -); -const UpdateJiraIssueToolUI = dynamic( - () => import("@/components/tool-ui/jira").then((m) => ({ default: m.UpdateJiraIssueToolUI })), - { ssr: false } -); -const DeleteJiraIssueToolUI = dynamic( - () => import("@/components/tool-ui/jira").then((m) => ({ default: m.DeleteJiraIssueToolUI })), - { ssr: false } -); -const CreateConfluencePageToolUI = dynamic( - () => - import("@/components/tool-ui/confluence").then((m) => ({ - default: m.CreateConfluencePageToolUI, - })), - { ssr: false } -); -const UpdateConfluencePageToolUI = dynamic( - () => - import("@/components/tool-ui/confluence").then((m) => ({ - default: m.UpdateConfluencePageToolUI, - })), - { ssr: false } -); -const DeleteConfluencePageToolUI = dynamic( - () => - import("@/components/tool-ui/confluence").then((m) => ({ - default: m.DeleteConfluencePageToolUI, - })), - { ssr: false } -); - function extractDomain(url: string): string | undefined { try { return new URL(url).hostname.replace(/^www\./, ""); @@ -503,50 +362,26 @@ const MessageInfoDropdown: FC = () => { ); }; -// Wrap each tool-ui card with ``withBundleStep`` so multi-card HITL bundles -// page through them and stage decisions instead of firing one resume per card. -const TOOLS_BY_NAME = { - generate_report: withBundleStep(GenerateReportToolUI), - generate_resume: withBundleStep(GenerateResumeToolUI), - generate_podcast: withBundleStep(GeneratePodcastToolUI), - generate_video_presentation: withBundleStep(GenerateVideoPresentationToolUI), - display_image: withBundleStep(GenerateImageToolUI), - generate_image: withBundleStep(GenerateImageToolUI), - update_memory: withBundleStep(UpdateMemoryToolUI), - execute: withBundleStep(SandboxExecuteToolUI), - execute_code: withBundleStep(SandboxExecuteToolUI), - create_notion_page: withBundleStep(CreateNotionPageToolUI), - update_notion_page: withBundleStep(UpdateNotionPageToolUI), - delete_notion_page: withBundleStep(DeleteNotionPageToolUI), - create_linear_issue: withBundleStep(CreateLinearIssueToolUI), - update_linear_issue: withBundleStep(UpdateLinearIssueToolUI), - delete_linear_issue: withBundleStep(DeleteLinearIssueToolUI), - create_google_drive_file: withBundleStep(CreateGoogleDriveFileToolUI), - delete_google_drive_file: withBundleStep(DeleteGoogleDriveFileToolUI), - create_onedrive_file: withBundleStep(CreateOneDriveFileToolUI), - delete_onedrive_file: withBundleStep(DeleteOneDriveFileToolUI), - create_dropbox_file: withBundleStep(CreateDropboxFileToolUI), - delete_dropbox_file: withBundleStep(DeleteDropboxFileToolUI), - create_calendar_event: withBundleStep(CreateCalendarEventToolUI), - update_calendar_event: withBundleStep(UpdateCalendarEventToolUI), - delete_calendar_event: withBundleStep(DeleteCalendarEventToolUI), - create_gmail_draft: withBundleStep(CreateGmailDraftToolUI), - update_gmail_draft: withBundleStep(UpdateGmailDraftToolUI), - send_gmail_email: withBundleStep(SendGmailEmailToolUI), - trash_gmail_email: withBundleStep(TrashGmailEmailToolUI), - create_jira_issue: withBundleStep(CreateJiraIssueToolUI), - update_jira_issue: withBundleStep(UpdateJiraIssueToolUI), - delete_jira_issue: withBundleStep(DeleteJiraIssueToolUI), - create_confluence_page: withBundleStep(CreateConfluencePageToolUI), - update_confluence_page: withBundleStep(UpdateConfluencePageToolUI), - delete_confluence_page: withBundleStep(DeleteConfluencePageToolUI), - web_search: () => null, - link_preview: () => null, - multi_link_preview: () => null, - scrape_webpage: () => null, +/** + * Tools rendered in the message BODY — value-add deliverables only. + * + * Process tools (connector CRUD, sandbox execute, memory updates, + * etc.) are NOT here; they render in the timeline via the slice's + * tool registry (see ``features/chat-messages/timeline``). The body + * opts out of every other tool by registering ``NullBodyTool`` as the + * fallback — any tool name not in this map renders nothing in the + * body and is picked up by the timeline instead. + */ +const BODY_TOOLS = { + generate_report: GenerateReportToolUI, + generate_resume: GenerateResumeToolUI, + generate_podcast: GeneratePodcastToolUI, + generate_video_presentation: GenerateVideoPresentationToolUI, + display_image: GenerateImageToolUI, + generate_image: GenerateImageToolUI, } as const; -const TOOLS_FALLBACK = withBundleStep(ToolFallback); +const NullBodyTool: ToolCallMessagePartComponent = () => null; const AssistantMessageInner: FC = () => { const isMobile = !useMediaQuery("(min-width: 768px)"); @@ -559,8 +394,8 @@ const AssistantMessageInner: FC = () => { Text: MarkdownText, Reasoning: ReasoningMessagePart, tools: { - by_name: TOOLS_BY_NAME, - Fallback: TOOLS_FALLBACK, + by_name: BODY_TOOLS, + Fallback: NullBodyTool, }, }} /> diff --git a/surfsense_web/components/assistant-ui/reasoning-message-part.tsx b/surfsense_web/components/assistant-ui/reasoning-message-part.tsx index 70636eab8..6e7aaf048 100644 --- a/surfsense_web/components/assistant-ui/reasoning-message-part.tsx +++ b/surfsense_web/components/assistant-ui/reasoning-message-part.tsx @@ -7,8 +7,8 @@ import { TextShimmerLoader } from "@/components/prompt-kit/loader"; import { cn } from "@/lib/utils"; /** - * Renders the structured `reasoning` part emitted by the backend's - * stream-parity v2 path (A1). + * Renders the structured `reasoning` part emitted by the backend stream + * (typed reasoning deltas from the chat model). * * Behaviour mirrors the existing `ThinkingStepsDisplay`: * - collapsed by default; diff --git a/surfsense_web/components/assistant-ui/thinking-steps.tsx b/surfsense_web/components/assistant-ui/thinking-steps.tsx deleted file mode 100644 index df1cef12c..000000000 --- a/surfsense_web/components/assistant-ui/thinking-steps.tsx +++ /dev/null @@ -1,175 +0,0 @@ -import { makeAssistantDataUI, useAuiState } from "@assistant-ui/react"; -import { ChevronRightIcon } from "lucide-react"; -import type { FC } from "react"; -import { useCallback, useEffect, useState } from "react"; -import { ChainOfThoughtItem } from "@/components/prompt-kit/chain-of-thought"; -import { TextShimmerLoader } from "@/components/prompt-kit/loader"; -import { cn } from "@/lib/utils"; - -export interface ThinkingStep { - id: string; - title: string; - items: string[]; - status: "pending" | "in_progress" | "completed"; -} - -/** - * Chain of thought display component - single collapsible dropdown design - */ -export const ThinkingStepsDisplay: FC<{ steps: ThinkingStep[]; isThreadRunning?: boolean }> = ({ - steps, - isThreadRunning = true, -}) => { - const getEffectiveStatus = useCallback( - (step: ThinkingStep): "pending" | "in_progress" | "completed" => { - if (step.status === "in_progress" && !isThreadRunning) { - return "completed"; - } - return step.status; - }, - [isThreadRunning] - ); - - const inProgressStep = steps.find((s) => getEffectiveStatus(s) === "in_progress"); - const allCompleted = - steps.length > 0 && - !isThreadRunning && - steps.every((s) => getEffectiveStatus(s) === "completed"); - const isProcessing = isThreadRunning && !allCompleted; - const [isOpen, setIsOpen] = useState(() => isProcessing); - - useEffect(() => { - if (isProcessing) { - setIsOpen(true); - return; - } - - if (allCompleted) { - setIsOpen(false); - } - }, [allCompleted, isProcessing]); - - if (steps.length === 0) return null; - - const getHeaderText = () => { - if (allCompleted) { - return "Reviewed"; - } - if (inProgressStep) { - return inProgressStep.title; - } - if (isProcessing) { - return "Processing"; - } - return "Reviewed"; - }; - - return ( -
-
- - -
-
-
- {steps.map((step, index) => { - const effectiveStatus = getEffectiveStatus(step); - const isLast = index === steps.length - 1; - - return ( -
-
- {!isLast && ( -
- )} -
- {effectiveStatus === "in_progress" ? ( - - - - - ) : ( - - )} -
-
- -
-
- {step.title} -
- - {step.items && step.items.length > 0 && ( -
- {step.items.map((item) => ( - - {item} - - ))} -
- )} -
-
- ); - })} -
-
-
-
-
- ); -}; - -/** - * assistant-ui data UI component that renders thinking steps from message content. - * Registered globally via makeAssistantDataUI — renders inside MessagePrimitive.Parts - * at the position of the data part in the content array. - */ -function ThinkingStepsDataRenderer({ data }: { name: string; data: unknown }) { - const isThreadRunning = useAuiState(({ thread }) => thread.isRunning); - const isLastMessage = useAuiState(({ message }) => message?.isLast ?? false); - const isMessageStreaming = isThreadRunning && isLastMessage; - - const steps = (data as { steps: ThinkingStep[] } | null)?.steps ?? []; - if (steps.length === 0) return null; - - return ( -
- -
- ); -} - -export const ThinkingStepsDataUI = makeAssistantDataUI({ - name: "thinking-steps", - render: ThinkingStepsDataRenderer, -}); diff --git a/surfsense_web/components/assistant-ui/tool-fallback.tsx b/surfsense_web/components/assistant-ui/tool-fallback.tsx deleted file mode 100644 index 06082c9c7..000000000 --- a/surfsense_web/components/assistant-ui/tool-fallback.tsx +++ /dev/null @@ -1,512 +0,0 @@ -import { type ToolCallMessagePartComponent, useAuiState } from "@assistant-ui/react"; -import { useQueryClient } from "@tanstack/react-query"; -import { useAtomValue } from "jotai"; -import { CheckIcon, ChevronDownIcon, RotateCcw, XCircleIcon } from "lucide-react"; -import { useEffect, useMemo, useState } from "react"; -import { toast } from "sonner"; -import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom"; -import { NestedScroll } from "@/components/assistant-ui/nested-scroll"; -import { - DoomLoopApprovalToolUI, - isDoomLoopInterrupt, -} from "@/components/tool-ui/doom-loop-approval"; -import { GenericHitlApprovalToolUI } from "@/components/tool-ui/generic-hitl-approval"; -import { - AlertDialog, - AlertDialogAction, - AlertDialogCancel, - AlertDialogContent, - AlertDialogDescription, - AlertDialogFooter, - AlertDialogHeader, - AlertDialogTitle, - AlertDialogTrigger, -} from "@/components/ui/alert-dialog"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { Card } from "@/components/ui/card"; -import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/components/ui/collapsible"; -import { Separator } from "@/components/ui/separator"; -import { Spinner } from "@/components/ui/spinner"; -import { getToolDisplayName } from "@/contracts/enums/toolIcons"; -import { markActionRevertedInCache, useAgentActionsQuery } from "@/hooks/use-agent-actions-query"; -import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; -import { AppError } from "@/lib/error"; -import { isInterruptResult } from "@/lib/hitl"; -import { cn } from "@/lib/utils"; - -/** - * Inline Revert button rendered on a tool card when the matching - * ``AgentActionLog`` row is reversible and hasn't been reverted yet. - * - * Reads from the unified ``useAgentActionsQuery`` cache — the SAME - * react-query cache the agent-actions sheet consumes. SSE events - * (``data-action-log`` / ``data-action-log-updated``) and - * ``POST /threads/{id}/revert/{id}`` responses both flow through the - * cache via ``setQueryData`` helpers, so the card and the sheet stay - * in lockstep on every code path: page reload, navigation, live - * stream, post-stream reversibility flip, and explicit revert clicks. - * - * Match key (in priority order): - * 1. ``a.tool_call_id === toolCallId`` — direct hit in parity_v2 when - * the model streamed ``tool_call_chunks`` so the card's synthetic - * id IS the LangChain id. - * 2. ``a.tool_call_id === langchainToolCallId`` — legacy mode (or - * parity_v2 with provider-side chunk emission) where the card's - * synthetic id is ``call_`` and the LangChain id is - * backfilled onto the part by ``tool-output-available``. - * 3. ``(chat_turn_id, tool_name, position-within-turn)`` — fallback - * for cards whose synthetic id is ``call_`` AND whose - * ``langchainToolCallId`` never got backfilled (provider emitted - * the tool_call as a single payload with no chunks AND streaming - * pre-dated the ``tool-output-available langchainToolCallId`` - * backfill, e.g. older threads). Reads the parent message's - * ``chatTurnId`` and ``content`` via ``useAuiState`` so we can - * match position-by-tool-name within the turn against the - * action_log rows the server returned in ``created_at`` order. - */ -function ToolCardRevertButton({ - toolCallId, - toolName, - langchainToolCallId, -}: { - toolCallId: string; - toolName: string; - langchainToolCallId?: string; -}) { - const session = useAtomValue(chatSessionStateAtom); - const threadId = session?.threadId ?? null; - const queryClient = useQueryClient(); - const { findByToolCallId, findByChatTurnAndTool } = useAgentActionsQuery(threadId); - - // Parent message metadata, read via the narrowest possible - // selectors so this card doesn't re-render on every text-delta of - // every other part in the same message during streaming. - // - // IMPORTANT — ``useAuiState`` re-renders the component whenever the - // returned slice's identity changes. Returning ``message?.content`` - // (an array) would re-render on every token because the runtime - // rebuilds the parts array. Returning a PRIMITIVE (the position - // number) lets ``useAuiState``'s ``Object.is`` check short-circuit - // when the position hasn't actually moved — which is the common - // case during text streaming, when only ``text``/``reasoning`` - // parts are mutating and the same-toolName tool-call ordering is - // stable. (See Vercel React rule ``rerender-defer-reads``.) - const chatTurnId = useAuiState(({ message }) => { - const meta = message?.metadata as { custom?: { chatTurnId?: string } } | undefined; - return meta?.custom?.chatTurnId ?? null; - }); - const positionInTurn = useAuiState(({ message }) => { - const content = message?.content; - if (!Array.isArray(content)) return -1; - let n = -1; - for (const part of content) { - if ( - part && - typeof part === "object" && - (part as { type?: string }).type === "tool-call" && - (part as { toolName?: string }).toolName === toolName - ) { - n += 1; - if ((part as { toolCallId?: string }).toolCallId === toolCallId) return n; - } - } - return -1; - }); - - const action = useMemo(() => { - // Tier 1 + 2: O(1) Map-backed direct id match. Covers - // ~all parity_v2 streams and any legacy stream that backfilled - // ``langchainToolCallId`` via ``tool-output-available``. - const direct = findByToolCallId(toolCallId) ?? findByToolCallId(langchainToolCallId); - if (direct) return direct; - // Tier 3: position-within-turn fallback. Only kicks in when the - // card has a synthetic ``call_`` id AND no - // ``langchainToolCallId`` was ever backfilled — i.e. the tool - // was emitted as a single non-chunked payload AND streaming - // pre-dated the on_tool_end backfill. - if (!chatTurnId || positionInTurn < 0) return null; - const turnSameTool = findByChatTurnAndTool(chatTurnId, toolName); - return turnSameTool[positionInTurn] ?? null; - }, [ - findByToolCallId, - findByChatTurnAndTool, - toolCallId, - langchainToolCallId, - chatTurnId, - toolName, - positionInTurn, - ]); - - const [isReverting, setIsReverting] = useState(false); - const [confirmOpen, setConfirmOpen] = useState(false); - - if (!action) return null; - if (!action.reversible) return null; - if (action.reverted_by_action_id !== null && action.reverted_by_action_id !== undefined) - return null; - if (action.is_revert_action) return null; - if (action.error !== null && action.error !== undefined) return null; - if (!threadId) return null; - - const handleRevert = async () => { - setIsReverting(true); - try { - const response = await agentActionsApiService.revert(threadId, action.id); - markActionRevertedInCache(queryClient, threadId, action.id, response.new_action_id ?? null); - toast.success(response.message || "Action reverted."); - } catch (err) { - // 503 means revert is gated off on this deployment — hide the - // button silently rather than nagging the user. Any other error - // is surfaced as a toast so the operator can investigate. - if (err instanceof AppError && err.status === 503) { - return; - } - const message = - err instanceof AppError - ? err.message - : err instanceof Error - ? err.message - : "Failed to revert action."; - toast.error(message); - } finally { - setIsReverting(false); - setConfirmOpen(false); - } - }; - - return ( - - - - - - - Revert this action? - - This will undo{" "} - {getToolDisplayName(action.tool_name)} and add a - new entry to the history. Your chat is preserved — only the changes the agent made to - your knowledge base or connected apps will be rolled back where possible. - - - - Cancel - { - e.preventDefault(); - handleRevert(); - }} - disabled={isReverting} - className="gap-1.5" - > - {isReverting && } - Revert - - - - - ); -} - -/** - * Compact tool-call card. - * - * shadcn composition note: we intentionally use ``Card`` as a visual - * frame WITHOUT ``CardHeader / CardContent``. The full composition's - * ``p-6`` padding doesn't fit a compact collapsible header that IS the - * trigger; using ``Card`` alone preserves the rounded border, shadow, - * and ``bg-card`` token (semantic colors) without forcing a layout - * that doesn't fit. All status colors use semantic tokens — no manual - * dark-mode overrides, no raw hex. - */ -const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => { - const { toolCallId, toolName, argsText, result, status } = props; - // ``langchainToolCallId`` is a SurfSense-specific extension the - // streaming pipeline attaches to the tool-call content part so - // the Revert button can resolve its ``AgentActionLog`` row even - // when only the LC id is known. assistant-ui's - // ``ToolCallMessagePartProps`` doesn't list it, but the runtime - // spreads ``{...part}`` so the prop reaches us at runtime. - const langchainToolCallId = (props as { langchainToolCallId?: string }).langchainToolCallId; - - const isCancelled = status?.type === "incomplete" && status.reason === "cancelled"; - const isError = status?.type === "incomplete" && status.reason === "error"; - const isRunning = status?.type === "running" || status?.type === "requires-action"; - - /* - Per-card expansion state. Initial value is ``isRunning`` so a - card streaming in mounts already-expanded (no flash of - collapsed → expanded on first paint), while a card loaded from - history (status="complete") mounts collapsed. The useEffect - below keeps this in lockstep with this card's own ``isRunning`` - when it transitions: false → true auto-expands (e.g. a tool - that re-runs after edit), true → false auto-collapses once the - tool finishes. Because the dep is per-card ``isRunning`` and - not the chat-level streaming flag, sibling cards on the same - assistant turn each manage their own expansion independently. - Once ``isRunning`` is false the user controls expansion via - ``onOpenChange``. - */ - const [isExpanded, setIsExpanded] = useState(isRunning); - useEffect(() => { - setIsExpanded(isRunning); - }, [isRunning]); - const errorData = status?.type === "incomplete" ? status.error : undefined; - const serializedError = useMemo( - () => (errorData && typeof errorData !== "string" ? JSON.stringify(errorData) : null), - [errorData] - ); - - const serializedResult = useMemo( - () => - result !== undefined && typeof result !== "string" ? JSON.stringify(result, null, 2) : null, - [result] - ); - - const cancelledReason = - isCancelled && status.error - ? typeof status.error === "string" - ? status.error - : serializedError - : null; - const errorReason = - isError && status.error - ? typeof status.error === "string" - ? status.error - : serializedError - : null; - - const displayName = getToolDisplayName(toolName); - const subtitle = errorReason ?? cancelledReason; - - return ( - - {/* - ``group`` lets the chevron (rendered as a sibling of the - main trigger button) read the Collapsible Root's - ``data-[state=open]`` for rotation. The Collapsible is - fully controlled via ``isExpanded`` — the useEffect - above syncs it to ``isRunning`` so the card auto-opens - while a tool streams in and auto-collapses once it - finishes. We deliberately DON'T pass ``disabled`` so - both triggers stay clickable; ``onOpenChange`` is wired - to a setter that no-ops while ``isRunning`` (see - ``handleOpenChange`` below) which keeps the card pinned - open mid-stream without losing keyboard / pointer - affordance the moment streaming ends. - */} - { - // Block manual collapse while the tool is still - // streaming — otherwise a stray click on either - // trigger would close the card and hide the live - // ``argsText`` panel mid-run. After streaming the - // user has full control again. - if (isRunning) return; - setIsExpanded(next); - }} - > - {/* - Header row: main trigger on the left (icon + title - col), Revert + chevron-trigger on the right as - siblings of the main trigger. The chevron is wrapped - in its OWN ``CollapsibleTrigger`` (Radix supports - multiple triggers per Root) so clicking the chevron - toggles the same state as clicking the title row. - The Revert button stays a separate AlertDialog - trigger and stops propagation in its onClick so it - doesn't toggle the collapsible while opening the - confirm dialog. Keeping these as flat siblings — - rather than nesting Revert / chevron inside the - title trigger — avoids invalid HTML - (button-in-button) and lets the Revert button - render in BOTH the collapsed and expanded states. - */} -
- - - - - {/* - Right-side controls. The Revert button is - visible whenever the matching action is - reversible — including the collapsed state — - but ``ToolCardRevertButton`` itself returns - ``null`` while a tool is still running because - no action-log row exists yet, so it doesn't - need an explicit ``isRunning`` gate here. - */} -
- - - - -
-
- - {/* - CollapsibleContent body — auto-open while streaming - (see ``open`` prop above) so the live ``argsText`` - streams into the Inputs panel directly, no need for - a separate "Live input" panel. Native - ``overflow-auto`` instead of ``ScrollArea`` because - Radix's Viewport can let content bleed past - ``max-h-*`` in dynamic flex layouts. ``min-w-0`` on - the column wrappers guarantees ``break-all`` wraps - correctly within the bounded ``max-w-lg`` Card. - */} - - -
- {(argsText || isRunning) && ( -
-

Inputs

- - {argsText ? ( -
-											{argsText}
-										
- ) : ( - // Bridges the brief gap between - // ``tool-input-start`` (creates the - // card, ``argsText`` undefined) and - // the first ``tool-input-delta``. -

- Waiting for input… -

- )} -
-
- )} - {!isCancelled && result !== undefined && ( - <> - -
-

Result

- -
-											{typeof result === "string" ? result : serializedResult}
-										
-
-
- - )} -
-
-
-
- ); -}; - -export const ToolFallback: ToolCallMessagePartComponent = (props) => { - if (isInterruptResult(props.result)) { - if (isDoomLoopInterrupt(props.result)) { - return ; - } - return ; - } - return ; -}; diff --git a/surfsense_web/components/free-chat/free-chat-page.tsx b/surfsense_web/components/free-chat/free-chat-page.tsx index 080d9a2b6..927eaef87 100644 --- a/surfsense_web/components/free-chat/free-chat-page.tsx +++ b/surfsense_web/components/free-chat/free-chat-page.tsx @@ -10,13 +10,13 @@ import { Turnstile, type TurnstileInstance } from "@marsidev/react-turnstile"; import { ShieldCheck } from "lucide-react"; import { useCallback, useEffect, useRef, useState } from "react"; import { StepSeparatorDataUI } from "@/components/assistant-ui/step-separator"; -import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps"; import { createTokenUsageStore, type TokenUsageData, TokenUsageProvider, } from "@/components/assistant-ui/token-usage-context"; import { useAnonymousMode } from "@/contexts/anonymous-mode"; +import { TimelineDataUI } from "@/features/chat-messages/timeline"; import { addStepSeparator, addToolCall, @@ -228,7 +228,8 @@ export function FreeChatPage() { parsed.toolName, {}, false, - parsed.langchainToolCallId + parsed.langchainToolCallId, + parsed.metadata ); forceFlush(); break; @@ -245,6 +246,7 @@ export function FreeChatPage() { args: parsed.input || {}, argsText: finalArgsText, langchainToolCallId: parsed.langchainToolCallId, + metadata: parsed.metadata, }); } else { addToolCall( @@ -254,7 +256,8 @@ export function FreeChatPage() { parsed.toolName, parsed.input || {}, false, - parsed.langchainToolCallId + parsed.langchainToolCallId, + parsed.metadata ); updateToolCall(contentPartsState, parsed.toolCallId, { argsText: finalArgsText, @@ -268,6 +271,7 @@ export function FreeChatPage() { updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output, langchainToolCallId: parsed.langchainToolCallId, + metadata: parsed.metadata, }); forceFlush(); break; @@ -469,7 +473,7 @@ export function FreeChatPage() { return ( - +
diff --git a/surfsense_web/components/hitl-bundle-pager/index.ts b/surfsense_web/components/hitl-bundle-pager/index.ts deleted file mode 100644 index ce434d224..000000000 --- a/surfsense_web/components/hitl-bundle-pager/index.ts +++ /dev/null @@ -1,2 +0,0 @@ -export { PagerChrome } from "./pager-chrome"; -export { withBundleStep } from "./with-bundle-step"; diff --git a/surfsense_web/components/hitl-bundle-pager/pager-chrome.tsx b/surfsense_web/components/hitl-bundle-pager/pager-chrome.tsx deleted file mode 100644 index 77d75fb6d..000000000 --- a/surfsense_web/components/hitl-bundle-pager/pager-chrome.tsx +++ /dev/null @@ -1,61 +0,0 @@ -"use client"; - -import { ChevronLeftIcon, ChevronRightIcon } from "lucide-react"; -import { Button } from "@/components/ui/button"; -import { useHitlBundle } from "@/lib/hitl"; - -/** - * Prev/next nav and Submit for the current step of an active HITL bundle. - * Submission is gated on every action_request having a staged decision. - */ -export function PagerChrome() { - const bundle = useHitlBundle(); - if (!bundle) return null; - - const total = bundle.toolCallIds.length; - const step = bundle.currentStep; - const allStaged = bundle.stagedCount === total; - - return ( -
- - - {step + 1} / {total} - - · - - {bundle.stagedCount} of {total} decided - - -
- -
-
- ); -} diff --git a/surfsense_web/components/hitl-bundle-pager/with-bundle-step.tsx b/surfsense_web/components/hitl-bundle-pager/with-bundle-step.tsx deleted file mode 100644 index 64ac801fb..000000000 --- a/surfsense_web/components/hitl-bundle-pager/with-bundle-step.tsx +++ /dev/null @@ -1,37 +0,0 @@ -"use client"; - -import type { ToolCallMessagePartProps } from "@assistant-ui/react"; -import type { ComponentType } from "react"; -import { ToolCallIdProvider, useHitlBundle } from "@/lib/hitl"; -import { PagerChrome } from "./pager-chrome"; - -/** - * Wrap a tool-ui card so that, when a multi-card HITL bundle is active: - * - cards belonging to the bundle but not the current step render ``null``; - * - the current-step card renders normally and is followed by ``PagerChrome``. - * - * Cards stay completely unchanged — the wrapper provides the - * ``ToolCallIdContext`` that ``useHitlDecision`` reads to stage decisions - * against the right ``toolCallId`` instead of firing the global event. - */ -export function withBundleStep

>( - Component: ComponentType

-): ComponentType

{ - function BundleStepWrapped(props: P) { - const bundle = useHitlBundle(); - const toolCallId = props.toolCallId; - const inBundle = bundle?.isInBundle(toolCallId) ?? false; - const isStep = bundle?.isCurrentStep(toolCallId) ?? false; - - if (bundle && inBundle && !isStep) return null; - - return ( - - - {bundle && isStep ? : null} - - ); - } - BundleStepWrapped.displayName = `withBundleStep(${Component.displayName ?? Component.name ?? "ToolUI"})`; - return BundleStepWrapped as ComponentType

; -} diff --git a/surfsense_web/components/hitl-edit-panel/hitl-edit-panel.tsx b/surfsense_web/components/hitl-edit-panel/hitl-edit-panel.tsx deleted file mode 100644 index b33392f38..000000000 --- a/surfsense_web/components/hitl-edit-panel/hitl-edit-panel.tsx +++ /dev/null @@ -1,405 +0,0 @@ -"use client"; - -import { format } from "date-fns"; -import { TagInput, type Tag as TagType } from "emblor"; -import { useAtomValue, useSetAtom } from "jotai"; -import { CalendarIcon, XIcon } from "lucide-react"; -import dynamic from "next/dynamic"; -import { useCallback, useEffect, useMemo, useRef, useState } from "react"; -import type { ExtraField } from "@/atoms/chat/hitl-edit-panel.atom"; -import { closeHitlEditPanelAtom, hitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom"; -import { Button } from "@/components/ui/button"; -import { Calendar } from "@/components/ui/calendar"; -import { Drawer, DrawerContent, DrawerHandle, DrawerTitle } from "@/components/ui/drawer"; -import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; -import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; -import { Skeleton } from "@/components/ui/skeleton"; -import { Textarea } from "@/components/ui/textarea"; -import { useMediaQuery } from "@/hooks/use-media-query"; - -const PlateEditor = dynamic( - () => import("@/components/editor/plate-editor").then((m) => ({ default: m.PlateEditor })), - { ssr: false, loading: () => } -); - -function parseEmailsToTags(value: string): TagType[] { - if (!value.trim()) return []; - return value - .split(",") - .map((s) => s.trim()) - .filter(Boolean) - .map((email, i) => ({ id: `${Date.now()}-${i}`, text: email })); -} - -function tagsToEmailString(tags: TagType[]): string { - return tags.map((t) => t.text).join(", "); -} - -function EmailsTagField({ - id, - value, - onChange, - placeholder, -}: { - id: string; - value: string; - onChange: (value: string) => void; - placeholder?: string; -}) { - const [tags, setTags] = useState(() => parseEmailsToTags(value)); - const [activeTagIndex, setActiveTagIndex] = useState(null); - const isInitialMount = useRef(true); - const onChangeRef = useRef(onChange); - onChangeRef.current = onChange; - - useEffect(() => { - if (isInitialMount.current) { - isInitialMount.current = false; - return; - } - onChangeRef.current(tagsToEmailString(tags)); - }, [tags]); - - const handleSetTags = useCallback((newTags: TagType[] | ((prev: TagType[]) => TagType[])) => { - setTags((prev) => (typeof newTags === "function" ? newTags(prev) : newTags)); - }, []); - - const handleAddTag = useCallback((text: string) => { - const trimmed = text.trim(); - if (!trimmed) return; - setTags((prev) => { - if (prev.some((tag) => tag.text === trimmed)) return prev; - const newTag: TagType = { id: Date.now().toString(), text: trimmed }; - return [...prev, newTag]; - }); - }, []); - - return ( - - ); -} - -function parseDateTimeValue(value: string): { date: Date | undefined; time: string } { - if (!value) return { date: undefined, time: "09:00" }; - try { - const d = new Date(value); - if (Number.isNaN(d.getTime())) return { date: undefined, time: "09:00" }; - return { - date: d, - time: format(d, "HH:mm"), - }; - } catch { - return { date: undefined, time: "09:00" }; - } -} - -function buildLocalDateTimeString(date: Date | undefined, time: string): string { - if (!date) return ""; - const [hours, minutes] = time.split(":").map(Number); - const combined = new Date(date); - combined.setHours(hours ?? 9, minutes ?? 0, 0, 0); - const y = combined.getFullYear(); - const m = String(combined.getMonth() + 1).padStart(2, "0"); - const d = String(combined.getDate()).padStart(2, "0"); - const h = String(combined.getHours()).padStart(2, "0"); - const min = String(combined.getMinutes()).padStart(2, "0"); - return `${y}-${m}-${d}T${h}:${min}:00`; -} - -function DateTimePickerField({ - id, - value, - onChange, -}: { - id: string; - value: string; - onChange: (value: string) => void; -}) { - const parsed = useMemo(() => parseDateTimeValue(value), [value]); - const [selectedDate, setSelectedDate] = useState(parsed.date); - const [time, setTime] = useState(parsed.time); - const [open, setOpen] = useState(false); - - const handleDateSelect = useCallback( - (day: Date | undefined) => { - setSelectedDate(day); - onChange(buildLocalDateTimeString(day, time)); - setOpen(false); - }, - [time, onChange] - ); - - const handleTimeChange = useCallback( - (e: React.ChangeEvent) => { - const newTime = e.target.value; - setTime(newTime); - onChange(buildLocalDateTimeString(selectedDate, newTime)); - }, - [selectedDate, onChange] - ); - - const displayLabel = selectedDate - ? `${format(selectedDate, "MMM d, yyyy")} at ${time}` - : "Pick date & time"; - - return ( -

- - - - - - - - - -
- ); -} - -export function HitlEditPanelContent({ - title: initialTitle, - content: initialContent, - contentFormat, - extraFields, - onSave, - onClose, - showCloseButton = true, -}: { - title: string; - content: string; - toolName: string; - contentFormat?: "markdown" | "html"; - extraFields?: ExtraField[]; - onSave: (title: string, content: string, extraFieldValues?: Record) => void; - onClose?: () => void; - showCloseButton?: boolean; -}) { - const [editedTitle, setEditedTitle] = useState(initialTitle); - const contentRef = useRef(initialContent); - const [isSaving, setIsSaving] = useState(false); - const [extraFieldValues, setExtraFieldValues] = useState>(() => { - if (!extraFields) return {}; - const initial: Record = {}; - for (const field of extraFields) { - initial[field.key] = field.value; - } - return initial; - }); - - const handleContentChange = useCallback((content: string) => { - contentRef.current = content; - }, []); - - const handleExtraFieldChange = useCallback((key: string, value: string) => { - setExtraFieldValues((prev) => ({ ...prev, [key]: value })); - }, []); - - const handleSave = useCallback(() => { - if (!editedTitle.trim()) return; - setIsSaving(true); - const extras = extraFields && extraFields.length > 0 ? extraFieldValues : undefined; - onSave(editedTitle, contentRef.current, extras); - onClose?.(); - }, [editedTitle, onSave, onClose, extraFields, extraFieldValues]); - - return ( - <> -
- setEditedTitle(e.target.value)} - placeholder="Untitled" - className="flex-1 min-w-0 bg-transparent text-sm font-semibold text-foreground outline-none placeholder:text-muted-foreground" - aria-label="Page title" - /> - {onClose && showCloseButton && ( - - )} -
- - {extraFields && extraFields.length > 0 && ( -
- {extraFields.map((field) => ( -
- - {field.type === "emails" ? ( - handleExtraFieldChange(field.key, v)} - placeholder={`Add ${field.label.toLowerCase()}`} - /> - ) : field.type === "datetime-local" ? ( - handleExtraFieldChange(field.key, v)} - /> - ) : field.type === "textarea" ? ( -