Merge pull request #1357 from CREDO23/feature/multi-agent

[Feature] Multi-agent chat: hierarchical timeline, live subagent streaming, and inline HITL approvals
This commit is contained in:
Rohan Verma 2026-05-09 16:13:04 -07:00 committed by GitHub
commit 28a02a9143
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
232 changed files with 9014 additions and 4055 deletions

View file

@ -1,4 +1,9 @@
{
"biome.configurationPath": "./surfsense_web/biome.json",
"deepscan.ignoreConfirmWarning": true
"deepscan.ignoreConfirmWarning": true,
"python.defaultInterpreterPath": "${workspaceFolder}/surfsense_backend/.venv/bin/python",
"basedpyright.analysis.extraPaths": [
"${workspaceFolder}/surfsense_backend"
],
"python-envs.pythonProjects": []
}

View file

@ -324,7 +324,6 @@ SURFSENSE_ENABLE_ACTION_LOG=true
SURFSENSE_ENABLE_REVERT_ROUTE=true
SURFSENSE_ENABLE_PERMISSION=true
SURFSENSE_ENABLE_DOOM_LOOP=true
SURFSENSE_ENABLE_STREAM_PARITY_V2=true
# Periodic connector sync interval (default: 5m)
# SCHEDULE_CHECKER_INTERVAL=5m

View file

@ -315,14 +315,6 @@ LANGSMITH_PROJECT=surfsense
# SURFSENSE_ENABLE_ACTION_LOG=false
# SURFSENSE_ENABLE_REVERT_ROUTE=false # Backend-only; flip when UI ships
# Streaming parity v2 — opt in to LangChain's structured AIMessageChunk
# content (typed reasoning blocks, tool-input deltas) and propagate the
# real tool_call_id to the SSE layer. When OFF, the stream falls back to
# the str-only text path and synthetic "call_<run_id>" tool-call ids.
# Schema migrations 135/136 ship unconditionally because they are
# forward-compatible.
# SURFSENSE_ENABLE_STREAM_PARITY_V2=false
# Plugins
# SURFSENSE_ENABLE_PLUGIN_LOADER=false
# Comma-separated allowlist of plugin entry-point names

View file

@ -6,12 +6,19 @@ exposes the side-channel ``stream_resume_chat`` uses to ferry resume payloads.
from __future__ import annotations
import logging
from typing import Any
from langchain.tools import ToolRuntime
from .constants import DEFAULT_SUBAGENT_RECURSION_LIMIT
logger = logging.getLogger(__name__)
# langgraph stores the parent task's scratchpad under this configurable key;
# subagents inherit the chain via ``parent_scratchpad`` fallback.
_LANGGRAPH_SCRATCHPAD_KEY = "__pregel_scratchpad"
def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]:
"""RunnableConfig for the nested invoke; raises ``recursion_limit`` to the parent's budget."""
@ -42,3 +49,42 @@ def has_surfsense_resume(runtime: ToolRuntime) -> bool:
if not isinstance(configurable, dict):
return False
return "surfsense_resume_value" in configurable
def drain_parent_null_resume(runtime: ToolRuntime) -> None:
"""Consume the parent's lingering ``NULL_TASK_ID/RESUME`` write before delegating.
``stream_resume_chat`` wakes the main agent with
``Command(resume={"decisions": [...]})`` so the propagated
``_lg_interrupt(...)`` can return. langgraph stores that payload as the
parent task's ``null_resume`` pending write, which only gets consumed
*after* ``subagent.[a]invoke`` returns (when the post-call propagation
re-fires). While the subagent is mid-execution, any *new* ``interrupt()``
inside it (e.g. a follow-up tool call after a mixed approve/reject) walks
``subagent_scratchpad parent_scratchpad.get_null_resume`` and picks up
the parent's still-live decisions — mismatching against a different number
of hanging tool calls and crashing ``HumanInTheLoopMiddleware``.
Draining the write here closes that cross-graph leak so subagent
interrupts pause cleanly and re-propagate as a fresh approval card.
"""
cfg = runtime.config or {}
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
if not isinstance(configurable, dict):
return
scratchpad = configurable.get(_LANGGRAPH_SCRATCHPAD_KEY)
if scratchpad is None:
return
consume = getattr(scratchpad, "get_null_resume", None)
if not callable(consume):
return
try:
consume(True)
except Exception:
# Defensive: if langgraph's internal scratchpad shape changes we don't
# want to break the resume path. Worst case the original ValueError
# still surfaces — same behavior as before this fix.
logger.debug(
"drain_parent_null_resume: scratchpad.get_null_resume raised",
exc_info=True,
)

View file

@ -20,6 +20,7 @@ from langgraph.types import Command
from .config import (
consume_surfsense_resume,
drain_parent_null_resume,
has_surfsense_resume,
subagent_invoke_config,
)
@ -157,6 +158,9 @@ def build_task_tool_with_parent_config(
)
expected = hitlrequest_action_count(pending_value)
resume_value = fan_out_decisions_to_match(resume_value, expected)
# Prevent the parent's resume payload from leaking into subagent
# interrupts via langgraph's parent_scratchpad fallback.
drain_parent_null_resume(runtime)
result = subagent.invoke(
build_resume_command(resume_value, pending_id),
config=sub_config,
@ -221,6 +225,9 @@ def build_task_tool_with_parent_config(
)
expected = hitlrequest_action_count(pending_value)
resume_value = fan_out_decisions_to_match(resume_value, expected)
# Prevent the parent's resume payload from leaking into subagent
# interrupts via langgraph's parent_scratchpad fallback.
drain_parent_null_resume(runtime)
result = await subagent.ainvoke(
build_resume_command(resume_value, pending_id),
config=sub_config,

View file

@ -1,11 +1,3 @@
"""Jira tools for creating, updating, and deleting issues."""
"""Jira route: native tool factories are empty; MCP supplies tools when configured."""
from .create_issue import create_create_jira_issue_tool
from .delete_issue import create_delete_jira_issue_tool
from .update_issue import create_update_jira_issue_tool
__all__ = [
"create_create_jira_issue_tool",
"create_delete_jira_issue_tool",
"create_update_jira_issue_tool",
]
__all__: list[str] = []

View file

@ -6,29 +6,9 @@ from app.agents.multi_agent_chat.subagents.shared.permissions import (
ToolsPermissions,
)
from .create_issue import create_create_jira_issue_tool
from .delete_issue import create_delete_jira_issue_tool
from .update_issue import create_update_jira_issue_tool
def load_tools(
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
) -> ToolsPermissions:
d = {**(dependencies or {}), **kwargs}
common = {
"db_session": d["db_session"],
"search_space_id": d["search_space_id"],
"user_id": d["user_id"],
"connector_id": d.get("connector_id"),
}
create = create_create_jira_issue_tool(**common)
update = create_update_jira_issue_tool(**common)
delete = create_delete_jira_issue_tool(**common)
return {
"allow": [],
"ask": [
{"name": getattr(create, "name", "") or "", "tool": create},
{"name": getattr(update, "name", "") or "", "tool": update},
{"name": getattr(delete, "name", "") or "", "tool": delete},
],
}
_ = {**(dependencies or {}), **kwargs}
return {"allow": [], "ask": []}

View file

@ -1,11 +1,3 @@
"""Linear tools for creating, updating, and deleting issues."""
"""Linear route: native tool factories are empty; MCP supplies tools when configured."""
from .create_issue import create_create_linear_issue_tool
from .delete_issue import create_delete_linear_issue_tool
from .update_issue import create_update_linear_issue_tool
__all__ = [
"create_create_linear_issue_tool",
"create_delete_linear_issue_tool",
"create_update_linear_issue_tool",
]
__all__: list[str] = []

View file

@ -1,248 +0,0 @@
import logging
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.linear_connector import LinearAPIError, LinearConnector
from app.services.linear import LinearToolMetadataService
logger = logging.getLogger(__name__)
def create_create_linear_issue_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
connector_id: int | None = None,
):
"""
Factory function to create the create_linear_issue tool.
Args:
db_session: Database session for accessing the Linear connector
search_space_id: Search space ID to find the Linear connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
Returns:
Configured create_linear_issue tool
"""
@tool
async def create_linear_issue(
title: str,
description: str | None = None,
) -> dict[str, Any]:
"""Create a new issue in Linear.
Use this tool when the user explicitly asks to create, add, or file
a new issue / ticket / task in Linear. The user MUST describe the issue
before you call this tool. If the request is vague, ask what the issue
should be about. Never call this tool without a clear topic from the user.
Args:
title: Short, descriptive issue title. Infer from the user's request.
description: Optional markdown body for the issue. Generate from context.
Returns:
Dictionary with:
- status: "success", "rejected", or "error"
- issue_id: Linear issue UUID (if success)
- identifier: Human-readable ID like "ENG-42" (if success)
- url: URL to the created issue (if success)
- message: Result message
IMPORTANT: If status is "rejected", the user explicitly declined the action.
Respond with a brief acknowledgment (e.g., "Understood, I won't create the issue.")
and move on. Do NOT retry, troubleshoot, or suggest alternatives.
Examples:
- "Create a Linear issue for the login bug"
- "File a ticket about the payment timeout problem"
- "Add an issue for the broken search feature"
"""
logger.info(f"create_linear_issue called: title='{title}'")
if db_session is None or search_space_id is None or user_id is None:
logger.error(
"Linear tool not properly configured - missing required parameters"
)
return {
"status": "error",
"message": "Linear tool not properly configured. Please contact support.",
}
try:
metadata_service = LinearToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
if "error" in context:
logger.error(f"Failed to fetch creation context: {context['error']}")
return {"status": "error", "message": context["error"]}
workspaces = context.get("workspaces", [])
if workspaces and all(w.get("auth_expired") for w in workspaces):
logger.warning("All Linear accounts have expired authentication")
return {
"status": "auth_error",
"message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "linear",
}
logger.info(f"Requesting approval for creating Linear issue: '{title}'")
result = request_approval(
action_type="linear_issue_creation",
tool_name="create_linear_issue",
params={
"title": title,
"description": description,
"team_id": None,
"state_id": None,
"assignee_id": None,
"priority": None,
"label_ids": [],
"connector_id": connector_id,
},
context=context,
)
if result.rejected:
logger.info("Linear issue creation rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_title = result.params.get("title", title)
final_description = result.params.get("description", description)
final_team_id = result.params.get("team_id")
final_state_id = result.params.get("state_id")
final_assignee_id = result.params.get("assignee_id")
final_priority = result.params.get("priority")
final_label_ids = result.params.get("label_ids") or []
final_connector_id = result.params.get("connector_id", connector_id)
if not final_title or not final_title.strip():
logger.error("Title is empty or contains only whitespace")
return {"status": "error", "message": "Issue title cannot be empty."}
if not final_team_id:
return {
"status": "error",
"message": "A team must be selected to create an issue.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
actual_connector_id = final_connector_id
if actual_connector_id is None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.LINEAR_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Linear connector found. Please connect Linear in your workspace settings.",
}
actual_connector_id = connector.id
logger.info(f"Found Linear connector: id={actual_connector_id}")
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == actual_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.LINEAR_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Linear connector is invalid or has been disconnected.",
}
logger.info(f"Validated Linear connector: id={actual_connector_id}")
logger.info(
f"Creating Linear issue with final params: title='{final_title}'"
)
linear_client = LinearConnector(
session=db_session, connector_id=actual_connector_id
)
result = await linear_client.create_issue(
team_id=final_team_id,
title=final_title,
description=final_description,
state_id=final_state_id,
assignee_id=final_assignee_id,
priority=final_priority,
label_ids=final_label_ids if final_label_ids else None,
)
if result.get("status") == "error":
logger.error(f"Failed to create Linear issue: {result.get('message')}")
return {"status": "error", "message": result.get("message")}
logger.info(
f"Linear issue created: {result.get('identifier')} - {result.get('title')}"
)
kb_message_suffix = ""
try:
from app.services.linear import LinearKBSyncService
kb_service = LinearKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
issue_id=result.get("id"),
issue_identifier=result.get("identifier", ""),
issue_title=result.get("title", final_title),
issue_url=result.get("url"),
description=final_description,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = " Your knowledge base has also been updated."
else:
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"issue_id": result.get("id"),
"identifier": result.get("identifier"),
"url": result.get("url"),
"message": (result.get("message", "") + kb_message_suffix),
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error creating Linear issue: {e}", exc_info=True)
if isinstance(e, ValueError | LinearAPIError):
message = str(e)
else:
message = (
"Something went wrong while creating the issue. Please try again."
)
return {"status": "error", "message": message}
return create_linear_issue

View file

@ -1,245 +0,0 @@
import logging
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.linear_connector import LinearAPIError, LinearConnector
from app.services.linear import LinearToolMetadataService
logger = logging.getLogger(__name__)
def create_delete_linear_issue_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
connector_id: int | None = None,
):
"""
Factory function to create the delete_linear_issue tool.
Args:
db_session: Database session for accessing the Linear connector
search_space_id: Search space ID to find the Linear connector
user_id: User ID for finding the correct Linear connector
connector_id: Optional specific connector ID (if known)
Returns:
Configured delete_linear_issue tool
"""
@tool
async def delete_linear_issue(
issue_ref: str,
delete_from_kb: bool = False,
) -> dict[str, Any]:
"""Archive (delete) a Linear issue.
Use this tool when the user asks to delete, remove, or archive a Linear issue.
Note that Linear archives issues rather than permanently deleting them
(they can be restored from the archive).
Args:
issue_ref: The issue to delete. Can be the issue title (e.g. "Fix login bug"),
the identifier (e.g. "ENG-42"), or the full document title
(e.g. "ENG-42: Fix login bug").
delete_from_kb: Whether to also remove the issue from the knowledge base.
Default is False. Set to True to remove from both Linear
and the knowledge base.
Returns:
Dictionary with:
- status: "success", "rejected", "not_found", or "error"
- identifier: Human-readable ID like "ENG-42" (if success)
- message: Success or error message
- deleted_from_kb: Whether the issue was also removed from the knowledge base (if success)
IMPORTANT:
- If status is "rejected", the user explicitly declined the action.
Respond with a brief acknowledgment (e.g., "Understood, I won't delete the issue.")
and move on. Do NOT ask for alternatives or troubleshoot.
- If status is "not_found", inform the user conversationally using the exact message
provided. Do NOT treat this as an error. Simply relay the message and ask the user
to verify the issue title or identifier, or check if it has been indexed.
Examples:
- "Delete the 'Fix login bug' Linear issue"
- "Archive ENG-42"
- "Remove the 'Old payment flow' issue from Linear"
"""
logger.info(
f"delete_linear_issue called: issue_ref='{issue_ref}', delete_from_kb={delete_from_kb}"
)
if db_session is None or search_space_id is None or user_id is None:
logger.error(
"Linear tool not properly configured - missing required parameters"
)
return {
"status": "error",
"message": "Linear tool not properly configured. Please contact support.",
}
try:
metadata_service = LinearToolMetadataService(db_session)
context = await metadata_service.get_delete_context(
search_space_id, user_id, issue_ref
)
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
logger.warning(f"Auth expired for delete context: {error_msg}")
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "linear",
}
if "not found" in error_msg.lower():
logger.warning(f"Issue not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
else:
logger.error(f"Failed to fetch delete context: {error_msg}")
return {"status": "error", "message": error_msg}
issue_id = context["issue"]["id"]
issue_identifier = context["issue"].get("identifier", "")
document_id = context["issue"]["document_id"]
connector_id_from_context = context.get("workspace", {}).get("id")
logger.info(
f"Requesting approval for deleting Linear issue: '{issue_ref}' "
f"(id={issue_id}, delete_from_kb={delete_from_kb})"
)
result = request_approval(
action_type="linear_issue_deletion",
tool_name="delete_linear_issue",
params={
"issue_id": issue_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected:
logger.info("Linear issue deletion rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_issue_id = result.params.get("issue_id", issue_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
logger.info(
f"Deleting Linear issue with final params: issue_id={final_issue_id}, "
f"connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if final_connector_id:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.LINEAR_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
logger.error(
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "Selected Linear connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
logger.info(f"Validated Linear connector: id={actual_connector_id}")
else:
logger.error("No connector found for this issue")
return {
"status": "error",
"message": "No connector found for this issue.",
}
linear_client = LinearConnector(
session=db_session, connector_id=actual_connector_id
)
result = await linear_client.archive_issue(issue_id=final_issue_id)
logger.info(
f"archive_issue result: {result.get('status')} - {result.get('message', '')}"
)
deleted_from_kb = False
if (
result.get("status") == "success"
and final_delete_from_kb
and document_id
):
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
result["warning"] = (
f"Issue archived in Linear, but failed to remove from knowledge base: {e!s}"
)
if result.get("status") == "success":
result["deleted_from_kb"] = deleted_from_kb
if issue_identifier:
result["message"] = (
f"Issue {issue_identifier} archived successfully."
)
if deleted_from_kb:
result["message"] = (
f"{result.get('message', '')} Also removed from the knowledge base."
)
return result
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error deleting Linear issue: {e}", exc_info=True)
if isinstance(e, ValueError | LinearAPIError):
message = str(e)
else:
message = (
"Something went wrong while deleting the issue. Please try again."
)
return {"status": "error", "message": message}
return delete_linear_issue

View file

@ -6,29 +6,9 @@ from app.agents.multi_agent_chat.subagents.shared.permissions import (
ToolsPermissions,
)
from .create_issue import create_create_linear_issue_tool
from .delete_issue import create_delete_linear_issue_tool
from .update_issue import create_update_linear_issue_tool
def load_tools(
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
) -> ToolsPermissions:
d = {**(dependencies or {}), **kwargs}
common = {
"db_session": d["db_session"],
"search_space_id": d["search_space_id"],
"user_id": d["user_id"],
"connector_id": d.get("connector_id"),
}
create = create_create_linear_issue_tool(**common)
update = create_update_linear_issue_tool(**common)
delete = create_delete_linear_issue_tool(**common)
return {
"allow": [],
"ask": [
{"name": getattr(create, "name", "") or "", "tool": create},
{"name": getattr(update, "name", "") or "", "tool": update},
{"name": getattr(delete, "name", "") or "", "tool": delete},
],
}
_ = {**(dependencies or {}), **kwargs}
return {"allow": [], "ask": []}

View file

@ -28,7 +28,6 @@ Defaults:
SURFSENSE_ENABLE_PERMISSION=true
SURFSENSE_ENABLE_DOOM_LOOP=true
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call
SURFSENSE_ENABLE_STREAM_PARITY_V2=true
Master kill-switch (overrides everything else):
@ -88,15 +87,6 @@ class AgentFeatureFlags:
enable_action_log: bool = True
enable_revert_route: bool = True
# Streaming parity v2 — opt in to LangChain's structured
# ``AIMessageChunk`` content (typed reasoning blocks, tool-input
# deltas) and propagate the real ``tool_call_id`` to the SSE layer.
# When OFF the ``stream_new_chat`` task falls back to the str-only
# text path and the synthetic ``call_<run_id>`` tool-call id (no
# ``langchainToolCallId`` propagation). Schema migrations 135/136
# ship unconditionally because they're forward-compatible.
enable_stream_parity_v2: bool = True
# Plugins
enable_plugin_loader: bool = False
@ -169,7 +159,6 @@ class AgentFeatureFlags:
enable_kb_planner_runnable=False,
enable_action_log=False,
enable_revert_route=False,
enable_stream_parity_v2=False,
enable_plugin_loader=False,
enable_otel=False,
enable_agent_cache=False,
@ -208,10 +197,6 @@ class AgentFeatureFlags:
# Snapshot / revert
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", True),
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", True),
# Streaming parity v2
enable_stream_parity_v2=_env_bool(
"SURFSENSE_ENABLE_STREAM_PARITY_V2", True
),
# Plugins
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
# Observability

View file

@ -71,7 +71,10 @@ from app.schemas.new_chat import (
TokenUsageSummary,
TurnStatusResponse,
)
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
from app.tasks.chat.stream_new_chat import (
stream_new_chat,
stream_resume_chat,
)
from app.users import current_active_user
from app.utils.perf import get_perf_logger
from app.utils.rbac import check_permission

View file

@ -380,7 +380,7 @@ class ResumeRequest(BaseModel):
"/regenerate. Resume reuses the original interrupted user "
"turn so the server does not write a new user message. "
"Currently unused but accepted to keep request bodies "
"uniform across the three streaming entrypoints."
"uniform across new-message, regenerate, and resume stream routes."
),
)

View file

@ -456,6 +456,8 @@ class VercelStreamingService:
title: str,
status: str = "in_progress",
items: list[str] | None = None,
*,
metadata: dict[str, Any] | None = None,
) -> str:
"""
Format a thinking step for chain-of-thought display (SurfSense specific).
@ -469,15 +471,15 @@ class VercelStreamingService:
Returns:
str: SSE formatted thinking step data part
"""
return self.format_data(
"thinking-step",
{
"id": step_id,
"title": title,
"status": status,
"items": items or [],
},
)
payload: dict[str, Any] = {
"id": step_id,
"title": title,
"status": status,
"items": items or [],
}
if metadata:
payload["metadata"] = metadata
return self.format_data("thinking-step", payload)
def format_thread_title_update(self, thread_id: int, title: str) -> str:
"""
@ -601,6 +603,7 @@ class VercelStreamingService:
tool_name: str,
*,
langchain_tool_call_id: str | None = None,
metadata: dict[str, Any] | None = None,
) -> str:
"""
Format the start of tool input streaming.
@ -608,15 +611,14 @@ class VercelStreamingService:
Args:
tool_call_id: The unique tool call identifier. May be EITHER the
synthetic ``call_<run_id>`` id derived from LangGraph
``run_id`` (legacy / ``SURFSENSE_ENABLE_STREAM_PARITY_V2``
OFF, or the unmatched-fallback path under parity_v2) OR
the authoritative LangChain ``tool_call.id`` (parity_v2
path: when the provider streams ``tool_call_chunks`` we
register the ``index`` and reuse the lc-id as the card
id so live ``tool-input-delta`` events can be routed
without a downstream join). Either way, the same id is
preserved across ``tool-input-start`` / ``-delta`` /
``-available`` / ``tool-output-available`` for one call.
``run_id`` (unmatched chunk fallback when no ``index`` was
registered) OR the authoritative LangChain ``tool_call.id``
(when the provider streams ``tool_call_chunks`` we register
the ``index`` and reuse the lc-id as the card id so live
``tool-input-delta`` events route without a downstream join).
Either way, the same id is preserved across
``tool-input-start`` / ``-delta`` / ``-available`` /
``tool-output-available`` for one call.
tool_name: The name of the tool being called.
langchain_tool_call_id: Optional authoritative LangChain
``tool_call.id``. When set, surfaces as
@ -636,6 +638,8 @@ class VercelStreamingService:
}
if langchain_tool_call_id:
payload["langchainToolCallId"] = langchain_tool_call_id
if metadata:
payload["metadata"] = metadata
return self._format_sse(payload)
def format_tool_input_delta(self, tool_call_id: str, input_text_delta: str) -> str:
@ -667,6 +671,7 @@ class VercelStreamingService:
input_data: dict[str, Any],
*,
langchain_tool_call_id: str | None = None,
metadata: dict[str, Any] | None = None,
) -> str:
"""
Format the completion of tool input.
@ -692,6 +697,8 @@ class VercelStreamingService:
}
if langchain_tool_call_id:
payload["langchainToolCallId"] = langchain_tool_call_id
if metadata:
payload["metadata"] = metadata
return self._format_sse(payload)
def format_tool_output_available(
@ -700,6 +707,7 @@ class VercelStreamingService:
output: Any,
*,
langchain_tool_call_id: str | None = None,
metadata: dict[str, Any] | None = None,
) -> str:
"""
Format tool execution output.
@ -726,6 +734,8 @@ class VercelStreamingService:
}
if langchain_tool_call_id:
payload["langchainToolCallId"] = langchain_tool_call_id
if metadata:
payload["metadata"] = metadata
return self._format_sse(payload)
# =========================================================================

View file

@ -0,0 +1,20 @@
"""Single-responsibility split of the streaming SSE protocol.
Layout:
* ``envelope/`` - SSE wire framing + ID generators
* ``emitter/`` - identity of the agent that emitted an event + runtime registry
* ``events/`` - one module per SSE event family
* ``service.py`` - composition root used when emitting chat SSE
* ``interrupt_correlation.py`` - id-aware lookup over LangGraph state
Naming on the wire:
* AI SDK protocol fields keep their existing camelCase
(``toolCallId``, ``messageId``, ``inputTextDelta``, ``langchainToolCallId``).
* Every SurfSense-added field uses ``snake_case``, including the
top-level ``emitted_by`` envelope and all inner ``data`` payloads.
Production chat uses ``app.services.new_streaming_service`` from
``app.tasks.chat.stream_new_chat`` and related routes.
"""
from __future__ import annotations

View file

@ -0,0 +1,29 @@
"""Identity of the agent that emitted a streamed event.
The wire field is ``emitted_by``; the Python identity is :class:`Emitter`.
``EmitterRegistry`` resolves which emitter owns a LangGraph event, with
LangGraph's own namespace metadata as the primary key and a parent_ids
walk as a fallback for cases where context vars don't propagate.
"""
from __future__ import annotations
from .emitter import (
MAIN_EMITTER,
Emitter,
EmitterLevel,
attach_emitted_by,
main_emitter,
subagent_emitter,
)
from .registry import EmitterRegistry
__all__ = [
"MAIN_EMITTER",
"Emitter",
"EmitterLevel",
"EmitterRegistry",
"attach_emitted_by",
"main_emitter",
"subagent_emitter",
]

View file

@ -0,0 +1,61 @@
"""Identity payload describing which agent produced a stream event."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Literal
EmitterLevel = Literal["main", "subagent"]
@dataclass(frozen=True)
class Emitter:
level: EmitterLevel
subagent_type: str | None = None
subagent_run_id: str | None = None
parent_tool_call_id: str | None = None
extra: dict[str, Any] = field(default_factory=dict)
def to_payload(self) -> dict[str, Any]:
payload: dict[str, Any] = {"level": self.level}
if self.subagent_type is not None:
payload["subagent_type"] = self.subagent_type
if self.subagent_run_id is not None:
payload["subagent_run_id"] = self.subagent_run_id
if self.parent_tool_call_id is not None:
payload["parent_tool_call_id"] = self.parent_tool_call_id
if self.extra:
payload.update(self.extra)
return payload
MAIN_EMITTER = Emitter(level="main")
def main_emitter() -> Emitter:
return MAIN_EMITTER
def subagent_emitter(
*,
subagent_type: str,
subagent_run_id: str,
parent_tool_call_id: str | None = None,
extra: dict[str, Any] | None = None,
) -> Emitter:
return Emitter(
level="subagent",
subagent_type=subagent_type,
subagent_run_id=subagent_run_id,
parent_tool_call_id=parent_tool_call_id,
extra=dict(extra or {}),
)
def attach_emitted_by(
payload: dict[str, Any], emitter: Emitter | None
) -> dict[str, Any]:
if emitter is None:
return payload
payload["emitted_by"] = emitter.to_payload()
return payload

View file

@ -0,0 +1,51 @@
"""Resolve which agent owns a streamed event from its LangGraph run lineage."""
from __future__ import annotations
from collections.abc import Iterable
from .emitter import Emitter, main_emitter
class EmitterRegistry:
def __init__(self) -> None:
self._by_run_id: dict[str, Emitter] = {}
def register(self, run_id: str, emitter: Emitter) -> None:
if not run_id:
return
self._by_run_id[run_id] = emitter
def unregister(self, run_id: str) -> Emitter | None:
if not run_id:
return None
return self._by_run_id.pop(run_id, None)
def get(self, run_id: str | None) -> Emitter | None:
if not run_id:
return None
return self._by_run_id.get(run_id)
def resolve(
self,
*,
run_id: str | None,
parent_ids: Iterable[str] | None,
) -> Emitter:
own = self.get(run_id)
if own is not None:
return own
if parent_ids:
for ancestor in reversed(list(parent_ids)):
emitter = self.get(ancestor)
if emitter is not None:
return emitter
return main_emitter()
def has_active_subagents(self) -> bool:
return any(
emitter.level == "subagent" for emitter in self._by_run_id.values()
)
def clear(self) -> None:
self._by_run_id.clear()

View file

@ -0,0 +1,23 @@
"""Wire framing layer."""
from __future__ import annotations
from .identifiers import (
generate_message_id,
generate_reasoning_id,
generate_subagent_run_id,
generate_text_id,
generate_tool_call_id,
)
from .sse import format_done, format_sse, get_response_headers
__all__ = [
"format_done",
"format_sse",
"generate_message_id",
"generate_reasoning_id",
"generate_subagent_run_id",
"generate_text_id",
"generate_tool_call_id",
"get_response_headers",
]

View file

@ -0,0 +1,25 @@
"""Prefixed UUID generators for stream parts."""
from __future__ import annotations
import uuid
def generate_message_id() -> str:
return f"msg_{uuid.uuid4().hex}"
def generate_text_id() -> str:
return f"text_{uuid.uuid4().hex}"
def generate_reasoning_id() -> str:
return f"reasoning_{uuid.uuid4().hex}"
def generate_tool_call_id() -> str:
return f"call_{uuid.uuid4().hex}"
def generate_subagent_run_id() -> str:
return f"subagent_{uuid.uuid4().hex}"

View file

@ -0,0 +1,25 @@
"""Server-Sent-Events wire framing."""
from __future__ import annotations
import json
from typing import Any
def format_sse(data: Any) -> str:
if isinstance(data, str):
return f"data: {data}\n\n"
return f"data: {json.dumps(data)}\n\n"
def format_done() -> str:
return "data: [DONE]\n\n"
def get_response_headers() -> dict[str, str]:
return {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"x-vercel-ai-ui-message-stream": "v1",
}

View file

@ -0,0 +1,29 @@
"""SSE event payload formatters, one module per event family."""
from __future__ import annotations
from . import (
action_log,
data,
error,
interrupt,
lifecycle,
reasoning,
source,
subagent_lifecycle,
text,
tool,
)
__all__ = [
"action_log",
"data",
"error",
"interrupt",
"lifecycle",
"reasoning",
"source",
"subagent_lifecycle",
"text",
"tool",
]

View file

@ -0,0 +1,24 @@
"""Action-log events relayed from ``ActionLogMiddleware`` custom dispatches."""
from __future__ import annotations
from typing import Any
from ..emitter import Emitter
from .data import format_data
def format_action_log(
payload: dict[str, Any],
*,
emitter: Emitter | None = None,
) -> str:
return format_data("action-log", payload, emitter=emitter)
def format_action_log_updated(
payload: dict[str, Any],
*,
emitter: Emitter | None = None,
) -> str:
return format_data("action-log-updated", payload, emitter=emitter)

View file

@ -0,0 +1,118 @@
"""Generic ``data-*`` envelopes and SurfSense-specific data parts.
Inner ``data`` dict fields use snake_case. Legacy ``threadId`` /
``messageId`` keys are preserved where they cross the AI SDK boundary.
"""
from __future__ import annotations
from typing import Any
from ..emitter import Emitter, attach_emitted_by
from ..envelope import format_sse
def format_data(
data_type: str,
data: Any,
*,
emitter: Emitter | None = None,
) -> str:
payload: dict[str, Any] = {"type": f"data-{data_type}", "data": data}
return format_sse(attach_emitted_by(payload, emitter))
def format_terminal_info(
text: str,
*,
message_type: str = "info",
emitter: Emitter | None = None,
) -> str:
return format_data(
"terminal-info",
{"text": text, "type": message_type},
emitter=emitter,
)
def format_further_questions(
questions: list[str],
*,
emitter: Emitter | None = None,
) -> str:
return format_data("further-questions", {"questions": questions}, emitter=emitter)
def format_thinking_step(
*,
step_id: str,
title: str,
status: str = "in_progress",
items: list[str] | None = None,
emitter: Emitter | None = None,
) -> str:
return format_data(
"thinking-step",
{
"id": step_id,
"title": title,
"status": status,
"items": items or [],
},
emitter=emitter,
)
def format_thread_title_update(
*,
thread_id: int,
title: str,
emitter: Emitter | None = None,
) -> str:
return format_data(
"thread-title-update",
{"threadId": thread_id, "title": title},
emitter=emitter,
)
def format_turn_info(
*,
chat_turn_id: str,
emitter: Emitter | None = None,
) -> str:
return format_data("turn-info", {"chat_turn_id": chat_turn_id}, emitter=emitter)
def format_turn_status(
*,
status: str,
emitter: Emitter | None = None,
) -> str:
return format_data("turn-status", {"status": status}, emitter=emitter)
def format_user_message_id(
*,
message_id: str,
turn_id: str,
emitter: Emitter | None = None,
) -> str:
return format_data(
"user-message-id",
{"message_id": message_id, "turn_id": turn_id},
emitter=emitter,
)
def format_assistant_message_id(
*,
message_id: str,
turn_id: str,
emitter: Emitter | None = None,
) -> str:
return format_data(
"assistant-message-id",
{"message_id": message_id, "turn_id": turn_id},
emitter=emitter,
)

View file

@ -0,0 +1,23 @@
"""Single terminal error path chat streaming must route through."""
from __future__ import annotations
from typing import Any
from ..emitter import Emitter, attach_emitted_by
from ..envelope import format_sse
def format_error(
error_text: str,
*,
error_code: str | None = None,
extra: dict[str, Any] | None = None,
emitter: Emitter | None = None,
) -> str:
payload: dict[str, Any] = {"type": "error", "errorText": error_text}
if error_code:
payload["errorCode"] = error_code
if extra:
payload.update(extra)
return format_sse(attach_emitted_by(payload, emitter))

View file

@ -0,0 +1,56 @@
"""Interrupt-request events with a single canonical payload shape."""
from __future__ import annotations
from typing import Any
from ..emitter import Emitter
from .data import format_data
def normalize_interrupt_payload(interrupt_value: dict[str, Any]) -> dict[str, Any]:
if "action_requests" in interrupt_value and "review_configs" in interrupt_value:
return interrupt_value
interrupt_type = interrupt_value.get("type", "unknown")
message = interrupt_value.get("message")
action = interrupt_value.get("action", {}) or {}
context = interrupt_value.get("context", {}) or {}
normalized: dict[str, Any] = {
"action_requests": [
{
"name": action.get("tool", "unknown_tool"),
"args": action.get("params", {}),
}
],
"review_configs": [
{
"action_name": action.get("tool", "unknown_tool"),
"allowed_decisions": ["approve", "edit", "reject"],
}
],
"interrupt_type": interrupt_type,
"context": context,
}
if message:
normalized["message"] = message
return normalized
def format_interrupt_request(
interrupt_value: dict[str, Any],
*,
interrupt_id: str | None = None,
pending_interrupt_count: int | None = None,
chat_turn_id: str | None = None,
emitter: Emitter | None = None,
) -> str:
payload = normalize_interrupt_payload(interrupt_value)
if interrupt_id is not None:
payload["interrupt_id"] = interrupt_id
if pending_interrupt_count is not None:
payload["pending_interrupt_count"] = pending_interrupt_count
if chat_turn_id is not None:
payload["chat_turn_id"] = chat_turn_id
return format_data("interrupt-request", payload, emitter=emitter)

View file

@ -0,0 +1,29 @@
"""High-level message and step lifecycle events.
Wire verbs are fixed by the AI SDK protocol (``start`` / ``finish`` for
the whole message, ``start-step`` / ``finish-step`` for each step).
Python helpers always read ``format_<entity>_<verb>`` so pairs are
visible at the call site.
"""
from __future__ import annotations
from ..emitter import Emitter, attach_emitted_by
from ..envelope import format_sse
def format_message_start(message_id: str, *, emitter: Emitter | None = None) -> str:
payload = {"type": "start", "messageId": message_id}
return format_sse(attach_emitted_by(payload, emitter))
def format_message_finish(*, emitter: Emitter | None = None) -> str:
return format_sse(attach_emitted_by({"type": "finish"}, emitter))
def format_step_start(*, emitter: Emitter | None = None) -> str:
return format_sse(attach_emitted_by({"type": "start-step"}, emitter))
def format_step_finish(*, emitter: Emitter | None = None) -> str:
return format_sse(attach_emitted_by({"type": "finish-step"}, emitter))

View file

@ -0,0 +1,36 @@
"""Reasoning-block streaming events."""
from __future__ import annotations
from ..emitter import Emitter, attach_emitted_by
from ..envelope import format_sse
def format_reasoning_start(
reasoning_id: str, *, emitter: Emitter | None = None
) -> str:
return format_sse(
attach_emitted_by({"type": "reasoning-start", "id": reasoning_id}, emitter)
)
def format_reasoning_delta(
reasoning_id: str,
delta: str,
*,
emitter: Emitter | None = None,
) -> str:
return format_sse(
attach_emitted_by(
{"type": "reasoning-delta", "id": reasoning_id, "delta": delta},
emitter,
)
)
def format_reasoning_end(
reasoning_id: str, *, emitter: Emitter | None = None
) -> str:
return format_sse(
attach_emitted_by({"type": "reasoning-end", "id": reasoning_id}, emitter)
)

View file

@ -0,0 +1,59 @@
"""Source and file reference events."""
from __future__ import annotations
from typing import Any
from ..emitter import Emitter, attach_emitted_by
from ..envelope import format_sse
def format_source_url(
url: str,
*,
source_id: str | None = None,
title: str | None = None,
emitter: Emitter | None = None,
) -> str:
payload: dict[str, Any] = {
"type": "source-url",
"sourceId": source_id or url,
"url": url,
}
if title:
payload["title"] = title
return format_sse(attach_emitted_by(payload, emitter))
def format_source_document(
source_id: str,
*,
media_type: str = "file",
title: str | None = None,
description: str | None = None,
emitter: Emitter | None = None,
) -> str:
payload: dict[str, Any] = {
"type": "source-document",
"sourceId": source_id,
"mediaType": media_type,
}
if title:
payload["title"] = title
if description:
payload["description"] = description
return format_sse(attach_emitted_by(payload, emitter))
def format_file(
url: str,
media_type: str,
*,
emitter: Emitter | None = None,
) -> str:
payload: dict[str, Any] = {
"type": "file",
"url": url,
"mediaType": media_type,
}
return format_sse(attach_emitted_by(payload, emitter))

View file

@ -0,0 +1,86 @@
"""Sub-agent lifecycle events the FE pairs into one timeline lane.
A sub-agent run is a high-level boundary (a whole agent invocation),
so we use the ``start`` / ``finish`` verb pair, matching how the AI SDK
spells message- and step-level lifecycles.
"""
from __future__ import annotations
from typing import Any
from ..emitter import Emitter
from .data import format_data
def format_subagent_start(
*,
subagent_run_id: str,
subagent_type: str,
parent_tool_call_id: str,
chat_turn_id: str | None = None,
description: str | None = None,
started_at: str | None = None,
emitter: Emitter | None = None,
) -> str:
payload: dict[str, Any] = {
"subagent_run_id": subagent_run_id,
"subagent_type": subagent_type,
"parent_tool_call_id": parent_tool_call_id,
}
if chat_turn_id is not None:
payload["chat_turn_id"] = chat_turn_id
if description is not None:
payload["description"] = description
if started_at is not None:
payload["started_at"] = started_at
return format_data("subagent-start", payload, emitter=emitter)
def format_subagent_finish(
*,
subagent_run_id: str,
subagent_type: str,
parent_tool_call_id: str,
status: str = "completed",
ended_at: str | None = None,
duration_ms: int | None = None,
emitter: Emitter | None = None,
) -> str:
payload: dict[str, Any] = {
"subagent_run_id": subagent_run_id,
"subagent_type": subagent_type,
"parent_tool_call_id": parent_tool_call_id,
"status": status,
}
if ended_at is not None:
payload["ended_at"] = ended_at
if duration_ms is not None:
payload["duration_ms"] = duration_ms
return format_data("subagent-finish", payload, emitter=emitter)
def format_subagent_error(
*,
subagent_run_id: str,
subagent_type: str,
parent_tool_call_id: str,
error_text: str,
error_type: str | None = None,
ended_at: str | None = None,
duration_ms: int | None = None,
emitter: Emitter | None = None,
) -> str:
payload: dict[str, Any] = {
"subagent_run_id": subagent_run_id,
"subagent_type": subagent_type,
"parent_tool_call_id": parent_tool_call_id,
"error_text": error_text,
}
if error_type is not None:
payload["error_type"] = error_type
if ended_at is not None:
payload["ended_at"] = ended_at
if duration_ms is not None:
payload["duration_ms"] = duration_ms
return format_data("subagent-error", payload, emitter=emitter)

View file

@ -0,0 +1,31 @@
"""Text-block streaming events."""
from __future__ import annotations
from ..emitter import Emitter, attach_emitted_by
from ..envelope import format_sse
def format_text_start(text_id: str, *, emitter: Emitter | None = None) -> str:
return format_sse(
attach_emitted_by({"type": "text-start", "id": text_id}, emitter)
)
def format_text_delta(
text_id: str,
delta: str,
*,
emitter: Emitter | None = None,
) -> str:
return format_sse(
attach_emitted_by(
{"type": "text-delta", "id": text_id, "delta": delta}, emitter
)
)
def format_text_end(text_id: str, *, emitter: Emitter | None = None) -> str:
return format_sse(
attach_emitted_by({"type": "text-end", "id": text_id}, emitter)
)

View file

@ -0,0 +1,80 @@
"""Tool-call streaming events.
``toolCallId`` and ``langchainToolCallId`` are AI SDK protocol fields
and stay camelCase. Sub-agent provenance rides on the snake_case
top-level ``emitted_by`` envelope added by :func:`attach_emitted_by`.
"""
from __future__ import annotations
from typing import Any
from ..emitter import Emitter, attach_emitted_by
from ..envelope import format_sse
def format_tool_input_start(
tool_call_id: str,
tool_name: str,
*,
langchain_tool_call_id: str | None = None,
emitter: Emitter | None = None,
) -> str:
payload: dict[str, Any] = {
"type": "tool-input-start",
"toolCallId": tool_call_id,
"toolName": tool_name,
}
if langchain_tool_call_id:
payload["langchainToolCallId"] = langchain_tool_call_id
return format_sse(attach_emitted_by(payload, emitter))
def format_tool_input_delta(
tool_call_id: str,
input_text_delta: str,
*,
emitter: Emitter | None = None,
) -> str:
payload: dict[str, Any] = {
"type": "tool-input-delta",
"toolCallId": tool_call_id,
"inputTextDelta": input_text_delta,
}
return format_sse(attach_emitted_by(payload, emitter))
def format_tool_input_available(
tool_call_id: str,
tool_name: str,
input_data: dict[str, Any],
*,
langchain_tool_call_id: str | None = None,
emitter: Emitter | None = None,
) -> str:
payload: dict[str, Any] = {
"type": "tool-input-available",
"toolCallId": tool_call_id,
"toolName": tool_name,
"input": input_data,
}
if langchain_tool_call_id:
payload["langchainToolCallId"] = langchain_tool_call_id
return format_sse(attach_emitted_by(payload, emitter))
def format_tool_output_available(
tool_call_id: str,
output: Any,
*,
langchain_tool_call_id: str | None = None,
emitter: Emitter | None = None,
) -> str:
payload: dict[str, Any] = {
"type": "tool-output-available",
"toolCallId": tool_call_id,
"output": output,
}
if langchain_tool_call_id:
payload["langchainToolCallId"] = langchain_tool_call_id
return format_sse(attach_emitted_by(payload, emitter))

View file

@ -0,0 +1,84 @@
"""Id-aware lookup of pending LangGraph interrupts (replaces first-wins)."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
@dataclass(frozen=True)
class PendingInterrupt:
interrupt_id: str | None
value: dict[str, Any]
source_task_id: str | None = None
def list_pending_interrupts(state: Any) -> list[PendingInterrupt]:
out: list[PendingInterrupt] = []
for task in getattr(state, "tasks", None) or ():
task_id = _safe_str(getattr(task, "id", None))
for it in getattr(task, "interrupts", None) or ():
value = _coerce_interrupt_value(it)
if value is None:
continue
interrupt_id = _safe_str(getattr(it, "id", None))
out.append(
PendingInterrupt(
interrupt_id=interrupt_id,
value=value,
source_task_id=task_id,
)
)
for it in getattr(state, "interrupts", None) or ():
value = _coerce_interrupt_value(it)
if value is None:
continue
interrupt_id = _safe_str(getattr(it, "id", None))
out.append(PendingInterrupt(interrupt_id=interrupt_id, value=value))
return out
def get_pending_interrupt_by_id(
state: Any, interrupt_id: str
) -> PendingInterrupt | None:
for pending in list_pending_interrupts(state):
if pending.interrupt_id == interrupt_id:
return pending
return None
def get_pending_interrupt_for_tool_call(
state: Any, tool_call_id: str
) -> PendingInterrupt | None:
for pending in list_pending_interrupts(state):
actions = pending.value.get("action_requests")
if not isinstance(actions, list):
continue
for action in actions:
if not isinstance(action, dict):
continue
if action.get("tool_call_id") == tool_call_id:
return pending
return None
def first_pending_interrupt(state: Any) -> PendingInterrupt | None:
"""Explicit opt-in to legacy first-wins; prefer the id-aware helpers above."""
pending = list_pending_interrupts(state)
return pending[0] if pending else None
def _coerce_interrupt_value(item: Any) -> dict[str, Any] | None:
if isinstance(item, dict):
return item if item else None
value = getattr(item, "value", None)
if isinstance(value, dict):
return value if value else None
return None
def _safe_str(value: Any) -> str | None:
return value if isinstance(value, str) and value else None

View file

@ -0,0 +1,414 @@
"""Composition root: bundles every formatter + a per-invocation emitter registry."""
from __future__ import annotations
from collections.abc import Iterable
from typing import Any
from . import envelope
from .emitter import Emitter, EmitterRegistry
from .events import (
action_log,
data,
error,
interrupt,
lifecycle,
reasoning,
source,
subagent_lifecycle,
text,
tool,
)
class StreamingService:
def __init__(self) -> None:
self._message_id: str | None = None
self.emitter_registry = EmitterRegistry()
@property
def message_id(self) -> str | None:
return self._message_id
def begin_message(self, message_id: str | None = None) -> str:
self._message_id = message_id or envelope.generate_message_id()
return self._message_id
@staticmethod
def generate_text_id() -> str:
return envelope.generate_text_id()
@staticmethod
def generate_reasoning_id() -> str:
return envelope.generate_reasoning_id()
@staticmethod
def generate_tool_call_id() -> str:
return envelope.generate_tool_call_id()
@staticmethod
def generate_subagent_run_id() -> str:
return envelope.generate_subagent_run_id()
@staticmethod
def get_response_headers() -> dict[str, str]:
return envelope.get_response_headers()
@staticmethod
def format_done() -> str:
return envelope.format_done()
def resolve_emitter(
self,
*,
run_id: str | None,
parent_ids: Iterable[str] | None,
) -> Emitter:
return self.emitter_registry.resolve(run_id=run_id, parent_ids=parent_ids)
def format_message_start(
self,
message_id: str | None = None,
*,
emitter: Emitter | None = None,
) -> str:
chosen = self.begin_message(message_id)
return lifecycle.format_message_start(chosen, emitter=emitter)
def format_message_finish(self, *, emitter: Emitter | None = None) -> str:
return lifecycle.format_message_finish(emitter=emitter)
def format_step_start(self, *, emitter: Emitter | None = None) -> str:
return lifecycle.format_step_start(emitter=emitter)
def format_step_finish(self, *, emitter: Emitter | None = None) -> str:
return lifecycle.format_step_finish(emitter=emitter)
def format_text_start(
self, text_id: str, *, emitter: Emitter | None = None
) -> str:
return text.format_text_start(text_id, emitter=emitter)
def format_text_delta(
self, text_id: str, delta: str, *, emitter: Emitter | None = None
) -> str:
return text.format_text_delta(text_id, delta, emitter=emitter)
def format_text_end(
self, text_id: str, *, emitter: Emitter | None = None
) -> str:
return text.format_text_end(text_id, emitter=emitter)
def format_reasoning_start(
self, reasoning_id: str, *, emitter: Emitter | None = None
) -> str:
return reasoning.format_reasoning_start(reasoning_id, emitter=emitter)
def format_reasoning_delta(
self,
reasoning_id: str,
delta: str,
*,
emitter: Emitter | None = None,
) -> str:
return reasoning.format_reasoning_delta(reasoning_id, delta, emitter=emitter)
def format_reasoning_end(
self, reasoning_id: str, *, emitter: Emitter | None = None
) -> str:
return reasoning.format_reasoning_end(reasoning_id, emitter=emitter)
def format_tool_input_start(
self,
tool_call_id: str,
tool_name: str,
*,
langchain_tool_call_id: str | None = None,
emitter: Emitter | None = None,
) -> str:
return tool.format_tool_input_start(
tool_call_id,
tool_name,
langchain_tool_call_id=langchain_tool_call_id,
emitter=emitter,
)
def format_tool_input_delta(
self,
tool_call_id: str,
input_text_delta: str,
*,
emitter: Emitter | None = None,
) -> str:
return tool.format_tool_input_delta(
tool_call_id, input_text_delta, emitter=emitter
)
def format_tool_input_available(
self,
tool_call_id: str,
tool_name: str,
input_data: dict[str, Any],
*,
langchain_tool_call_id: str | None = None,
emitter: Emitter | None = None,
) -> str:
return tool.format_tool_input_available(
tool_call_id,
tool_name,
input_data,
langchain_tool_call_id=langchain_tool_call_id,
emitter=emitter,
)
def format_tool_output_available(
self,
tool_call_id: str,
output: Any,
*,
langchain_tool_call_id: str | None = None,
emitter: Emitter | None = None,
) -> str:
return tool.format_tool_output_available(
tool_call_id,
output,
langchain_tool_call_id=langchain_tool_call_id,
emitter=emitter,
)
def format_source_url(
self,
url: str,
*,
source_id: str | None = None,
title: str | None = None,
emitter: Emitter | None = None,
) -> str:
return source.format_source_url(
url, source_id=source_id, title=title, emitter=emitter
)
def format_source_document(
self,
source_id: str,
*,
media_type: str = "file",
title: str | None = None,
description: str | None = None,
emitter: Emitter | None = None,
) -> str:
return source.format_source_document(
source_id,
media_type=media_type,
title=title,
description=description,
emitter=emitter,
)
def format_file(
self, url: str, media_type: str, *, emitter: Emitter | None = None
) -> str:
return source.format_file(url, media_type, emitter=emitter)
def format_data(
self, data_type: str, payload: Any, *, emitter: Emitter | None = None
) -> str:
return data.format_data(data_type, payload, emitter=emitter)
def format_terminal_info(
self,
text_value: str,
*,
message_type: str = "info",
emitter: Emitter | None = None,
) -> str:
return data.format_terminal_info(
text_value, message_type=message_type, emitter=emitter
)
def format_further_questions(
self,
questions: list[str],
*,
emitter: Emitter | None = None,
) -> str:
return data.format_further_questions(questions, emitter=emitter)
def format_thinking_step(
self,
*,
step_id: str,
title: str,
status: str = "in_progress",
items: list[str] | None = None,
emitter: Emitter | None = None,
) -> str:
return data.format_thinking_step(
step_id=step_id,
title=title,
status=status,
items=items,
emitter=emitter,
)
def format_thread_title_update(
self,
*,
thread_id: int,
title: str,
emitter: Emitter | None = None,
) -> str:
return data.format_thread_title_update(
thread_id=thread_id, title=title, emitter=emitter
)
def format_turn_info(
self,
*,
chat_turn_id: str,
emitter: Emitter | None = None,
) -> str:
return data.format_turn_info(chat_turn_id=chat_turn_id, emitter=emitter)
def format_turn_status(
self,
*,
status: str,
emitter: Emitter | None = None,
) -> str:
return data.format_turn_status(status=status, emitter=emitter)
def format_user_message_id(
self,
*,
message_id: str,
turn_id: str,
emitter: Emitter | None = None,
) -> str:
return data.format_user_message_id(
message_id=message_id, turn_id=turn_id, emitter=emitter
)
def format_assistant_message_id(
self,
*,
message_id: str,
turn_id: str,
emitter: Emitter | None = None,
) -> str:
return data.format_assistant_message_id(
message_id=message_id, turn_id=turn_id, emitter=emitter
)
def format_error(
self,
error_text: str,
*,
error_code: str | None = None,
extra: dict[str, Any] | None = None,
emitter: Emitter | None = None,
) -> str:
return error.format_error(
error_text,
error_code=error_code,
extra=extra,
emitter=emitter,
)
def format_interrupt_request(
self,
interrupt_value: dict[str, Any],
*,
interrupt_id: str | None = None,
pending_interrupt_count: int | None = None,
chat_turn_id: str | None = None,
emitter: Emitter | None = None,
) -> str:
return interrupt.format_interrupt_request(
interrupt_value,
interrupt_id=interrupt_id,
pending_interrupt_count=pending_interrupt_count,
chat_turn_id=chat_turn_id,
emitter=emitter,
)
def format_subagent_start(
self,
*,
subagent_run_id: str,
subagent_type: str,
parent_tool_call_id: str,
chat_turn_id: str | None = None,
description: str | None = None,
started_at: str | None = None,
emitter: Emitter | None = None,
) -> str:
return subagent_lifecycle.format_subagent_start(
subagent_run_id=subagent_run_id,
subagent_type=subagent_type,
parent_tool_call_id=parent_tool_call_id,
chat_turn_id=chat_turn_id,
description=description,
started_at=started_at,
emitter=emitter,
)
def format_subagent_finish(
self,
*,
subagent_run_id: str,
subagent_type: str,
parent_tool_call_id: str,
status: str = "completed",
ended_at: str | None = None,
duration_ms: int | None = None,
emitter: Emitter | None = None,
) -> str:
return subagent_lifecycle.format_subagent_finish(
subagent_run_id=subagent_run_id,
subagent_type=subagent_type,
parent_tool_call_id=parent_tool_call_id,
status=status,
ended_at=ended_at,
duration_ms=duration_ms,
emitter=emitter,
)
def format_subagent_error(
self,
*,
subagent_run_id: str,
subagent_type: str,
parent_tool_call_id: str,
error_text: str,
error_type: str | None = None,
ended_at: str | None = None,
duration_ms: int | None = None,
emitter: Emitter | None = None,
) -> str:
return subagent_lifecycle.format_subagent_error(
subagent_run_id=subagent_run_id,
subagent_type=subagent_type,
parent_tool_call_id=parent_tool_call_id,
error_text=error_text,
error_type=error_type,
ended_at=ended_at,
duration_ms=duration_ms,
emitter=emitter,
)
def format_action_log(
self,
payload: dict[str, Any],
*,
emitter: Emitter | None = None,
) -> str:
return action_log.format_action_log(payload, emitter=emitter)
def format_action_log_updated(
self,
payload: dict[str, Any],
*,
emitter: Emitter | None = None,
) -> str:
return action_log.format_action_log_updated(payload, emitter=emitter)

View file

@ -51,6 +51,21 @@ logger = logging.getLogger(__name__)
_MEANINGFUL_PART_TYPES: frozenset[str] = frozenset({"text", "reasoning", "tool-call"})
def _merge_tool_part_metadata(part: dict[str, Any], metadata: dict[str, Any] | None) -> None:
"""Shallow-merge ``metadata`` into ``part["metadata"]``; first key wins.
Used for tool-call linkage (``spanId``, ``thinkingStepId``, ): a later
event must not overwrite an existing key so chunk order vs ``on_tool_start``
stays stable.
"""
if not metadata:
return
md = part.setdefault("metadata", {})
for k, v in metadata.items():
if k not in md:
md[k] = v
class AssistantContentBuilder:
"""Server-side projection of ``surfsense_web/lib/chat/streaming-state.ts``.
@ -61,6 +76,7 @@ class AssistantContentBuilder:
| { type: "reasoning"; text: string }
| { type: "tool-call"; toolCallId: str; toolName: str;
args: dict; result?: any; argsText?: str; langchainToolCallId?: str;
metadata?: { spanId?: str; thinkingStepId?: str; ... };
state?: "aborted" }
| { type: "data-thinking-steps"; data: { steps: ThinkingStepData[] } }
| { type: "data-step-separator"; data: { stepIndex: int } }
@ -85,8 +101,8 @@ class AssistantContentBuilder:
self._current_text_idx: int = -1
self._current_reasoning_idx: int = -1
# ``ui_id``-keyed indexes for tool-call parts. ``ui_id`` is the
# synthetic ``call_<run_id>`` (legacy) or the LangChain
# ``tool_call.id`` (parity_v2) — same key the streaming layer
# synthetic ``call_<run_id>`` (chunk fallback) or the LangChain
# ``tool_call.id`` (indexed chunk path) — same key the streaming layer
# threads through every ``tool-input-*`` / ``tool-output-*`` event.
self._tool_call_idx_by_ui_id: dict[str, int] = {}
# Live argsText accumulator (concatenated ``tool-input-delta`` chunks)
@ -177,21 +193,27 @@ class AssistantContentBuilder:
ui_id: str,
tool_name: str,
langchain_tool_call_id: str | None,
*,
metadata: dict[str, Any] | None = None,
) -> None:
"""Register a tool-call card. Args are filled in by later events."""
"""Register a tool-call card. Args are filled in by later events.
Optional ``metadata`` (``spanId``, ``thinkingStepId``, ) is stored on the
part; duplicate ``tool-input-start`` calls merge with first-key-wins.
"""
if not ui_id:
return
# Skip duplicate registration: parity_v2 may emit
# Skip duplicate registration: the stream may emit
# ``tool-input-start`` from both ``on_chat_model_stream``
# (when tool_call_chunks register a name) and ``on_tool_start``
# (the canonical path). The FE de-dupes via ``toolCallIndices``;
# we mirror that here.
if ui_id in self._tool_call_idx_by_ui_id:
if langchain_tool_call_id:
idx = self._tool_call_idx_by_ui_id[ui_id]
part = self.parts[idx]
if not part.get("langchainToolCallId"):
part["langchainToolCallId"] = langchain_tool_call_id
idx = self._tool_call_idx_by_ui_id[ui_id]
part = self.parts[idx]
if langchain_tool_call_id and not part.get("langchainToolCallId"):
part["langchainToolCallId"] = langchain_tool_call_id
_merge_tool_part_metadata(part, metadata)
return
part: dict[str, Any] = {
@ -202,6 +224,8 @@ class AssistantContentBuilder:
}
if langchain_tool_call_id:
part["langchainToolCallId"] = langchain_tool_call_id
if metadata:
part["metadata"] = dict(metadata)
self.parts.append(part)
self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1
@ -235,6 +259,8 @@ class AssistantContentBuilder:
tool_name: str,
args: dict[str, Any],
langchain_tool_call_id: str | None,
*,
metadata: dict[str, Any] | None = None,
) -> None:
"""Finalize the tool-call card's input.
@ -243,7 +269,7 @@ class AssistantContentBuilder:
pretty-printed JSON, sets the full ``args`` dict, and backfills
``langchainToolCallId`` if it wasn't known at ``tool-input-start`` time.
Also creates the card if no prior ``tool-input-start`` registered it
(legacy parity_v2-OFF / late-registration paths).
(late-registration when no prior ``tool-input-start``).
"""
if not ui_id:
return
@ -264,6 +290,7 @@ class AssistantContentBuilder:
part["argsText"] = final_args_text
if langchain_tool_call_id and not part.get("langchainToolCallId"):
part["langchainToolCallId"] = langchain_tool_call_id
_merge_tool_part_metadata(part, metadata)
return
# No prior tool-input-start: register the card now.
@ -276,6 +303,7 @@ class AssistantContentBuilder:
}
if langchain_tool_call_id:
new_part["langchainToolCallId"] = langchain_tool_call_id
_merge_tool_part_metadata(new_part, metadata)
self.parts.append(new_part)
self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1
@ -287,6 +315,8 @@ class AssistantContentBuilder:
ui_id: str,
output: Any,
langchain_tool_call_id: str | None,
*,
metadata: dict[str, Any] | None = None,
) -> None:
"""Attach the tool's output (``result``) to the matching card.
@ -305,6 +335,7 @@ class AssistantContentBuilder:
part["result"] = output
if langchain_tool_call_id and not part.get("langchainToolCallId"):
part["langchainToolCallId"] = langchain_tool_call_id
_merge_tool_part_metadata(part, metadata)
# ------------------------------------------------------------------
# Thinking steps & step separators
@ -316,6 +347,8 @@ class AssistantContentBuilder:
title: str,
status: str,
items: list[str] | None,
*,
metadata: dict[str, Any] | None = None,
) -> None:
"""Update / insert the singleton ``data-thinking-steps`` part.
@ -328,12 +361,14 @@ class AssistantContentBuilder:
if not step_id:
return
new_step = {
new_step: dict[str, Any] = {
"id": step_id,
"title": title or "",
"status": status or "in_progress",
"items": list(items) if items else [],
}
if metadata:
new_step["metadata"] = dict(metadata)
# Find existing data-thinking-steps part.
existing_idx = -1
@ -347,6 +382,8 @@ class AssistantContentBuilder:
replaced = False
for i, step in enumerate(current_steps):
if step.get("id") == step_id:
if not metadata and step.get("metadata"):
new_step["metadata"] = dict(step["metadata"])
current_steps[i] = new_step
replaced = True
break

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,3 @@
"""Chat streaming helpers (e.g. LangGraph → SSE relay under ``graph_stream``)."""
from __future__ import annotations

View file

@ -0,0 +1,3 @@
"""Error classification, structured logging, and terminal-error SSE emission."""
from __future__ import annotations

View file

@ -0,0 +1,187 @@
"""Classify stream exceptions for logging and client error payloads."""
from __future__ import annotations
import json
import logging
import time
from typing import Any, Literal
from app.agents.new_chat.errors import BusyError
from app.agents.new_chat.middleware.busy_mutex import (
get_cancel_state,
is_cancel_requested,
)
TURN_CANCELLING_INITIAL_DELAY_MS = 200
TURN_CANCELLING_BACKOFF_FACTOR = 2
TURN_CANCELLING_MAX_DELAY_MS = 1500
def compute_turn_cancelling_retry_delay(attempt: int) -> int:
if attempt < 1:
attempt = 1
delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
)
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
def log_chat_stream_error(
*,
flow: Literal["new", "resume", "regenerate"],
error_kind: str,
error_code: str | None,
severity: Literal["info", "warn", "error"],
is_expected: bool,
request_id: str | None,
thread_id: int | None,
search_space_id: int | None,
user_id: str | None,
message: str,
extra: dict[str, Any] | None = None,
) -> None:
payload: dict[str, Any] = {
"event": "chat_stream_error",
"flow": flow,
"error_kind": error_kind,
"error_code": error_code,
"severity": severity,
"is_expected": is_expected,
"request_id": request_id or "unknown",
"thread_id": thread_id,
"search_space_id": search_space_id,
"user_id": user_id,
"message": message,
}
if extra:
payload.update(extra)
logger = logging.getLogger(__name__)
rendered = json.dumps(payload, ensure_ascii=False)
if severity == "error":
logger.error("[chat_stream_error] %s", rendered)
elif severity == "warn":
logger.warning("[chat_stream_error] %s", rendered)
else:
logger.info("[chat_stream_error] %s", rendered)
def _parse_error_payload(message: str) -> dict[str, Any] | None:
candidates = [message]
first_brace_idx = message.find("{")
if first_brace_idx >= 0:
candidates.append(message[first_brace_idx:])
for candidate in candidates:
try:
parsed = json.loads(candidate)
if isinstance(parsed, dict):
return parsed
except Exception:
continue
return None
def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None:
if not isinstance(parsed, dict):
return None
candidates: list[Any] = [parsed.get("code")]
nested = parsed.get("error")
if isinstance(nested, dict):
candidates.append(nested.get("code"))
for value in candidates:
try:
if value is None:
continue
return int(value)
except Exception:
continue
return None
def is_provider_rate_limited(exc: BaseException) -> bool:
"""Return True if the exception looks like an upstream HTTP 429 / rate limit."""
raw = str(exc)
lowered = raw.lower()
if "ratelimit" in type(exc).__name__.lower():
return True
parsed = _parse_error_payload(raw)
provider_code = _extract_provider_error_code(parsed)
if provider_code == 429:
return True
provider_error_type = ""
if parsed:
top_type = parsed.get("type")
if isinstance(top_type, str):
provider_error_type = top_type.lower()
nested = parsed.get("error")
if isinstance(nested, dict):
nested_type = nested.get("type")
if isinstance(nested_type, str):
provider_error_type = nested_type.lower()
if provider_error_type == "rate_limit_error":
return True
return (
"rate limited" in lowered
or "rate-limited" in lowered
or "temporarily rate-limited upstream" in lowered
)
def classify_stream_exception(
exc: Exception,
*,
flow_label: str,
) -> tuple[
str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None
]:
"""Return kind, code, severity, expected flag, message, and optional extra dict."""
raw = str(exc)
if isinstance(exc, BusyError) or "Thread is busy with another request" in raw:
busy_thread_id = str(exc.request_id) if isinstance(exc, BusyError) else None
if busy_thread_id and is_cancel_requested(busy_thread_id):
cancel_state = get_cancel_state(busy_thread_id)
attempt = cancel_state[0] if cancel_state else 1
retry_after_ms = compute_turn_cancelling_retry_delay(attempt)
retry_after_at = int(time.time() * 1000) + retry_after_ms
return (
"thread_busy",
"TURN_CANCELLING",
"info",
True,
"A previous response is still stopping. Please try again in a moment.",
{
"retry_after_ms": retry_after_ms,
"retry_after_at": retry_after_at,
},
)
return (
"thread_busy",
"THREAD_BUSY",
"warn",
True,
"Another response is still finishing for this thread. Please try again in a moment.",
None,
)
if is_provider_rate_limited(exc):
return (
"rate_limited",
"RATE_LIMITED",
"warn",
True,
"This model is temporarily rate-limited. Please try again in a few seconds or switch models.",
None,
)
return (
"server_error",
"SERVER_ERROR",
"error",
False,
f"Error during {flow_label}: {raw}",
None,
)

View file

@ -0,0 +1,38 @@
"""Emit one terminal error SSE frame and log via the stream error classifier."""
from __future__ import annotations
from typing import Any, Literal
from .classifier import log_chat_stream_error
def emit_stream_terminal_error(
*,
streaming_service: Any,
flow: Literal["new", "resume", "regenerate"],
request_id: str | None,
thread_id: int,
search_space_id: int,
user_id: str | None,
message: str,
error_kind: str = "server_error",
error_code: str = "SERVER_ERROR",
severity: Literal["info", "warn", "error"] = "error",
is_expected: bool = False,
extra: dict[str, Any] | None = None,
) -> str:
log_chat_stream_error(
flow=flow,
error_kind=error_kind,
error_code=error_code,
severity=severity,
is_expected=is_expected,
request_id=request_id,
thread_id=thread_id,
search_space_id=search_space_id,
user_id=user_id,
message=message,
extra=extra,
)
return streaming_service.format_error(message, error_code=error_code, extra=extra)

View file

@ -0,0 +1,21 @@
"""LangGraph ``astream_events`` → SSE (``stream_output`` + ``StreamingResult``).
Imports are lazy to avoid a circular import with ``relay.event_relay``.
"""
from __future__ import annotations
__all__ = ["StreamingResult", "stream_output"]
def __getattr__(name: str):
if name == "stream_output":
from app.tasks.chat.streaming.graph_stream.event_stream import stream_output
return stream_output
if name == "StreamingResult":
from app.tasks.chat.streaming.graph_stream.result import StreamingResult
return StreamingResult
msg = f"module {__name__!r} has no attribute {name!r}"
raise AttributeError(msg)

View file

@ -0,0 +1,51 @@
"""Run LangGraph event streams through ``EventRelay``."""
from __future__ import annotations
from collections.abc import AsyncIterator
from typing import Any
from app.tasks.chat.streaming.graph_stream.result import StreamingResult
from app.tasks.chat.streaming.relay.event_relay import EventRelay
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
async def stream_output(
*,
agent: Any,
config: dict[str, Any],
input_data: Any,
streaming_service: Any,
result: StreamingResult,
step_prefix: str = "thinking",
initial_step_id: str | None = None,
initial_step_title: str = "",
initial_step_items: list[str] | None = None,
content_builder: Any | None = None,
runtime_context: Any = None,
) -> AsyncIterator[str]:
"""Yield SSE frames from agent ``astream_events`` via ``EventRelay``."""
state = AgentEventRelayState.for_invocation(
initial_step_id=initial_step_id,
initial_step_title=initial_step_title,
initial_step_items=initial_step_items,
)
astream_kwargs: dict[str, Any] = {"config": config, "version": "v2"}
if runtime_context is not None:
astream_kwargs["context"] = runtime_context
events = agent.astream_events(input_data, **astream_kwargs)
relay = EventRelay(streaming_service=streaming_service)
async for frame in relay.relay(
events,
state=state,
result=result,
step_prefix=step_prefix,
content_builder=content_builder,
config=config,
):
yield frame
result.accumulated_text = state.accumulated_text
result.agent_called_update_memory = state.called_update_memory

View file

@ -0,0 +1,28 @@
"""Mutable facts collected while relaying one agent stream (``stream_output``)."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
@dataclass
class StreamingResult:
accumulated_text: str = ""
is_interrupted: bool = False
interrupt_value: dict[str, Any] | None = None
sandbox_files: list[str] = field(default_factory=list)
agent_called_update_memory: bool = False
request_id: str | None = None
turn_id: str = ""
filesystem_mode: str = "cloud"
client_platform: str = "web"
intent_detected: str = "chat_only"
intent_confidence: float = 0.0
write_attempted: bool = False
write_succeeded: bool = False
verification_succeeded: bool = False
commit_gate_passed: bool = True
commit_gate_reason: str = ""
assistant_message_id: int | None = None
content_builder: Any | None = field(default=None, repr=False)

View file

@ -0,0 +1,3 @@
"""LangGraph stream handlers by event kind."""
from __future__ import annotations

View file

@ -0,0 +1,23 @@
"""Close open text when a LangGraph chain or agent node finishes."""
from __future__ import annotations
from collections.abc import Iterator
from typing import Any
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
def iter_chain_end_frames(
_event: dict[str, Any],
*,
state: AgentEventRelayState,
streaming_service: Any,
content_builder: Any | None,
) -> Iterator[str]:
"""Close the open text stream if one is open."""
if state.current_text_id is not None:
yield streaming_service.format_text_end(state.current_text_id)
if content_builder is not None:
content_builder.on_text_end(state.current_text_id)
state.current_text_id = None

View file

@ -0,0 +1,159 @@
"""Chat model stream: text, reasoning, and tool-call chunk SSE."""
from __future__ import annotations
from collections.abc import Iterator
from typing import Any
from app.tasks.chat.streaming.helpers.chunk_parts import extract_chunk_parts
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
from app.tasks.chat.streaming.relay.task_span import ensure_pending_task_span_for_lc
from app.tasks.chat.streaming.relay.thinking_step_completion import (
complete_active_thinking_step,
)
def iter_chat_model_stream_frames(
event: dict[str, Any],
*,
state: AgentEventRelayState,
streaming_service: Any,
content_builder: Any | None,
step_prefix: str,
) -> Iterator[str]:
"""SSE frames for one chat-model chunk."""
if state.active_tool_depth > 0:
return
if "surfsense:internal" in event.get("tags", []):
return
chunk = event.get("data", {}).get("chunk")
if not chunk:
return
parts = extract_chunk_parts(chunk)
reasoning_delta = parts["reasoning"]
text_delta = parts["text"]
if reasoning_delta:
if state.current_text_id is not None:
yield streaming_service.format_text_end(state.current_text_id)
if content_builder is not None:
content_builder.on_text_end(state.current_text_id)
state.current_text_id = None
if state.current_reasoning_id is None:
comp, new_active = complete_active_thinking_step(
state=state,
streaming_service=streaming_service,
content_builder=content_builder,
last_active_step_id=state.last_active_step_id,
last_active_step_title=state.last_active_step_title,
last_active_step_items=state.last_active_step_items,
completed_step_ids=state.completed_step_ids,
)
if comp:
yield comp
state.last_active_step_id = new_active
if state.just_finished_tool:
state.last_active_step_id = None
state.last_active_step_title = ""
state.last_active_step_items = []
state.just_finished_tool = False
state.current_reasoning_id = streaming_service.generate_reasoning_id()
yield streaming_service.format_reasoning_start(state.current_reasoning_id)
if content_builder is not None:
content_builder.on_reasoning_start(state.current_reasoning_id)
yield streaming_service.format_reasoning_delta(
state.current_reasoning_id, reasoning_delta
)
if content_builder is not None:
content_builder.on_reasoning_delta(
state.current_reasoning_id, reasoning_delta
)
if text_delta:
if state.current_reasoning_id is not None:
yield streaming_service.format_reasoning_end(state.current_reasoning_id)
if content_builder is not None:
content_builder.on_reasoning_end(state.current_reasoning_id)
state.current_reasoning_id = None
if state.current_text_id is None:
comp, new_active = complete_active_thinking_step(
state=state,
streaming_service=streaming_service,
content_builder=content_builder,
last_active_step_id=state.last_active_step_id,
last_active_step_title=state.last_active_step_title,
last_active_step_items=state.last_active_step_items,
completed_step_ids=state.completed_step_ids,
)
if comp:
yield comp
state.last_active_step_id = new_active
if state.just_finished_tool:
state.last_active_step_id = None
state.last_active_step_title = ""
state.last_active_step_items = []
state.just_finished_tool = False
state.current_text_id = streaming_service.generate_text_id()
yield streaming_service.format_text_start(state.current_text_id)
if content_builder is not None:
content_builder.on_text_start(state.current_text_id)
yield streaming_service.format_text_delta(state.current_text_id, text_delta)
state.accumulated_text += text_delta
if content_builder is not None:
content_builder.on_text_delta(state.current_text_id, text_delta)
if parts["tool_call_chunks"]:
for tcc in parts["tool_call_chunks"]:
idx = tcc.get("index")
if idx is not None and idx not in state.index_to_meta:
lc_id = tcc.get("id")
name = tcc.get("name")
if lc_id and name:
ui_id = lc_id
tool_input_metadata: dict[str, Any] | None = None
if name == "task":
sid = ensure_pending_task_span_for_lc(state, str(lc_id))
tool_input_metadata = {"spanId": sid}
if state.current_text_id is not None:
yield streaming_service.format_text_end(state.current_text_id)
if content_builder is not None:
content_builder.on_text_end(state.current_text_id)
state.current_text_id = None
if state.current_reasoning_id is not None:
yield streaming_service.format_reasoning_end(
state.current_reasoning_id
)
if content_builder is not None:
content_builder.on_reasoning_end(state.current_reasoning_id)
state.current_reasoning_id = None
state.index_to_meta[idx] = {
"ui_id": ui_id,
"lc_id": lc_id,
"name": name,
}
yield streaming_service.format_tool_input_start(
ui_id,
name,
langchain_tool_call_id=lc_id,
metadata=tool_input_metadata,
)
if content_builder is not None:
content_builder.on_tool_input_start(
ui_id, name, lc_id, metadata=tool_input_metadata
)
meta = state.index_to_meta.get(idx) if idx is not None else None
if meta:
args_chunk = tcc.get("args") or ""
if args_chunk:
yield streaming_service.format_tool_input_delta(
meta["ui_id"], args_chunk
)
if content_builder is not None:
content_builder.on_tool_input_delta(meta["ui_id"], args_chunk)
else:
state.pending_tool_call_chunks.append(tcc)

View file

@ -0,0 +1,57 @@
"""Custom graph events routed to SSE (documents, action logs, report progress)."""
from __future__ import annotations
from collections.abc import Iterator
from typing import Any
from app.tasks.chat.streaming.handlers.custom_events import (
handle_action_log,
handle_action_log_updated,
handle_document_created,
handle_report_progress,
)
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
def iter_custom_event_frames(
event: dict[str, Any],
*,
state: AgentEventRelayState,
streaming_service: Any,
content_builder: Any | None,
) -> Iterator[str]:
"""Yield any SSE produced by ad-hoc graph events (documents, action logs, report progress)."""
name = event.get("name")
data = event.get("data", {})
if name == "report_progress":
frame, state.last_active_step_items = handle_report_progress(
data,
last_active_step_id=state.last_active_step_id,
last_active_step_title=state.last_active_step_title,
last_active_step_items=state.last_active_step_items,
streaming_service=streaming_service,
content_builder=content_builder,
thinking_metadata=state.span_metadata_if_active(),
)
if frame:
yield frame
return
if name == "document_created":
frame = handle_document_created(data, streaming_service=streaming_service)
if frame:
yield frame
return
if name == "action_log":
frame = handle_action_log(data, streaming_service=streaming_service)
if frame:
yield frame
return
if name == "action_log_updated":
frame = handle_action_log_updated(data, streaming_service=streaming_service)
if frame:
yield frame

View file

@ -0,0 +1,79 @@
"""Custom-event payloads turned into SSE (no model/tool stream handling)."""
from __future__ import annotations
from typing import Any
from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame
def handle_report_progress(
data: dict[str, Any],
*,
last_active_step_id: str | None,
last_active_step_title: str,
last_active_step_items: list[str],
streaming_service: Any,
content_builder: Any | None,
thinking_metadata: dict[str, Any] | None = None,
) -> tuple[str | None, list[str]]:
"""Update report step items; may emit one thinking SSE frame.
Returns (frame or None, items list after update).
"""
message = data.get("message", "")
if not message or not last_active_step_id:
return None, last_active_step_items
phase = data.get("phase", "")
topic_items = [
item for item in last_active_step_items if item.startswith("Topic:")
]
if phase in ("revising_section", "adding_section"):
plan_items = [
item
for item in last_active_step_items
if item.startswith("Topic:")
or item.startswith("Modifying ")
or item.startswith("Adding ")
or item.startswith("Removing ")
]
plan_items = [item for item in plan_items if not item.endswith("...")]
new_items = [*plan_items, message]
else:
new_items = [*topic_items, message]
frame = emit_thinking_step_frame(
streaming_service=streaming_service,
content_builder=content_builder,
step_id=last_active_step_id,
title=last_active_step_title,
status="in_progress",
items=new_items,
metadata=thinking_metadata,
)
return frame, new_items
def handle_document_created(data: dict[str, Any], *, streaming_service: Any) -> str | None:
if not data.get("id"):
return None
return streaming_service.format_data(
"documents-updated",
{"action": "created", "document": data},
)
def handle_action_log(data: dict[str, Any], *, streaming_service: Any) -> str | None:
if data.get("id") is None:
return None
return streaming_service.format_data("action-log", data)
def handle_action_log_updated(
data: dict[str, Any], *, streaming_service: Any
) -> str | None:
if data.get("id") is None:
return None
return streaming_service.format_data("action-log-updated", data)

View file

@ -0,0 +1,121 @@
"""Tool end: thinking completion, tool output, and terminal SSE."""
from __future__ import annotations
import json
from collections.abc import Iterator
from typing import Any
from app.tasks.chat.streaming.handlers.tools import (
ToolCompletionEmissionContext,
iter_tool_completion_emission_frames,
resolve_tool_completed_thinking_step,
)
from app.tasks.chat.streaming.helpers.tool_output import tool_output_has_error
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
from app.tasks.chat.streaming.relay.task_span import clear_task_span_if_delegating_task_ended
from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame
def iter_tool_end_frames(
event: dict[str, Any],
*,
state: AgentEventRelayState,
streaming_service: Any,
content_builder: Any | None,
result: Any,
step_prefix: str,
config: dict[str, Any],
) -> Iterator[str]:
"""SSE frames when one tool run finishes."""
state.active_tool_depth = max(0, state.active_tool_depth - 1)
run_id = event.get("run_id", "")
tool_name = event.get("name", "unknown_tool")
raw_output = event.get("data", {}).get("output", "")
staged_file_path = (
state.file_path_by_run.pop(run_id, None) if run_id else None
)
if tool_name == "update_memory":
state.called_update_memory = True
if hasattr(raw_output, "content"):
content = raw_output.content
if isinstance(content, str):
try:
tool_output = json.loads(content)
except (json.JSONDecodeError, TypeError):
tool_output = {"result": content}
elif isinstance(content, dict):
tool_output = content
else:
tool_output = {"result": str(content)}
elif isinstance(raw_output, dict):
tool_output = raw_output
else:
tool_output = {"result": str(raw_output) if raw_output else "completed"}
if tool_name in ("write_file", "edit_file"):
if tool_output_has_error(tool_output):
pass
else:
result.write_succeeded = True
result.verification_succeeded = True
tool_call_id = state.ui_tool_call_id_by_run.get(
run_id,
f"call_{run_id[:32]}" if run_id else "call_unknown",
)
original_step_id = state.tool_step_ids.get(
run_id, f"{step_prefix}-unknown-{run_id[:8]}"
)
state.completed_step_ids.add(original_step_id)
holder = state.current_lc_tool_call_id
holder["value"] = None
authoritative = getattr(raw_output, "tool_call_id", None)
if isinstance(authoritative, str) and authoritative:
holder["value"] = authoritative
if run_id:
state.lc_tool_call_id_by_run[run_id] = authoritative
elif run_id and run_id in state.lc_tool_call_id_by_run:
holder["value"] = state.lc_tool_call_id_by_run[run_id]
items = state.last_active_step_items
title, completed_items = resolve_tool_completed_thinking_step(
tool_name, tool_output, items
)
yield emit_thinking_step_frame(
streaming_service=streaming_service,
content_builder=content_builder,
step_id=original_step_id,
title=title,
status="completed",
items=completed_items,
metadata=state.span_metadata_if_active(),
)
state.just_finished_tool = True
state.last_active_step_id = None
state.last_active_step_title = ""
state.last_active_step_items = []
emission_ctx = ToolCompletionEmissionContext(
tool_name=tool_name,
tool_call_id=tool_call_id,
tool_output=tool_output,
streaming_service=streaming_service,
content_builder=content_builder,
langchain_tool_call_id_holder=holder,
stream_result=result,
langgraph_config=config,
staged_workspace_file_path=staged_file_path,
tool_metadata=state.tool_activity_metadata(
thinking_step_id=original_step_id,
),
)
yield from iter_tool_completion_emission_frames(emission_ctx)
clear_task_span_if_delegating_task_ended(
state, tool_name=tool_name, run_id=run_id
)

View file

@ -0,0 +1,29 @@
"""Emit tool-output SSE and optional assistant content updates."""
from __future__ import annotations
from typing import Any
def emit_tool_output_available_frame(
*,
streaming_service: Any,
content_builder: Any | None,
langchain_id_holder: dict[str, str | None],
call_id: str,
output: Any,
tool_metadata: dict[str, Any] | None = None,
) -> str:
if content_builder is not None:
content_builder.on_tool_output_available(
call_id,
output,
langchain_id_holder["value"],
metadata=tool_metadata,
)
return streaming_service.format_tool_output_available(
call_id,
output,
langchain_tool_call_id=langchain_id_holder["value"],
metadata=tool_metadata,
)

View file

@ -0,0 +1,161 @@
"""Tool start: thinking-step and tool-input SSE."""
from __future__ import annotations
import json
from collections.abc import Iterator
from typing import Any
from app.tasks.chat.streaming.handlers.tools import resolve_tool_start_thinking
from app.tasks.chat.streaming.helpers.tool_call_matching import (
match_buffered_langchain_tool_call_id,
)
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
from app.tasks.chat.streaming.relay.task_span import open_task_span
from app.tasks.chat.streaming.relay.thinking_step_completion import (
complete_active_thinking_step,
)
from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame
def iter_tool_start_frames(
event: dict[str, Any],
*,
state: AgentEventRelayState,
streaming_service: Any,
content_builder: Any | None,
result: Any,
step_prefix: str,
) -> Iterator[str]:
"""SSE frames for the start of one tool run."""
state.active_tool_depth += 1
tool_name = event.get("name", "unknown_tool")
run_id = event.get("run_id", "")
tool_input = event.get("data", {}).get("input", {})
if tool_name in ("write_file", "edit_file"):
result.write_attempted = True
if isinstance(tool_input, dict):
file_path = tool_input.get("file_path")
if isinstance(file_path, str) and file_path.strip() and run_id:
state.file_path_by_run[run_id] = file_path.strip()
if state.current_text_id is not None:
yield streaming_service.format_text_end(state.current_text_id)
if content_builder is not None:
content_builder.on_text_end(state.current_text_id)
state.current_text_id = None
if state.last_active_step_title != "Synthesizing response":
comp, new_active = complete_active_thinking_step(
state=state,
streaming_service=streaming_service,
content_builder=content_builder,
last_active_step_id=state.last_active_step_id,
last_active_step_title=state.last_active_step_title,
last_active_step_items=state.last_active_step_items,
completed_step_ids=state.completed_step_ids,
)
if comp:
yield comp
state.last_active_step_id = new_active
state.just_finished_tool = False
tool_step_id = state.next_thinking_step_id(step_prefix)
state.tool_step_ids[run_id] = tool_step_id
state.last_active_step_id = tool_step_id
matched_meta: dict[str, str] | None = None
taken_ui_ids = set(state.ui_tool_call_id_by_run.values())
for meta in state.index_to_meta.values():
if meta["name"] == tool_name and meta["ui_id"] not in taken_ui_ids:
matched_meta = meta
break
tool_call_id: str
langchain_tool_call_id: str | None = None
if matched_meta is not None:
tool_call_id = matched_meta["ui_id"]
langchain_tool_call_id = matched_meta["lc_id"]
if run_id:
state.lc_tool_call_id_by_run[run_id] = matched_meta["lc_id"]
else:
tool_call_id = (
f"call_{run_id[:32]}"
if run_id
else streaming_service.generate_tool_call_id()
)
langchain_tool_call_id = match_buffered_langchain_tool_call_id(
state.pending_tool_call_chunks,
tool_name,
run_id,
state.lc_tool_call_id_by_run,
)
if tool_name == "task":
open_task_span(
state,
run_id=run_id,
langchain_tool_call_id=langchain_tool_call_id,
)
span_md = state.span_metadata_if_active()
tool_md = state.tool_activity_metadata(thinking_step_id=tool_step_id)
if matched_meta is None:
yield streaming_service.format_tool_input_start(
tool_call_id,
tool_name,
langchain_tool_call_id=langchain_tool_call_id,
metadata=tool_md,
)
if content_builder is not None:
content_builder.on_tool_input_start(
tool_call_id,
tool_name,
langchain_tool_call_id,
metadata=tool_md,
)
thinking = resolve_tool_start_thinking(tool_name, tool_input)
state.last_active_step_title = thinking.title
state.last_active_step_items = thinking.items
frame_kw: dict[str, Any] = {
"streaming_service": streaming_service,
"content_builder": content_builder,
"step_id": tool_step_id,
"title": thinking.title,
"status": "in_progress",
"metadata": span_md,
}
if thinking.include_items_on_frame:
frame_kw["items"] = thinking.items
yield emit_thinking_step_frame(**frame_kw)
if run_id:
state.ui_tool_call_id_by_run[run_id] = tool_call_id
if isinstance(tool_input, dict):
_safe_input: dict[str, Any] = {}
for _k, _v in tool_input.items():
try:
json.dumps(_v)
_safe_input[_k] = _v
except (TypeError, ValueError, OverflowError):
pass
else:
_safe_input = {"input": tool_input}
yield streaming_service.format_tool_input_available(
tool_call_id,
tool_name,
_safe_input,
langchain_tool_call_id=langchain_tool_call_id,
metadata=tool_md,
)
if content_builder is not None:
content_builder.on_tool_input_available(
tool_call_id,
tool_name,
_safe_input,
langchain_tool_call_id,
metadata=tool_md,
)

View file

@ -0,0 +1,23 @@
"""Per-tool streaming: thinking-step and completion emission."""
from __future__ import annotations
from app.tasks.chat.streaming.handlers.tools.emission_context import (
ToolCompletionEmissionContext,
)
from app.tasks.chat.streaming.handlers.tools.registry import (
iter_tool_completion_emission_frames,
resolve_tool_completed_thinking_step,
resolve_tool_start_thinking,
)
from app.tasks.chat.streaming.handlers.tools.shared.model import (
ToolStartThinking,
)
__all__ = [
"ToolCompletionEmissionContext",
"ToolStartThinking",
"iter_tool_completion_emission_frames",
"resolve_tool_completed_thinking_step",
"resolve_tool_start_thinking",
]

View file

@ -0,0 +1,15 @@
from __future__ import annotations
from collections.abc import Iterator
from app.tasks.chat.streaming.handlers.tools.emission_context import (
ToolCompletionEmissionContext,
)
def iter_completion_emission_frames(
ctx: ToolCompletionEmissionContext,
) -> Iterator[str]:
out = ctx.tool_output
payload = out if isinstance(out, dict) else {"result": out}
yield ctx.emit_tool_output_card(payload)

View file

@ -0,0 +1,22 @@
from __future__ import annotations
from typing import Any
from app.tasks.chat.streaming.handlers.tools.default import (
thinking as default_thinking,
)
from app.tasks.chat.streaming.handlers.tools.shared.model import (
ToolStartThinking,
)
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
return default_thinking.resolve_start_thinking(tool_name, tool_input)
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
) -> tuple[str, list[str]]:
return default_thinking.resolve_completed_thinking(
tool_name, tool_output, last_items
)

View file

@ -0,0 +1,31 @@
from __future__ import annotations
SHARED_CONNECTOR_TOOLS: frozenset[str] = frozenset(
{
"create_calendar_event",
"create_confluence_page",
"create_dropbox_file",
"create_gmail_draft",
"create_google_drive_file",
"create_jira_issue",
"create_linear_issue",
"create_notion_page",
"create_onedrive_file",
"delete_calendar_event",
"delete_confluence_page",
"delete_dropbox_file",
"delete_google_drive_file",
"delete_jira_issue",
"delete_linear_issue",
"delete_notion_page",
"delete_onedrive_file",
"send_gmail_email",
"trash_gmail_email",
"update_calendar_event",
"update_confluence_page",
"update_gmail_draft",
"update_jira_issue",
"update_linear_issue",
"update_notion_page",
}
)

View file

@ -0,0 +1,3 @@
"""Fallback tool package."""
from __future__ import annotations

View file

@ -0,0 +1,24 @@
"""Default tool-output card and a short completion terminal line."""
from __future__ import annotations
from collections.abc import Iterator
from app.tasks.chat.streaming.handlers.tools.emission_context import (
ToolCompletionEmissionContext,
)
def iter_completion_emission_frames(
ctx: ToolCompletionEmissionContext,
) -> Iterator[str]:
yield ctx.emit_tool_output_card(
{
"status": "completed",
"result_length": len(str(ctx.tool_output)),
},
)
yield ctx.streaming_service.format_terminal_info(
f"Tool {ctx.tool_name} completed",
"success",
)

View file

@ -0,0 +1,23 @@
"""Fallback thinking-step copy for unknown tools and connectors without custom UI."""
from __future__ import annotations
from typing import Any
from app.tasks.chat.streaming.handlers.tools.shared.model import (
ToolStartThinking,
)
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
del tool_input
title = tool_name.replace("_", " ").strip().capitalize() or tool_name
return ToolStartThinking(title=title, items=[], include_items_on_frame=False)
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str]
) -> tuple[str, list[str]]:
del tool_output
title = tool_name.replace("_", " ").strip().capitalize() or tool_name
return (title, last_items)

View file

@ -0,0 +1,28 @@
"""generate_image: tool card + terminal summary."""
from __future__ import annotations
from collections.abc import Iterator
from app.tasks.chat.streaming.handlers.tools.emission_context import (
ToolCompletionEmissionContext,
)
def iter_completion_emission_frames(
ctx: ToolCompletionEmissionContext,
) -> Iterator[str]:
out = ctx.tool_output
payload = out if isinstance(out, dict) else {"result": out}
yield ctx.emit_tool_output_card(payload)
if isinstance(out, dict):
if out.get("error"):
yield ctx.streaming_service.format_terminal_info(
f"Image generation failed: {out['error'][:60]}",
"error",
)
else:
yield ctx.streaming_service.format_terminal_info(
"Image generated successfully",
"success",
)

View file

@ -0,0 +1,39 @@
"""generate_image: thinking-step copy."""
from __future__ import annotations
from typing import Any
from app.tasks.chat.streaming.handlers.tools.deliverables.shared.tool_input import (
as_tool_input_dict,
)
from app.tasks.chat.streaming.handlers.tools.shared.model import (
ToolStartThinking,
)
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
del tool_name
d = as_tool_input_dict(tool_input)
prompt = d.get("prompt", "") if isinstance(tool_input, dict) else str(tool_input)
return ToolStartThinking(
title="Generating image",
items=[f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}"],
)
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
) -> tuple[str, list[str]]:
del tool_name
items = last_items
if isinstance(tool_output, dict) and not tool_output.get("error"):
completed = [*items, "Image generated successfully"]
else:
error_msg = (
tool_output.get("error", "Generation failed")
if isinstance(tool_output, dict)
else "Generation failed"
)
completed = [*items, f"Error: {error_msg}"]
return ("Generating image", completed)

View file

@ -0,0 +1,37 @@
"""generate_podcast: tool card + queue / success / failure terminal lines."""
from __future__ import annotations
from collections.abc import Iterator
from app.tasks.chat.streaming.handlers.tools.emission_context import (
ToolCompletionEmissionContext,
)
def iter_completion_emission_frames(
ctx: ToolCompletionEmissionContext,
) -> Iterator[str]:
out = ctx.tool_output
payload = out if isinstance(out, dict) else {"result": out}
yield ctx.emit_tool_output_card(payload)
if isinstance(out, dict) and out.get("status") in (
"pending",
"generating",
"processing",
):
yield ctx.streaming_service.format_terminal_info(
f"Podcast queued: {out.get('title', 'Podcast')}",
"success",
)
elif isinstance(out, dict) and out.get("status") in ("ready", "success"):
yield ctx.streaming_service.format_terminal_info(
f"Podcast generated successfully: {out.get('title', 'Podcast')}",
"success",
)
elif isinstance(out, dict) and out.get("status") in ("failed", "error"):
error_msg = out.get("error", "Unknown error")
yield ctx.streaming_service.format_terminal_info(
f"Podcast generation failed: {error_msg}",
"error",
)

View file

@ -0,0 +1,80 @@
"""generate_podcast: thinking-step copy."""
from __future__ import annotations
from typing import Any
from app.tasks.chat.streaming.handlers.tools.deliverables.shared.tool_input import (
as_tool_input_dict,
)
from app.tasks.chat.streaming.handlers.tools.shared.model import (
ToolStartThinking,
)
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
del tool_name
d = as_tool_input_dict(tool_input)
podcast_title = (
d.get("podcast_title", "SurfSense Podcast")
if isinstance(tool_input, dict)
else "SurfSense Podcast"
)
content_len = len(
d.get("source_content", "") if isinstance(tool_input, dict) else ""
)
return ToolStartThinking(
title="Generating podcast",
items=[
f"Title: {podcast_title}",
f"Content: {content_len:,} characters",
"Preparing audio generation...",
],
)
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
) -> tuple[str, list[str]]:
del tool_name
items = last_items
podcast_status = (
tool_output.get("status", "unknown")
if isinstance(tool_output, dict)
else "unknown"
)
podcast_title = (
tool_output.get("title", "Podcast")
if isinstance(tool_output, dict)
else "Podcast"
)
if podcast_status in ("pending", "generating", "processing"):
completed = [
f"Title: {podcast_title}",
"Podcast generation started",
"Processing in background...",
]
elif podcast_status == "already_generating":
completed = [
f"Title: {podcast_title}",
"Podcast already in progress",
"Please wait for it to complete",
]
elif podcast_status in ("failed", "error"):
error_msg = (
tool_output.get("error", "Unknown error")
if isinstance(tool_output, dict)
else "Unknown error"
)
completed = [
f"Title: {podcast_title}",
f"Error: {error_msg[:50]}",
]
elif podcast_status in ("ready", "success"):
completed = [
f"Title: {podcast_title}",
"Podcast ready",
]
else:
completed = items
return ("Generating podcast", completed)

View file

@ -0,0 +1,33 @@
"""generate_report: full payload + terminal line."""
from __future__ import annotations
from collections.abc import Iterator
from app.tasks.chat.streaming.handlers.tools.emission_context import (
ToolCompletionEmissionContext,
)
def iter_completion_emission_frames(
ctx: ToolCompletionEmissionContext,
) -> Iterator[str]:
out = ctx.tool_output
payload = out if isinstance(out, dict) else {"result": out}
yield ctx.emit_tool_output_card(payload)
if isinstance(out, dict) and out.get("status") == "ready":
word_count = out.get("word_count", 0)
yield ctx.streaming_service.format_terminal_info(
f"Report generated: {out.get('title', 'Report')} ({word_count:,} words)",
"success",
)
else:
error_msg = (
out.get("error", "Unknown error")
if isinstance(out, dict)
else "Unknown error"
)
yield ctx.streaming_service.format_terminal_info(
f"Report generation failed: {error_msg}",
"error",
)

View file

@ -0,0 +1,77 @@
"""generate_report: thinking-step copy."""
from __future__ import annotations
from typing import Any
from app.tasks.chat.streaming.handlers.tools.deliverables.shared.tool_input import (
as_tool_input_dict,
)
from app.tasks.chat.streaming.handlers.tools.shared.model import (
ToolStartThinking,
)
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
del tool_name
d = as_tool_input_dict(tool_input)
report_topic = (
d.get("topic", "Report") if isinstance(tool_input, dict) else "Report"
)
is_revision = bool(
isinstance(tool_input, dict) and tool_input.get("parent_report_id")
)
step_title = "Revising report" if is_revision else "Generating report"
return ToolStartThinking(
title=step_title,
items=[f"Topic: {report_topic}", "Analyzing source content..."],
)
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
) -> tuple[str, list[str]]:
del tool_name
items = last_items
report_status = (
tool_output.get("status", "unknown")
if isinstance(tool_output, dict)
else "unknown"
)
report_title = (
tool_output.get("title", "Report")
if isinstance(tool_output, dict)
else "Report"
)
word_count = (
tool_output.get("word_count", 0)
if isinstance(tool_output, dict)
else 0
)
is_revision = (
tool_output.get("is_revision", False)
if isinstance(tool_output, dict)
else False
)
step_title = "Revising report" if is_revision else "Generating report"
if report_status == "ready":
completed = [
f"Topic: {report_title}",
f"{word_count:,} words",
"Report ready",
]
elif report_status == "failed":
error_msg = (
tool_output.get("error", "Unknown error")
if isinstance(tool_output, dict)
else "Unknown error"
)
completed = [
f"Topic: {report_title}",
f"Error: {error_msg[:50]}",
]
else:
completed = items
return (step_title, completed)

View file

@ -0,0 +1,32 @@
"""generate_resume: full payload + terminal line."""
from __future__ import annotations
from collections.abc import Iterator
from app.tasks.chat.streaming.handlers.tools.emission_context import (
ToolCompletionEmissionContext,
)
def iter_completion_emission_frames(
ctx: ToolCompletionEmissionContext,
) -> Iterator[str]:
out = ctx.tool_output
payload = out if isinstance(out, dict) else {"result": out}
yield ctx.emit_tool_output_card(payload)
if isinstance(out, dict) and out.get("status") == "ready":
yield ctx.streaming_service.format_terminal_info(
f"Resume generated: {out.get('title', 'Resume')}",
"success",
)
else:
error_msg = (
out.get("error", "Unknown error")
if isinstance(out, dict)
else "Unknown error"
)
yield ctx.streaming_service.format_terminal_info(
f"Resume generation failed: {error_msg}",
"error",
)

View file

@ -0,0 +1,24 @@
"""generate_resume: generic thinking titles and items."""
from __future__ import annotations
from typing import Any
from app.tasks.chat.streaming.handlers.tools.default import (
thinking as default_thinking,
)
from app.tasks.chat.streaming.handlers.tools.shared.model import (
ToolStartThinking,
)
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
return default_thinking.resolve_start_thinking(tool_name, tool_input)
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
) -> tuple[str, list[str]]:
return default_thinking.resolve_completed_thinking(
tool_name, tool_output, last_items
)

View file

@ -0,0 +1,28 @@
"""generate_video_presentation: tool card + terminal line."""
from __future__ import annotations
from collections.abc import Iterator
from app.tasks.chat.streaming.handlers.tools.emission_context import (
ToolCompletionEmissionContext,
)
def iter_completion_emission_frames(
ctx: ToolCompletionEmissionContext,
) -> Iterator[str]:
out = ctx.tool_output
payload = out if isinstance(out, dict) else {"result": out}
yield ctx.emit_tool_output_card(payload)
if isinstance(out, dict) and out.get("status") == "pending":
yield ctx.streaming_service.format_terminal_info(
f"Video presentation queued: {out.get('title', 'Presentation')}",
"success",
)
elif isinstance(out, dict) and out.get("status") == "failed":
error_msg = out.get("error", "Unknown error")
yield ctx.streaming_service.format_terminal_info(
f"Presentation generation failed: {error_msg}",
"error",
)

View file

@ -0,0 +1,52 @@
"""generate_video_presentation: generic in-progress thinking; completion is status-driven."""
from __future__ import annotations
from typing import Any
from app.tasks.chat.streaming.handlers.tools.default import (
thinking as default_thinking,
)
from app.tasks.chat.streaming.handlers.tools.shared.model import (
ToolStartThinking,
)
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
return default_thinking.resolve_start_thinking(tool_name, tool_input)
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
) -> tuple[str, list[str]]:
del tool_name
items = last_items
vp_status = (
tool_output.get("status", "unknown")
if isinstance(tool_output, dict)
else "unknown"
)
vp_title = (
tool_output.get("title", "Presentation")
if isinstance(tool_output, dict)
else "Presentation"
)
if vp_status in ("pending", "generating"):
completed = [
f"Title: {vp_title}",
"Presentation generation started",
"Processing in background...",
]
elif vp_status == "failed":
error_msg = (
tool_output.get("error", "Unknown error")
if isinstance(tool_output, dict)
else "Unknown error"
)
completed = [
f"Title: {vp_title}",
f"Error: {error_msg[:50]}",
]
else:
completed = items
return ("Generating video presentation", completed)

View file

@ -0,0 +1,16 @@
"""save_document: default completion card and terminal line."""
from __future__ import annotations
from collections.abc import Iterator
from app.tasks.chat.streaming.handlers.tools.default import emission as _default
from app.tasks.chat.streaming.handlers.tools.emission_context import (
ToolCompletionEmissionContext,
)
def iter_completion_emission_frames(
ctx: ToolCompletionEmissionContext,
) -> Iterator[str]:
yield from _default.iter_completion_emission_frames(ctx)

View file

@ -0,0 +1,38 @@
"""save_document: thinking-step copy."""
from __future__ import annotations
from typing import Any
from app.tasks.chat.streaming.handlers.tools.deliverables.shared.tool_input import (
as_tool_input_dict,
)
from app.tasks.chat.streaming.handlers.tools.shared.model import (
ToolStartThinking,
)
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
del tool_name
d = as_tool_input_dict(tool_input)
doc_title = d.get("title", "") if isinstance(tool_input, dict) else str(tool_input)
display_title = doc_title[:60] + ("" if len(doc_title) > 60 else "")
return ToolStartThinking(title="Saving document", items=[display_title])
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
) -> tuple[str, list[str]]:
del tool_name
items = last_items
result_str = (
tool_output.get("result", "")
if isinstance(tool_output, dict)
else str(tool_output)
)
is_error = "Error" in result_str
completed = [
*items,
result_str[:80] if is_error else "Saved to knowledge base",
]
return ("Saving document", completed)

View file

@ -0,0 +1,9 @@
"""Tool-call args for deliverable thinking modules."""
from __future__ import annotations
from typing import Any
def as_tool_input_dict(tool_input: Any) -> dict[str, Any]:
return tool_input if isinstance(tool_input, dict) else {}

View file

@ -0,0 +1,12 @@
from __future__ import annotations
DELIVERABLE_TOOLS: frozenset[str] = frozenset(
{
"generate_image",
"generate_podcast",
"generate_report",
"generate_resume",
"generate_video_presentation",
"save_document",
}
)

View file

@ -0,0 +1,36 @@
"""Context for one tool-completion emission pass."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from app.tasks.chat.streaming.handlers.tool_output_frame import (
emit_tool_output_available_frame,
)
@dataclass
class ToolCompletionEmissionContext:
"""Streaming service, tool output, and ids for completion frames."""
tool_name: str
tool_call_id: str
tool_output: Any
streaming_service: Any
content_builder: Any | None
langchain_tool_call_id_holder: dict[str, str | None]
stream_result: Any
langgraph_config: dict[str, Any]
staged_workspace_file_path: str | None
tool_metadata: dict[str, Any] | None = None
def emit_tool_output_card(self, payload: Any) -> str:
return emit_tool_output_available_frame(
streaming_service=self.streaming_service,
content_builder=self.content_builder,
langchain_id_holder=self.langchain_tool_call_id_holder,
call_id=self.tool_call_id,
output=payload,
tool_metadata=self.tool_metadata,
)

View file

@ -0,0 +1,27 @@
"""edit_file: thinking-step copy."""
from __future__ import annotations
from typing import Any
from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import (
as_tool_input_dict,
truncate_path,
)
from app.tasks.chat.streaming.handlers.tools.shared.model import (
ToolStartThinking,
)
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
del tool_name
d = as_tool_input_dict(tool_input)
fp = d.get("file_path", "") if isinstance(tool_input, dict) else str(tool_input)
return ToolStartThinking(title="Editing file", items=[truncate_path(fp)])
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Editing file", last_items)

View file

@ -0,0 +1,40 @@
"""execute: exit code, stdout, sandbox file hints."""
from __future__ import annotations
import re
from collections.abc import Iterator
from app.tasks.chat.streaming.handlers.tools.emission_context import (
ToolCompletionEmissionContext,
)
def iter_completion_emission_frames(
ctx: ToolCompletionEmissionContext,
) -> Iterator[str]:
out = ctx.tool_output
raw_text = out.get("result", "") if isinstance(out, dict) else str(out)
exit_code: int | None = None
output_text = raw_text
m = re.match(r"^Exit code:\s*(\d+)", raw_text)
if m:
exit_code = int(m.group(1))
om = re.search(r"\nOutput:\n([\s\S]*)", raw_text)
output_text = om.group(1) if om else ""
thread_id_str = ctx.langgraph_config.get("configurable", {}).get("thread_id", "")
for sf_match in re.finditer(
r"^SANDBOX_FILE:\s*(.+)$", output_text, re.MULTILINE
):
fpath = sf_match.group(1).strip()
if fpath and fpath not in ctx.stream_result.sandbox_files:
ctx.stream_result.sandbox_files.append(fpath)
yield ctx.emit_tool_output_card(
{
"exit_code": exit_code,
"output": output_text,
"thread_id": thread_id_str,
},
)

View file

@ -0,0 +1,42 @@
"""execute: sandbox command thinking + completion lines."""
from __future__ import annotations
import re
from typing import Any
from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import (
as_tool_input_dict,
)
from app.tasks.chat.streaming.handlers.tools.shared.model import (
ToolStartThinking,
)
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
del tool_name
d = as_tool_input_dict(tool_input)
cmd = d.get("command", "") if isinstance(tool_input, dict) else str(tool_input)
display_cmd = cmd[:80] + ("" if len(cmd) > 80 else "")
return ToolStartThinking(title="Running command", items=[f"$ {display_cmd}"])
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
) -> tuple[str, list[str]]:
del tool_name
items = last_items
raw_text = (
tool_output.get("result", "")
if isinstance(tool_output, dict)
else str(tool_output)
)
m = re.match(r"^Exit code:\s*(\d+)", raw_text)
exit_code_val = int(m.group(1)) if m else None
if exit_code_val is not None and exit_code_val == 0:
completed = [*items, "Completed successfully"]
elif exit_code_val is not None:
completed = [*items, f"Exit code: {exit_code_val}"]
else:
completed = [*items, "Finished"]
return ("Running command", completed)

View file

@ -0,0 +1,27 @@
"""glob: thinking-step copy."""
from __future__ import annotations
from typing import Any
from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import (
as_tool_input_dict,
)
from app.tasks.chat.streaming.handlers.tools.shared.model import (
ToolStartThinking,
)
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
del tool_name
d = as_tool_input_dict(tool_input)
pat = d.get("pattern", "") if isinstance(tool_input, dict) else str(tool_input)
base = d.get("path", "/") if isinstance(tool_input, dict) else "/"
return ToolStartThinking(title="Searching files", items=[f"{pat} in {base}"])
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Searching files", last_items)

View file

@ -0,0 +1,31 @@
"""grep: thinking-step copy."""
from __future__ import annotations
from typing import Any
from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import (
as_tool_input_dict,
)
from app.tasks.chat.streaming.handlers.tools.shared.model import (
ToolStartThinking,
)
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
del tool_name
d = as_tool_input_dict(tool_input)
pat = d.get("pattern", "") if isinstance(tool_input, dict) else str(tool_input)
grep_path = d.get("path", "") if isinstance(tool_input, dict) else ""
display_pat = pat[:60] + ("" if len(pat) > 60 else "")
return ToolStartThinking(
title="Searching content",
items=[f'"{display_pat}"' + (f" in {grep_path}" if grep_path else "")],
)
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Searching content", last_items)

View file

@ -0,0 +1,59 @@
"""ls: thinking-step copy for directory listing."""
from __future__ import annotations
import ast
from typing import Any
from app.tasks.chat.streaming.handlers.tools.shared.model import (
ToolStartThinking,
)
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
del tool_name
if isinstance(tool_input, dict):
path = tool_input.get("path", "/")
else:
path = str(tool_input)
return ToolStartThinking(title="Listing files", items=[path])
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
) -> tuple[str, list[str]]:
del tool_name
if isinstance(tool_output, dict):
ls_output = tool_output.get("result", "")
elif isinstance(tool_output, str):
ls_output = tool_output
else:
ls_output = str(tool_output) if tool_output else ""
file_names: list[str] = []
if ls_output:
paths: list[str] = []
try:
parsed = ast.literal_eval(ls_output)
if isinstance(parsed, list):
paths = [str(p) for p in parsed]
except (ValueError, SyntaxError):
paths = [
line.strip()
for line in ls_output.strip().split("\n")
if line.strip()
]
for p in paths:
name = p.rstrip("/").split("/")[-1]
if name and len(name) <= 40:
file_names.append(name)
elif name:
file_names.append(name[:37] + "...")
if file_names:
if len(file_names) <= 5:
completed = [f"[{name}]" for name in file_names]
else:
completed = [f"[{name}]" for name in file_names[:4]]
completed.append(f"(+{len(file_names) - 4} more)")
else:
completed = ["No files found"]
return ("Listing files", completed)

View file

@ -0,0 +1,27 @@
"""mkdir: thinking-step copy."""
from __future__ import annotations
from typing import Any
from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import (
as_tool_input_dict,
)
from app.tasks.chat.streaming.handlers.tools.shared.model import (
ToolStartThinking,
)
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
del tool_name
d = as_tool_input_dict(tool_input)
p = d.get("path", "") if isinstance(tool_input, dict) else str(tool_input)
display = p if len(p) <= 80 else "" + p[-77:]
return ToolStartThinking(title="Creating folder", items=[display] if display else [])
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Creating folder", last_items)

View file

@ -0,0 +1,33 @@
"""move_file: thinking-step copy."""
from __future__ import annotations
from typing import Any
from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import (
as_tool_input_dict,
truncate_middle,
)
from app.tasks.chat.streaming.handlers.tools.shared.model import (
ToolStartThinking,
)
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
del tool_name
d = as_tool_input_dict(tool_input)
src = d.get("source_path", "") if isinstance(tool_input, dict) else ""
dst = d.get("destination_path", "") if isinstance(tool_input, dict) else ""
display_src = truncate_middle(src, max_len=60)
display_dst = truncate_middle(dst, max_len=60)
return ToolStartThinking(
title="Moving file",
items=[f"{display_src}{display_dst}"] if src or dst else [],
)
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Moving file", last_items)

Some files were not shown because too many files have changed in this diff Show more