mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-12 17:22:38 +02:00
Merge pull request #1357 from CREDO23/feature/multi-agent
[Feature] Multi-agent chat: hierarchical timeline, live subagent streaming, and inline HITL approvals
This commit is contained in:
commit
28a02a9143
232 changed files with 9014 additions and 4055 deletions
|
|
@ -315,14 +315,6 @@ LANGSMITH_PROJECT=surfsense
|
|||
# SURFSENSE_ENABLE_ACTION_LOG=false
|
||||
# SURFSENSE_ENABLE_REVERT_ROUTE=false # Backend-only; flip when UI ships
|
||||
|
||||
# Streaming parity v2 — opt in to LangChain's structured AIMessageChunk
|
||||
# content (typed reasoning blocks, tool-input deltas) and propagate the
|
||||
# real tool_call_id to the SSE layer. When OFF, the stream falls back to
|
||||
# the str-only text path and synthetic "call_<run_id>" tool-call ids.
|
||||
# Schema migrations 135/136 ship unconditionally because they are
|
||||
# forward-compatible.
|
||||
# SURFSENSE_ENABLE_STREAM_PARITY_V2=false
|
||||
|
||||
# Plugins
|
||||
# SURFSENSE_ENABLE_PLUGIN_LOADER=false
|
||||
# Comma-separated allowlist of plugin entry-point names
|
||||
|
|
|
|||
|
|
@ -6,12 +6,19 @@ exposes the side-channel ``stream_resume_chat`` uses to ferry resume payloads.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
from .constants import DEFAULT_SUBAGENT_RECURSION_LIMIT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# langgraph stores the parent task's scratchpad under this configurable key;
|
||||
# subagents inherit the chain via ``parent_scratchpad`` fallback.
|
||||
_LANGGRAPH_SCRATCHPAD_KEY = "__pregel_scratchpad"
|
||||
|
||||
|
||||
def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]:
|
||||
"""RunnableConfig for the nested invoke; raises ``recursion_limit`` to the parent's budget."""
|
||||
|
|
@ -42,3 +49,42 @@ def has_surfsense_resume(runtime: ToolRuntime) -> bool:
|
|||
if not isinstance(configurable, dict):
|
||||
return False
|
||||
return "surfsense_resume_value" in configurable
|
||||
|
||||
|
||||
def drain_parent_null_resume(runtime: ToolRuntime) -> None:
|
||||
"""Consume the parent's lingering ``NULL_TASK_ID/RESUME`` write before delegating.
|
||||
|
||||
``stream_resume_chat`` wakes the main agent with
|
||||
``Command(resume={"decisions": [...]})`` so the propagated
|
||||
``_lg_interrupt(...)`` can return. langgraph stores that payload as the
|
||||
parent task's ``null_resume`` pending write, which only gets consumed
|
||||
*after* ``subagent.[a]invoke`` returns (when the post-call propagation
|
||||
re-fires). While the subagent is mid-execution, any *new* ``interrupt()``
|
||||
inside it (e.g. a follow-up tool call after a mixed approve/reject) walks
|
||||
``subagent_scratchpad → parent_scratchpad.get_null_resume`` and picks up
|
||||
the parent's still-live decisions — mismatching against a different number
|
||||
of hanging tool calls and crashing ``HumanInTheLoopMiddleware``.
|
||||
|
||||
Draining the write here closes that cross-graph leak so subagent
|
||||
interrupts pause cleanly and re-propagate as a fresh approval card.
|
||||
"""
|
||||
cfg = runtime.config or {}
|
||||
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
|
||||
if not isinstance(configurable, dict):
|
||||
return
|
||||
scratchpad = configurable.get(_LANGGRAPH_SCRATCHPAD_KEY)
|
||||
if scratchpad is None:
|
||||
return
|
||||
consume = getattr(scratchpad, "get_null_resume", None)
|
||||
if not callable(consume):
|
||||
return
|
||||
try:
|
||||
consume(True)
|
||||
except Exception:
|
||||
# Defensive: if langgraph's internal scratchpad shape changes we don't
|
||||
# want to break the resume path. Worst case the original ValueError
|
||||
# still surfaces — same behavior as before this fix.
|
||||
logger.debug(
|
||||
"drain_parent_null_resume: scratchpad.get_null_resume raised",
|
||||
exc_info=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from langgraph.types import Command
|
|||
|
||||
from .config import (
|
||||
consume_surfsense_resume,
|
||||
drain_parent_null_resume,
|
||||
has_surfsense_resume,
|
||||
subagent_invoke_config,
|
||||
)
|
||||
|
|
@ -157,6 +158,9 @@ def build_task_tool_with_parent_config(
|
|||
)
|
||||
expected = hitlrequest_action_count(pending_value)
|
||||
resume_value = fan_out_decisions_to_match(resume_value, expected)
|
||||
# Prevent the parent's resume payload from leaking into subagent
|
||||
# interrupts via langgraph's parent_scratchpad fallback.
|
||||
drain_parent_null_resume(runtime)
|
||||
result = subagent.invoke(
|
||||
build_resume_command(resume_value, pending_id),
|
||||
config=sub_config,
|
||||
|
|
@ -221,6 +225,9 @@ def build_task_tool_with_parent_config(
|
|||
)
|
||||
expected = hitlrequest_action_count(pending_value)
|
||||
resume_value = fan_out_decisions_to_match(resume_value, expected)
|
||||
# Prevent the parent's resume payload from leaking into subagent
|
||||
# interrupts via langgraph's parent_scratchpad fallback.
|
||||
drain_parent_null_resume(runtime)
|
||||
result = await subagent.ainvoke(
|
||||
build_resume_command(resume_value, pending_id),
|
||||
config=sub_config,
|
||||
|
|
|
|||
|
|
@ -1,11 +1,3 @@
|
|||
"""Jira tools for creating, updating, and deleting issues."""
|
||||
"""Jira route: native tool factories are empty; MCP supplies tools when configured."""
|
||||
|
||||
from .create_issue import create_create_jira_issue_tool
|
||||
from .delete_issue import create_delete_jira_issue_tool
|
||||
from .update_issue import create_update_jira_issue_tool
|
||||
|
||||
__all__ = [
|
||||
"create_create_jira_issue_tool",
|
||||
"create_delete_jira_issue_tool",
|
||||
"create_update_jira_issue_tool",
|
||||
]
|
||||
__all__: list[str] = []
|
||||
|
|
|
|||
|
|
@ -6,29 +6,9 @@ from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
|||
ToolsPermissions,
|
||||
)
|
||||
|
||||
from .create_issue import create_create_jira_issue_tool
|
||||
from .delete_issue import create_delete_jira_issue_tool
|
||||
from .update_issue import create_update_jira_issue_tool
|
||||
|
||||
|
||||
def load_tools(
|
||||
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
|
||||
) -> ToolsPermissions:
|
||||
d = {**(dependencies or {}), **kwargs}
|
||||
common = {
|
||||
"db_session": d["db_session"],
|
||||
"search_space_id": d["search_space_id"],
|
||||
"user_id": d["user_id"],
|
||||
"connector_id": d.get("connector_id"),
|
||||
}
|
||||
create = create_create_jira_issue_tool(**common)
|
||||
update = create_update_jira_issue_tool(**common)
|
||||
delete = create_delete_jira_issue_tool(**common)
|
||||
return {
|
||||
"allow": [],
|
||||
"ask": [
|
||||
{"name": getattr(create, "name", "") or "", "tool": create},
|
||||
{"name": getattr(update, "name", "") or "", "tool": update},
|
||||
{"name": getattr(delete, "name", "") or "", "tool": delete},
|
||||
],
|
||||
}
|
||||
_ = {**(dependencies or {}), **kwargs}
|
||||
return {"allow": [], "ask": []}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,3 @@
|
|||
"""Linear tools for creating, updating, and deleting issues."""
|
||||
"""Linear route: native tool factories are empty; MCP supplies tools when configured."""
|
||||
|
||||
from .create_issue import create_create_linear_issue_tool
|
||||
from .delete_issue import create_delete_linear_issue_tool
|
||||
from .update_issue import create_update_linear_issue_tool
|
||||
|
||||
__all__ = [
|
||||
"create_create_linear_issue_tool",
|
||||
"create_delete_linear_issue_tool",
|
||||
"create_update_linear_issue_tool",
|
||||
]
|
||||
__all__: list[str] = []
|
||||
|
|
|
|||
|
|
@ -1,248 +0,0 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.linear_connector import LinearAPIError, LinearConnector
|
||||
from app.services.linear import LinearToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_create_linear_issue_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the create_linear_issue tool.
|
||||
|
||||
Args:
|
||||
db_session: Database session for accessing the Linear connector
|
||||
search_space_id: Search space ID to find the Linear connector
|
||||
user_id: User ID for fetching user-specific context
|
||||
connector_id: Optional specific connector ID (if known)
|
||||
|
||||
Returns:
|
||||
Configured create_linear_issue tool
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def create_linear_issue(
|
||||
title: str,
|
||||
description: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new issue in Linear.
|
||||
|
||||
Use this tool when the user explicitly asks to create, add, or file
|
||||
a new issue / ticket / task in Linear. The user MUST describe the issue
|
||||
before you call this tool. If the request is vague, ask what the issue
|
||||
should be about. Never call this tool without a clear topic from the user.
|
||||
|
||||
Args:
|
||||
title: Short, descriptive issue title. Infer from the user's request.
|
||||
description: Optional markdown body for the issue. Generate from context.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", or "error"
|
||||
- issue_id: Linear issue UUID (if success)
|
||||
- identifier: Human-readable ID like "ENG-42" (if success)
|
||||
- url: URL to the created issue (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT: If status is "rejected", the user explicitly declined the action.
|
||||
Respond with a brief acknowledgment (e.g., "Understood, I won't create the issue.")
|
||||
and move on. Do NOT retry, troubleshoot, or suggest alternatives.
|
||||
|
||||
Examples:
|
||||
- "Create a Linear issue for the login bug"
|
||||
- "File a ticket about the payment timeout problem"
|
||||
- "Add an issue for the broken search feature"
|
||||
"""
|
||||
logger.info(f"create_linear_issue called: title='{title}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
logger.error(
|
||||
"Linear tool not properly configured - missing required parameters"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Linear tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = LinearToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
workspaces = context.get("workspaces", [])
|
||||
if workspaces and all(w.get("auth_expired") for w in workspaces):
|
||||
logger.warning("All Linear accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "linear",
|
||||
}
|
||||
|
||||
logger.info(f"Requesting approval for creating Linear issue: '{title}'")
|
||||
result = request_approval(
|
||||
action_type="linear_issue_creation",
|
||||
tool_name="create_linear_issue",
|
||||
params={
|
||||
"title": title,
|
||||
"description": description,
|
||||
"team_id": None,
|
||||
"state_id": None,
|
||||
"assignee_id": None,
|
||||
"priority": None,
|
||||
"label_ids": [],
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
logger.info("Linear issue creation rejected by user")
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_title = result.params.get("title", title)
|
||||
final_description = result.params.get("description", description)
|
||||
final_team_id = result.params.get("team_id")
|
||||
final_state_id = result.params.get("state_id")
|
||||
final_assignee_id = result.params.get("assignee_id")
|
||||
final_priority = result.params.get("priority")
|
||||
final_label_ids = result.params.get("label_ids") or []
|
||||
final_connector_id = result.params.get("connector_id", connector_id)
|
||||
|
||||
if not final_title or not final_title.strip():
|
||||
logger.error("Title is empty or contains only whitespace")
|
||||
return {"status": "error", "message": "Issue title cannot be empty."}
|
||||
if not final_team_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "A team must be selected to create an issue.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
actual_connector_id = final_connector_id
|
||||
if actual_connector_id is None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Linear connector found. Please connect Linear in your workspace settings.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
logger.info(f"Found Linear connector: id={actual_connector_id}")
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == actual_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Linear connector is invalid or has been disconnected.",
|
||||
}
|
||||
logger.info(f"Validated Linear connector: id={actual_connector_id}")
|
||||
|
||||
logger.info(
|
||||
f"Creating Linear issue with final params: title='{final_title}'"
|
||||
)
|
||||
linear_client = LinearConnector(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
result = await linear_client.create_issue(
|
||||
team_id=final_team_id,
|
||||
title=final_title,
|
||||
description=final_description,
|
||||
state_id=final_state_id,
|
||||
assignee_id=final_assignee_id,
|
||||
priority=final_priority,
|
||||
label_ids=final_label_ids if final_label_ids else None,
|
||||
)
|
||||
|
||||
if result.get("status") == "error":
|
||||
logger.error(f"Failed to create Linear issue: {result.get('message')}")
|
||||
return {"status": "error", "message": result.get("message")}
|
||||
|
||||
logger.info(
|
||||
f"Linear issue created: {result.get('identifier')} - {result.get('title')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.linear import LinearKBSyncService
|
||||
|
||||
kb_service = LinearKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
issue_id=result.get("id"),
|
||||
issue_identifier=result.get("identifier", ""),
|
||||
issue_title=result.get("title", final_title),
|
||||
issue_url=result.get("url"),
|
||||
description=final_description,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"issue_id": result.get("id"),
|
||||
"identifier": result.get("identifier"),
|
||||
"url": result.get("url"),
|
||||
"message": (result.get("message", "") + kb_message_suffix),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error creating Linear issue: {e}", exc_info=True)
|
||||
if isinstance(e, ValueError | LinearAPIError):
|
||||
message = str(e)
|
||||
else:
|
||||
message = (
|
||||
"Something went wrong while creating the issue. Please try again."
|
||||
)
|
||||
return {"status": "error", "message": message}
|
||||
|
||||
return create_linear_issue
|
||||
|
|
@ -1,245 +0,0 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.linear_connector import LinearAPIError, LinearConnector
|
||||
from app.services.linear import LinearToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_delete_linear_issue_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the delete_linear_issue tool.
|
||||
|
||||
Args:
|
||||
db_session: Database session for accessing the Linear connector
|
||||
search_space_id: Search space ID to find the Linear connector
|
||||
user_id: User ID for finding the correct Linear connector
|
||||
connector_id: Optional specific connector ID (if known)
|
||||
|
||||
Returns:
|
||||
Configured delete_linear_issue tool
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def delete_linear_issue(
|
||||
issue_ref: str,
|
||||
delete_from_kb: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Archive (delete) a Linear issue.
|
||||
|
||||
Use this tool when the user asks to delete, remove, or archive a Linear issue.
|
||||
Note that Linear archives issues rather than permanently deleting them
|
||||
(they can be restored from the archive).
|
||||
|
||||
|
||||
Args:
|
||||
issue_ref: The issue to delete. Can be the issue title (e.g. "Fix login bug"),
|
||||
the identifier (e.g. "ENG-42"), or the full document title
|
||||
(e.g. "ENG-42: Fix login bug").
|
||||
delete_from_kb: Whether to also remove the issue from the knowledge base.
|
||||
Default is False. Set to True to remove from both Linear
|
||||
and the knowledge base.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "not_found", or "error"
|
||||
- identifier: Human-readable ID like "ENG-42" (if success)
|
||||
- message: Success or error message
|
||||
- deleted_from_kb: Whether the issue was also removed from the knowledge base (if success)
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined the action.
|
||||
Respond with a brief acknowledgment (e.g., "Understood, I won't delete the issue.")
|
||||
and move on. Do NOT ask for alternatives or troubleshoot.
|
||||
- If status is "not_found", inform the user conversationally using the exact message
|
||||
provided. Do NOT treat this as an error. Simply relay the message and ask the user
|
||||
to verify the issue title or identifier, or check if it has been indexed.
|
||||
Examples:
|
||||
- "Delete the 'Fix login bug' Linear issue"
|
||||
- "Archive ENG-42"
|
||||
- "Remove the 'Old payment flow' issue from Linear"
|
||||
"""
|
||||
logger.info(
|
||||
f"delete_linear_issue called: issue_ref='{issue_ref}', delete_from_kb={delete_from_kb}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
logger.error(
|
||||
"Linear tool not properly configured - missing required parameters"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Linear tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = LinearToolMetadataService(db_session)
|
||||
context = await metadata_service.get_delete_context(
|
||||
search_space_id, user_id, issue_ref
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
logger.warning(f"Auth expired for delete context: {error_msg}")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "linear",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Issue not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
else:
|
||||
logger.error(f"Failed to fetch delete context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
issue_id = context["issue"]["id"]
|
||||
issue_identifier = context["issue"].get("identifier", "")
|
||||
document_id = context["issue"]["document_id"]
|
||||
connector_id_from_context = context.get("workspace", {}).get("id")
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for deleting Linear issue: '{issue_ref}' "
|
||||
f"(id={issue_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="linear_issue_deletion",
|
||||
tool_name="delete_linear_issue",
|
||||
params={
|
||||
"issue_id": issue_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
logger.info("Linear issue deletion rejected by user")
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_issue_id = result.params.get("issue_id", issue_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
logger.info(
|
||||
f"Deleting Linear issue with final params: issue_id={final_issue_id}, "
|
||||
f"connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
|
||||
)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if final_connector_id:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
logger.error(
|
||||
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Linear connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
logger.info(f"Validated Linear connector: id={actual_connector_id}")
|
||||
else:
|
||||
logger.error("No connector found for this issue")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this issue.",
|
||||
}
|
||||
|
||||
linear_client = LinearConnector(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
|
||||
result = await linear_client.archive_issue(issue_id=final_issue_id)
|
||||
|
||||
logger.info(
|
||||
f"archive_issue result: {result.get('status')} - {result.get('message', '')}"
|
||||
)
|
||||
|
||||
deleted_from_kb = False
|
||||
if (
|
||||
result.get("status") == "success"
|
||||
and final_delete_from_kb
|
||||
and document_id
|
||||
):
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
result["warning"] = (
|
||||
f"Issue archived in Linear, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
if result.get("status") == "success":
|
||||
result["deleted_from_kb"] = deleted_from_kb
|
||||
if issue_identifier:
|
||||
result["message"] = (
|
||||
f"Issue {issue_identifier} archived successfully."
|
||||
)
|
||||
if deleted_from_kb:
|
||||
result["message"] = (
|
||||
f"{result.get('message', '')} Also removed from the knowledge base."
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error deleting Linear issue: {e}", exc_info=True)
|
||||
if isinstance(e, ValueError | LinearAPIError):
|
||||
message = str(e)
|
||||
else:
|
||||
message = (
|
||||
"Something went wrong while deleting the issue. Please try again."
|
||||
)
|
||||
return {"status": "error", "message": message}
|
||||
|
||||
return delete_linear_issue
|
||||
|
|
@ -6,29 +6,9 @@ from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
|||
ToolsPermissions,
|
||||
)
|
||||
|
||||
from .create_issue import create_create_linear_issue_tool
|
||||
from .delete_issue import create_delete_linear_issue_tool
|
||||
from .update_issue import create_update_linear_issue_tool
|
||||
|
||||
|
||||
def load_tools(
|
||||
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
|
||||
) -> ToolsPermissions:
|
||||
d = {**(dependencies or {}), **kwargs}
|
||||
common = {
|
||||
"db_session": d["db_session"],
|
||||
"search_space_id": d["search_space_id"],
|
||||
"user_id": d["user_id"],
|
||||
"connector_id": d.get("connector_id"),
|
||||
}
|
||||
create = create_create_linear_issue_tool(**common)
|
||||
update = create_update_linear_issue_tool(**common)
|
||||
delete = create_delete_linear_issue_tool(**common)
|
||||
return {
|
||||
"allow": [],
|
||||
"ask": [
|
||||
{"name": getattr(create, "name", "") or "", "tool": create},
|
||||
{"name": getattr(update, "name", "") or "", "tool": update},
|
||||
{"name": getattr(delete, "name", "") or "", "tool": delete},
|
||||
],
|
||||
}
|
||||
_ = {**(dependencies or {}), **kwargs}
|
||||
return {"allow": [], "ask": []}
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ Defaults:
|
|||
SURFSENSE_ENABLE_PERMISSION=true
|
||||
SURFSENSE_ENABLE_DOOM_LOOP=true
|
||||
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call
|
||||
SURFSENSE_ENABLE_STREAM_PARITY_V2=true
|
||||
|
||||
Master kill-switch (overrides everything else):
|
||||
|
||||
|
|
@ -88,15 +87,6 @@ class AgentFeatureFlags:
|
|||
enable_action_log: bool = True
|
||||
enable_revert_route: bool = True
|
||||
|
||||
# Streaming parity v2 — opt in to LangChain's structured
|
||||
# ``AIMessageChunk`` content (typed reasoning blocks, tool-input
|
||||
# deltas) and propagate the real ``tool_call_id`` to the SSE layer.
|
||||
# When OFF the ``stream_new_chat`` task falls back to the str-only
|
||||
# text path and the synthetic ``call_<run_id>`` tool-call id (no
|
||||
# ``langchainToolCallId`` propagation). Schema migrations 135/136
|
||||
# ship unconditionally because they're forward-compatible.
|
||||
enable_stream_parity_v2: bool = True
|
||||
|
||||
# Plugins
|
||||
enable_plugin_loader: bool = False
|
||||
|
||||
|
|
@ -169,7 +159,6 @@ class AgentFeatureFlags:
|
|||
enable_kb_planner_runnable=False,
|
||||
enable_action_log=False,
|
||||
enable_revert_route=False,
|
||||
enable_stream_parity_v2=False,
|
||||
enable_plugin_loader=False,
|
||||
enable_otel=False,
|
||||
enable_agent_cache=False,
|
||||
|
|
@ -208,10 +197,6 @@ class AgentFeatureFlags:
|
|||
# Snapshot / revert
|
||||
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", True),
|
||||
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", True),
|
||||
# Streaming parity v2
|
||||
enable_stream_parity_v2=_env_bool(
|
||||
"SURFSENSE_ENABLE_STREAM_PARITY_V2", True
|
||||
),
|
||||
# Plugins
|
||||
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
|
||||
# Observability
|
||||
|
|
|
|||
|
|
@ -71,7 +71,10 @@ from app.schemas.new_chat import (
|
|||
TokenUsageSummary,
|
||||
TurnStatusResponse,
|
||||
)
|
||||
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
||||
from app.tasks.chat.stream_new_chat import (
|
||||
stream_new_chat,
|
||||
stream_resume_chat,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.perf import get_perf_logger
|
||||
from app.utils.rbac import check_permission
|
||||
|
|
|
|||
|
|
@ -380,7 +380,7 @@ class ResumeRequest(BaseModel):
|
|||
"/regenerate. Resume reuses the original interrupted user "
|
||||
"turn so the server does not write a new user message. "
|
||||
"Currently unused but accepted to keep request bodies "
|
||||
"uniform across the three streaming entrypoints."
|
||||
"uniform across new-message, regenerate, and resume stream routes."
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -456,6 +456,8 @@ class VercelStreamingService:
|
|||
title: str,
|
||||
status: str = "in_progress",
|
||||
items: list[str] | None = None,
|
||||
*,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format a thinking step for chain-of-thought display (SurfSense specific).
|
||||
|
|
@ -469,15 +471,15 @@ class VercelStreamingService:
|
|||
Returns:
|
||||
str: SSE formatted thinking step data part
|
||||
"""
|
||||
return self.format_data(
|
||||
"thinking-step",
|
||||
{
|
||||
"id": step_id,
|
||||
"title": title,
|
||||
"status": status,
|
||||
"items": items or [],
|
||||
},
|
||||
)
|
||||
payload: dict[str, Any] = {
|
||||
"id": step_id,
|
||||
"title": title,
|
||||
"status": status,
|
||||
"items": items or [],
|
||||
}
|
||||
if metadata:
|
||||
payload["metadata"] = metadata
|
||||
return self.format_data("thinking-step", payload)
|
||||
|
||||
def format_thread_title_update(self, thread_id: int, title: str) -> str:
|
||||
"""
|
||||
|
|
@ -601,6 +603,7 @@ class VercelStreamingService:
|
|||
tool_name: str,
|
||||
*,
|
||||
langchain_tool_call_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format the start of tool input streaming.
|
||||
|
|
@ -608,15 +611,14 @@ class VercelStreamingService:
|
|||
Args:
|
||||
tool_call_id: The unique tool call identifier. May be EITHER the
|
||||
synthetic ``call_<run_id>`` id derived from LangGraph
|
||||
``run_id`` (legacy / ``SURFSENSE_ENABLE_STREAM_PARITY_V2``
|
||||
OFF, or the unmatched-fallback path under parity_v2) OR
|
||||
the authoritative LangChain ``tool_call.id`` (parity_v2
|
||||
path: when the provider streams ``tool_call_chunks`` we
|
||||
register the ``index`` and reuse the lc-id as the card
|
||||
id so live ``tool-input-delta`` events can be routed
|
||||
without a downstream join). Either way, the same id is
|
||||
preserved across ``tool-input-start`` / ``-delta`` /
|
||||
``-available`` / ``tool-output-available`` for one call.
|
||||
``run_id`` (unmatched chunk fallback when no ``index`` was
|
||||
registered) OR the authoritative LangChain ``tool_call.id``
|
||||
(when the provider streams ``tool_call_chunks`` we register
|
||||
the ``index`` and reuse the lc-id as the card id so live
|
||||
``tool-input-delta`` events route without a downstream join).
|
||||
Either way, the same id is preserved across
|
||||
``tool-input-start`` / ``-delta`` / ``-available`` /
|
||||
``tool-output-available`` for one call.
|
||||
tool_name: The name of the tool being called.
|
||||
langchain_tool_call_id: Optional authoritative LangChain
|
||||
``tool_call.id``. When set, surfaces as
|
||||
|
|
@ -636,6 +638,8 @@ class VercelStreamingService:
|
|||
}
|
||||
if langchain_tool_call_id:
|
||||
payload["langchainToolCallId"] = langchain_tool_call_id
|
||||
if metadata:
|
||||
payload["metadata"] = metadata
|
||||
return self._format_sse(payload)
|
||||
|
||||
def format_tool_input_delta(self, tool_call_id: str, input_text_delta: str) -> str:
|
||||
|
|
@ -667,6 +671,7 @@ class VercelStreamingService:
|
|||
input_data: dict[str, Any],
|
||||
*,
|
||||
langchain_tool_call_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format the completion of tool input.
|
||||
|
|
@ -692,6 +697,8 @@ class VercelStreamingService:
|
|||
}
|
||||
if langchain_tool_call_id:
|
||||
payload["langchainToolCallId"] = langchain_tool_call_id
|
||||
if metadata:
|
||||
payload["metadata"] = metadata
|
||||
return self._format_sse(payload)
|
||||
|
||||
def format_tool_output_available(
|
||||
|
|
@ -700,6 +707,7 @@ class VercelStreamingService:
|
|||
output: Any,
|
||||
*,
|
||||
langchain_tool_call_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format tool execution output.
|
||||
|
|
@ -726,6 +734,8 @@ class VercelStreamingService:
|
|||
}
|
||||
if langchain_tool_call_id:
|
||||
payload["langchainToolCallId"] = langchain_tool_call_id
|
||||
if metadata:
|
||||
payload["metadata"] = metadata
|
||||
return self._format_sse(payload)
|
||||
|
||||
# =========================================================================
|
||||
|
|
|
|||
20
surfsense_backend/app/services/streaming/__init__.py
Normal file
20
surfsense_backend/app/services/streaming/__init__.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
"""Single-responsibility split of the streaming SSE protocol.
|
||||
|
||||
Layout:
|
||||
* ``envelope/`` - SSE wire framing + ID generators
|
||||
* ``emitter/`` - identity of the agent that emitted an event + runtime registry
|
||||
* ``events/`` - one module per SSE event family
|
||||
* ``service.py`` - composition root used when emitting chat SSE
|
||||
* ``interrupt_correlation.py`` - id-aware lookup over LangGraph state
|
||||
|
||||
Naming on the wire:
|
||||
* AI SDK protocol fields keep their existing camelCase
|
||||
(``toolCallId``, ``messageId``, ``inputTextDelta``, ``langchainToolCallId``).
|
||||
* Every SurfSense-added field uses ``snake_case``, including the
|
||||
top-level ``emitted_by`` envelope and all inner ``data`` payloads.
|
||||
|
||||
Production chat uses ``app.services.new_streaming_service`` from
|
||||
``app.tasks.chat.stream_new_chat`` and related routes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
29
surfsense_backend/app/services/streaming/emitter/__init__.py
Normal file
29
surfsense_backend/app/services/streaming/emitter/__init__.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
"""Identity of the agent that emitted a streamed event.
|
||||
|
||||
The wire field is ``emitted_by``; the Python identity is :class:`Emitter`.
|
||||
``EmitterRegistry`` resolves which emitter owns a LangGraph event, with
|
||||
LangGraph's own namespace metadata as the primary key and a parent_ids
|
||||
walk as a fallback for cases where context vars don't propagate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .emitter import (
|
||||
MAIN_EMITTER,
|
||||
Emitter,
|
||||
EmitterLevel,
|
||||
attach_emitted_by,
|
||||
main_emitter,
|
||||
subagent_emitter,
|
||||
)
|
||||
from .registry import EmitterRegistry
|
||||
|
||||
__all__ = [
|
||||
"MAIN_EMITTER",
|
||||
"Emitter",
|
||||
"EmitterLevel",
|
||||
"EmitterRegistry",
|
||||
"attach_emitted_by",
|
||||
"main_emitter",
|
||||
"subagent_emitter",
|
||||
]
|
||||
61
surfsense_backend/app/services/streaming/emitter/emitter.py
Normal file
61
surfsense_backend/app/services/streaming/emitter/emitter.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
"""Identity payload describing which agent produced a stream event."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
EmitterLevel = Literal["main", "subagent"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Emitter:
|
||||
level: EmitterLevel
|
||||
subagent_type: str | None = None
|
||||
subagent_run_id: str | None = None
|
||||
parent_tool_call_id: str | None = None
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_payload(self) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {"level": self.level}
|
||||
if self.subagent_type is not None:
|
||||
payload["subagent_type"] = self.subagent_type
|
||||
if self.subagent_run_id is not None:
|
||||
payload["subagent_run_id"] = self.subagent_run_id
|
||||
if self.parent_tool_call_id is not None:
|
||||
payload["parent_tool_call_id"] = self.parent_tool_call_id
|
||||
if self.extra:
|
||||
payload.update(self.extra)
|
||||
return payload
|
||||
|
||||
|
||||
MAIN_EMITTER = Emitter(level="main")
|
||||
|
||||
|
||||
def main_emitter() -> Emitter:
|
||||
return MAIN_EMITTER
|
||||
|
||||
|
||||
def subagent_emitter(
|
||||
*,
|
||||
subagent_type: str,
|
||||
subagent_run_id: str,
|
||||
parent_tool_call_id: str | None = None,
|
||||
extra: dict[str, Any] | None = None,
|
||||
) -> Emitter:
|
||||
return Emitter(
|
||||
level="subagent",
|
||||
subagent_type=subagent_type,
|
||||
subagent_run_id=subagent_run_id,
|
||||
parent_tool_call_id=parent_tool_call_id,
|
||||
extra=dict(extra or {}),
|
||||
)
|
||||
|
||||
|
||||
def attach_emitted_by(
|
||||
payload: dict[str, Any], emitter: Emitter | None
|
||||
) -> dict[str, Any]:
|
||||
if emitter is None:
|
||||
return payload
|
||||
payload["emitted_by"] = emitter.to_payload()
|
||||
return payload
|
||||
51
surfsense_backend/app/services/streaming/emitter/registry.py
Normal file
51
surfsense_backend/app/services/streaming/emitter/registry.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
"""Resolve which agent owns a streamed event from its LangGraph run lineage."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
from .emitter import Emitter, main_emitter
|
||||
|
||||
|
||||
class EmitterRegistry:
|
||||
def __init__(self) -> None:
|
||||
self._by_run_id: dict[str, Emitter] = {}
|
||||
|
||||
def register(self, run_id: str, emitter: Emitter) -> None:
|
||||
if not run_id:
|
||||
return
|
||||
self._by_run_id[run_id] = emitter
|
||||
|
||||
def unregister(self, run_id: str) -> Emitter | None:
|
||||
if not run_id:
|
||||
return None
|
||||
return self._by_run_id.pop(run_id, None)
|
||||
|
||||
def get(self, run_id: str | None) -> Emitter | None:
|
||||
if not run_id:
|
||||
return None
|
||||
return self._by_run_id.get(run_id)
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
*,
|
||||
run_id: str | None,
|
||||
parent_ids: Iterable[str] | None,
|
||||
) -> Emitter:
|
||||
own = self.get(run_id)
|
||||
if own is not None:
|
||||
return own
|
||||
if parent_ids:
|
||||
for ancestor in reversed(list(parent_ids)):
|
||||
emitter = self.get(ancestor)
|
||||
if emitter is not None:
|
||||
return emitter
|
||||
return main_emitter()
|
||||
|
||||
def has_active_subagents(self) -> bool:
|
||||
return any(
|
||||
emitter.level == "subagent" for emitter in self._by_run_id.values()
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
self._by_run_id.clear()
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
"""Wire framing layer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .identifiers import (
|
||||
generate_message_id,
|
||||
generate_reasoning_id,
|
||||
generate_subagent_run_id,
|
||||
generate_text_id,
|
||||
generate_tool_call_id,
|
||||
)
|
||||
from .sse import format_done, format_sse, get_response_headers
|
||||
|
||||
__all__ = [
|
||||
"format_done",
|
||||
"format_sse",
|
||||
"generate_message_id",
|
||||
"generate_reasoning_id",
|
||||
"generate_subagent_run_id",
|
||||
"generate_text_id",
|
||||
"generate_tool_call_id",
|
||||
"get_response_headers",
|
||||
]
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
"""Prefixed UUID generators for stream parts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
|
||||
def generate_message_id() -> str:
|
||||
return f"msg_{uuid.uuid4().hex}"
|
||||
|
||||
|
||||
def generate_text_id() -> str:
|
||||
return f"text_{uuid.uuid4().hex}"
|
||||
|
||||
|
||||
def generate_reasoning_id() -> str:
|
||||
return f"reasoning_{uuid.uuid4().hex}"
|
||||
|
||||
|
||||
def generate_tool_call_id() -> str:
|
||||
return f"call_{uuid.uuid4().hex}"
|
||||
|
||||
|
||||
def generate_subagent_run_id() -> str:
|
||||
return f"subagent_{uuid.uuid4().hex}"
|
||||
25
surfsense_backend/app/services/streaming/envelope/sse.py
Normal file
25
surfsense_backend/app/services/streaming/envelope/sse.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
"""Server-Sent-Events wire framing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
||||
def format_sse(data: Any) -> str:
|
||||
if isinstance(data, str):
|
||||
return f"data: {data}\n\n"
|
||||
return f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
def format_done() -> str:
|
||||
return "data: [DONE]\n\n"
|
||||
|
||||
|
||||
def get_response_headers() -> dict[str, str]:
|
||||
return {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
}
|
||||
29
surfsense_backend/app/services/streaming/events/__init__.py
Normal file
29
surfsense_backend/app/services/streaming/events/__init__.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
"""SSE event payload formatters, one module per event family."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from . import (
|
||||
action_log,
|
||||
data,
|
||||
error,
|
||||
interrupt,
|
||||
lifecycle,
|
||||
reasoning,
|
||||
source,
|
||||
subagent_lifecycle,
|
||||
text,
|
||||
tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"action_log",
|
||||
"data",
|
||||
"error",
|
||||
"interrupt",
|
||||
"lifecycle",
|
||||
"reasoning",
|
||||
"source",
|
||||
"subagent_lifecycle",
|
||||
"text",
|
||||
"tool",
|
||||
]
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
"""Action-log events relayed from ``ActionLogMiddleware`` custom dispatches."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ..emitter import Emitter
|
||||
from .data import format_data
|
||||
|
||||
|
||||
def format_action_log(
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return format_data("action-log", payload, emitter=emitter)
|
||||
|
||||
|
||||
def format_action_log_updated(
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return format_data("action-log-updated", payload, emitter=emitter)
|
||||
118
surfsense_backend/app/services/streaming/events/data.py
Normal file
118
surfsense_backend/app/services/streaming/events/data.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
"""Generic ``data-*`` envelopes and SurfSense-specific data parts.
|
||||
|
||||
Inner ``data`` dict fields use snake_case. Legacy ``threadId`` /
|
||||
``messageId`` keys are preserved where they cross the AI SDK boundary.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ..emitter import Emitter, attach_emitted_by
|
||||
from ..envelope import format_sse
|
||||
|
||||
|
||||
def format_data(
|
||||
data_type: str,
|
||||
data: Any,
|
||||
*,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
payload: dict[str, Any] = {"type": f"data-{data_type}", "data": data}
|
||||
return format_sse(attach_emitted_by(payload, emitter))
|
||||
|
||||
|
||||
def format_terminal_info(
|
||||
text: str,
|
||||
*,
|
||||
message_type: str = "info",
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return format_data(
|
||||
"terminal-info",
|
||||
{"text": text, "type": message_type},
|
||||
emitter=emitter,
|
||||
)
|
||||
|
||||
|
||||
def format_further_questions(
|
||||
questions: list[str],
|
||||
*,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return format_data("further-questions", {"questions": questions}, emitter=emitter)
|
||||
|
||||
|
||||
def format_thinking_step(
|
||||
*,
|
||||
step_id: str,
|
||||
title: str,
|
||||
status: str = "in_progress",
|
||||
items: list[str] | None = None,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return format_data(
|
||||
"thinking-step",
|
||||
{
|
||||
"id": step_id,
|
||||
"title": title,
|
||||
"status": status,
|
||||
"items": items or [],
|
||||
},
|
||||
emitter=emitter,
|
||||
)
|
||||
|
||||
|
||||
def format_thread_title_update(
|
||||
*,
|
||||
thread_id: int,
|
||||
title: str,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return format_data(
|
||||
"thread-title-update",
|
||||
{"threadId": thread_id, "title": title},
|
||||
emitter=emitter,
|
||||
)
|
||||
|
||||
|
||||
def format_turn_info(
|
||||
*,
|
||||
chat_turn_id: str,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return format_data("turn-info", {"chat_turn_id": chat_turn_id}, emitter=emitter)
|
||||
|
||||
|
||||
def format_turn_status(
|
||||
*,
|
||||
status: str,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return format_data("turn-status", {"status": status}, emitter=emitter)
|
||||
|
||||
|
||||
def format_user_message_id(
|
||||
*,
|
||||
message_id: str,
|
||||
turn_id: str,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return format_data(
|
||||
"user-message-id",
|
||||
{"message_id": message_id, "turn_id": turn_id},
|
||||
emitter=emitter,
|
||||
)
|
||||
|
||||
|
||||
def format_assistant_message_id(
|
||||
*,
|
||||
message_id: str,
|
||||
turn_id: str,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return format_data(
|
||||
"assistant-message-id",
|
||||
{"message_id": message_id, "turn_id": turn_id},
|
||||
emitter=emitter,
|
||||
)
|
||||
23
surfsense_backend/app/services/streaming/events/error.py
Normal file
23
surfsense_backend/app/services/streaming/events/error.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
"""Single terminal error path chat streaming must route through."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ..emitter import Emitter, attach_emitted_by
|
||||
from ..envelope import format_sse
|
||||
|
||||
|
||||
def format_error(
|
||||
error_text: str,
|
||||
*,
|
||||
error_code: str | None = None,
|
||||
extra: dict[str, Any] | None = None,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
payload: dict[str, Any] = {"type": "error", "errorText": error_text}
|
||||
if error_code:
|
||||
payload["errorCode"] = error_code
|
||||
if extra:
|
||||
payload.update(extra)
|
||||
return format_sse(attach_emitted_by(payload, emitter))
|
||||
56
surfsense_backend/app/services/streaming/events/interrupt.py
Normal file
56
surfsense_backend/app/services/streaming/events/interrupt.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
"""Interrupt-request events with a single canonical payload shape."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ..emitter import Emitter
|
||||
from .data import format_data
|
||||
|
||||
|
||||
def normalize_interrupt_payload(interrupt_value: dict[str, Any]) -> dict[str, Any]:
|
||||
if "action_requests" in interrupt_value and "review_configs" in interrupt_value:
|
||||
return interrupt_value
|
||||
|
||||
interrupt_type = interrupt_value.get("type", "unknown")
|
||||
message = interrupt_value.get("message")
|
||||
action = interrupt_value.get("action", {}) or {}
|
||||
context = interrupt_value.get("context", {}) or {}
|
||||
|
||||
normalized: dict[str, Any] = {
|
||||
"action_requests": [
|
||||
{
|
||||
"name": action.get("tool", "unknown_tool"),
|
||||
"args": action.get("params", {}),
|
||||
}
|
||||
],
|
||||
"review_configs": [
|
||||
{
|
||||
"action_name": action.get("tool", "unknown_tool"),
|
||||
"allowed_decisions": ["approve", "edit", "reject"],
|
||||
}
|
||||
],
|
||||
"interrupt_type": interrupt_type,
|
||||
"context": context,
|
||||
}
|
||||
if message:
|
||||
normalized["message"] = message
|
||||
return normalized
|
||||
|
||||
|
||||
def format_interrupt_request(
|
||||
interrupt_value: dict[str, Any],
|
||||
*,
|
||||
interrupt_id: str | None = None,
|
||||
pending_interrupt_count: int | None = None,
|
||||
chat_turn_id: str | None = None,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
payload = normalize_interrupt_payload(interrupt_value)
|
||||
if interrupt_id is not None:
|
||||
payload["interrupt_id"] = interrupt_id
|
||||
if pending_interrupt_count is not None:
|
||||
payload["pending_interrupt_count"] = pending_interrupt_count
|
||||
if chat_turn_id is not None:
|
||||
payload["chat_turn_id"] = chat_turn_id
|
||||
return format_data("interrupt-request", payload, emitter=emitter)
|
||||
29
surfsense_backend/app/services/streaming/events/lifecycle.py
Normal file
29
surfsense_backend/app/services/streaming/events/lifecycle.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
"""High-level message and step lifecycle events.
|
||||
|
||||
Wire verbs are fixed by the AI SDK protocol (``start`` / ``finish`` for
|
||||
the whole message, ``start-step`` / ``finish-step`` for each step).
|
||||
Python helpers always read ``format_<entity>_<verb>`` so pairs are
|
||||
visible at the call site.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ..emitter import Emitter, attach_emitted_by
|
||||
from ..envelope import format_sse
|
||||
|
||||
|
||||
def format_message_start(message_id: str, *, emitter: Emitter | None = None) -> str:
|
||||
payload = {"type": "start", "messageId": message_id}
|
||||
return format_sse(attach_emitted_by(payload, emitter))
|
||||
|
||||
|
||||
def format_message_finish(*, emitter: Emitter | None = None) -> str:
|
||||
return format_sse(attach_emitted_by({"type": "finish"}, emitter))
|
||||
|
||||
|
||||
def format_step_start(*, emitter: Emitter | None = None) -> str:
|
||||
return format_sse(attach_emitted_by({"type": "start-step"}, emitter))
|
||||
|
||||
|
||||
def format_step_finish(*, emitter: Emitter | None = None) -> str:
|
||||
return format_sse(attach_emitted_by({"type": "finish-step"}, emitter))
|
||||
36
surfsense_backend/app/services/streaming/events/reasoning.py
Normal file
36
surfsense_backend/app/services/streaming/events/reasoning.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
"""Reasoning-block streaming events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ..emitter import Emitter, attach_emitted_by
|
||||
from ..envelope import format_sse
|
||||
|
||||
|
||||
def format_reasoning_start(
|
||||
reasoning_id: str, *, emitter: Emitter | None = None
|
||||
) -> str:
|
||||
return format_sse(
|
||||
attach_emitted_by({"type": "reasoning-start", "id": reasoning_id}, emitter)
|
||||
)
|
||||
|
||||
|
||||
def format_reasoning_delta(
|
||||
reasoning_id: str,
|
||||
delta: str,
|
||||
*,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return format_sse(
|
||||
attach_emitted_by(
|
||||
{"type": "reasoning-delta", "id": reasoning_id, "delta": delta},
|
||||
emitter,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def format_reasoning_end(
|
||||
reasoning_id: str, *, emitter: Emitter | None = None
|
||||
) -> str:
|
||||
return format_sse(
|
||||
attach_emitted_by({"type": "reasoning-end", "id": reasoning_id}, emitter)
|
||||
)
|
||||
59
surfsense_backend/app/services/streaming/events/source.py
Normal file
59
surfsense_backend/app/services/streaming/events/source.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
"""Source and file reference events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ..emitter import Emitter, attach_emitted_by
|
||||
from ..envelope import format_sse
|
||||
|
||||
|
||||
def format_source_url(
|
||||
url: str,
|
||||
*,
|
||||
source_id: str | None = None,
|
||||
title: str | None = None,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
payload: dict[str, Any] = {
|
||||
"type": "source-url",
|
||||
"sourceId": source_id or url,
|
||||
"url": url,
|
||||
}
|
||||
if title:
|
||||
payload["title"] = title
|
||||
return format_sse(attach_emitted_by(payload, emitter))
|
||||
|
||||
|
||||
def format_source_document(
|
||||
source_id: str,
|
||||
*,
|
||||
media_type: str = "file",
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
payload: dict[str, Any] = {
|
||||
"type": "source-document",
|
||||
"sourceId": source_id,
|
||||
"mediaType": media_type,
|
||||
}
|
||||
if title:
|
||||
payload["title"] = title
|
||||
if description:
|
||||
payload["description"] = description
|
||||
return format_sse(attach_emitted_by(payload, emitter))
|
||||
|
||||
|
||||
def format_file(
|
||||
url: str,
|
||||
media_type: str,
|
||||
*,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
payload: dict[str, Any] = {
|
||||
"type": "file",
|
||||
"url": url,
|
||||
"mediaType": media_type,
|
||||
}
|
||||
return format_sse(attach_emitted_by(payload, emitter))
|
||||
|
|
@ -0,0 +1,86 @@
|
|||
"""Sub-agent lifecycle events the FE pairs into one timeline lane.
|
||||
|
||||
A sub-agent run is a high-level boundary (a whole agent invocation),
|
||||
so we use the ``start`` / ``finish`` verb pair, matching how the AI SDK
|
||||
spells message- and step-level lifecycles.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ..emitter import Emitter
|
||||
from .data import format_data
|
||||
|
||||
|
||||
def format_subagent_start(
|
||||
*,
|
||||
subagent_run_id: str,
|
||||
subagent_type: str,
|
||||
parent_tool_call_id: str,
|
||||
chat_turn_id: str | None = None,
|
||||
description: str | None = None,
|
||||
started_at: str | None = None,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
payload: dict[str, Any] = {
|
||||
"subagent_run_id": subagent_run_id,
|
||||
"subagent_type": subagent_type,
|
||||
"parent_tool_call_id": parent_tool_call_id,
|
||||
}
|
||||
if chat_turn_id is not None:
|
||||
payload["chat_turn_id"] = chat_turn_id
|
||||
if description is not None:
|
||||
payload["description"] = description
|
||||
if started_at is not None:
|
||||
payload["started_at"] = started_at
|
||||
return format_data("subagent-start", payload, emitter=emitter)
|
||||
|
||||
|
||||
def format_subagent_finish(
|
||||
*,
|
||||
subagent_run_id: str,
|
||||
subagent_type: str,
|
||||
parent_tool_call_id: str,
|
||||
status: str = "completed",
|
||||
ended_at: str | None = None,
|
||||
duration_ms: int | None = None,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
payload: dict[str, Any] = {
|
||||
"subagent_run_id": subagent_run_id,
|
||||
"subagent_type": subagent_type,
|
||||
"parent_tool_call_id": parent_tool_call_id,
|
||||
"status": status,
|
||||
}
|
||||
if ended_at is not None:
|
||||
payload["ended_at"] = ended_at
|
||||
if duration_ms is not None:
|
||||
payload["duration_ms"] = duration_ms
|
||||
return format_data("subagent-finish", payload, emitter=emitter)
|
||||
|
||||
|
||||
def format_subagent_error(
|
||||
*,
|
||||
subagent_run_id: str,
|
||||
subagent_type: str,
|
||||
parent_tool_call_id: str,
|
||||
error_text: str,
|
||||
error_type: str | None = None,
|
||||
ended_at: str | None = None,
|
||||
duration_ms: int | None = None,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
payload: dict[str, Any] = {
|
||||
"subagent_run_id": subagent_run_id,
|
||||
"subagent_type": subagent_type,
|
||||
"parent_tool_call_id": parent_tool_call_id,
|
||||
"error_text": error_text,
|
||||
}
|
||||
if error_type is not None:
|
||||
payload["error_type"] = error_type
|
||||
if ended_at is not None:
|
||||
payload["ended_at"] = ended_at
|
||||
if duration_ms is not None:
|
||||
payload["duration_ms"] = duration_ms
|
||||
return format_data("subagent-error", payload, emitter=emitter)
|
||||
31
surfsense_backend/app/services/streaming/events/text.py
Normal file
31
surfsense_backend/app/services/streaming/events/text.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
"""Text-block streaming events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ..emitter import Emitter, attach_emitted_by
|
||||
from ..envelope import format_sse
|
||||
|
||||
|
||||
def format_text_start(text_id: str, *, emitter: Emitter | None = None) -> str:
|
||||
return format_sse(
|
||||
attach_emitted_by({"type": "text-start", "id": text_id}, emitter)
|
||||
)
|
||||
|
||||
|
||||
def format_text_delta(
|
||||
text_id: str,
|
||||
delta: str,
|
||||
*,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return format_sse(
|
||||
attach_emitted_by(
|
||||
{"type": "text-delta", "id": text_id, "delta": delta}, emitter
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def format_text_end(text_id: str, *, emitter: Emitter | None = None) -> str:
|
||||
return format_sse(
|
||||
attach_emitted_by({"type": "text-end", "id": text_id}, emitter)
|
||||
)
|
||||
80
surfsense_backend/app/services/streaming/events/tool.py
Normal file
80
surfsense_backend/app/services/streaming/events/tool.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
"""Tool-call streaming events.
|
||||
|
||||
``toolCallId`` and ``langchainToolCallId`` are AI SDK protocol fields
|
||||
and stay camelCase. Sub-agent provenance rides on the snake_case
|
||||
top-level ``emitted_by`` envelope added by :func:`attach_emitted_by`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ..emitter import Emitter, attach_emitted_by
|
||||
from ..envelope import format_sse
|
||||
|
||||
|
||||
def format_tool_input_start(
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
*,
|
||||
langchain_tool_call_id: str | None = None,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
payload: dict[str, Any] = {
|
||||
"type": "tool-input-start",
|
||||
"toolCallId": tool_call_id,
|
||||
"toolName": tool_name,
|
||||
}
|
||||
if langchain_tool_call_id:
|
||||
payload["langchainToolCallId"] = langchain_tool_call_id
|
||||
return format_sse(attach_emitted_by(payload, emitter))
|
||||
|
||||
|
||||
def format_tool_input_delta(
|
||||
tool_call_id: str,
|
||||
input_text_delta: str,
|
||||
*,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
payload: dict[str, Any] = {
|
||||
"type": "tool-input-delta",
|
||||
"toolCallId": tool_call_id,
|
||||
"inputTextDelta": input_text_delta,
|
||||
}
|
||||
return format_sse(attach_emitted_by(payload, emitter))
|
||||
|
||||
|
||||
def format_tool_input_available(
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
input_data: dict[str, Any],
|
||||
*,
|
||||
langchain_tool_call_id: str | None = None,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
payload: dict[str, Any] = {
|
||||
"type": "tool-input-available",
|
||||
"toolCallId": tool_call_id,
|
||||
"toolName": tool_name,
|
||||
"input": input_data,
|
||||
}
|
||||
if langchain_tool_call_id:
|
||||
payload["langchainToolCallId"] = langchain_tool_call_id
|
||||
return format_sse(attach_emitted_by(payload, emitter))
|
||||
|
||||
|
||||
def format_tool_output_available(
|
||||
tool_call_id: str,
|
||||
output: Any,
|
||||
*,
|
||||
langchain_tool_call_id: str | None = None,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
payload: dict[str, Any] = {
|
||||
"type": "tool-output-available",
|
||||
"toolCallId": tool_call_id,
|
||||
"output": output,
|
||||
}
|
||||
if langchain_tool_call_id:
|
||||
payload["langchainToolCallId"] = langchain_tool_call_id
|
||||
return format_sse(attach_emitted_by(payload, emitter))
|
||||
|
|
@ -0,0 +1,84 @@
|
|||
"""Id-aware lookup of pending LangGraph interrupts (replaces first-wins)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PendingInterrupt:
|
||||
interrupt_id: str | None
|
||||
value: dict[str, Any]
|
||||
source_task_id: str | None = None
|
||||
|
||||
|
||||
def list_pending_interrupts(state: Any) -> list[PendingInterrupt]:
|
||||
out: list[PendingInterrupt] = []
|
||||
|
||||
for task in getattr(state, "tasks", None) or ():
|
||||
task_id = _safe_str(getattr(task, "id", None))
|
||||
for it in getattr(task, "interrupts", None) or ():
|
||||
value = _coerce_interrupt_value(it)
|
||||
if value is None:
|
||||
continue
|
||||
interrupt_id = _safe_str(getattr(it, "id", None))
|
||||
out.append(
|
||||
PendingInterrupt(
|
||||
interrupt_id=interrupt_id,
|
||||
value=value,
|
||||
source_task_id=task_id,
|
||||
)
|
||||
)
|
||||
|
||||
for it in getattr(state, "interrupts", None) or ():
|
||||
value = _coerce_interrupt_value(it)
|
||||
if value is None:
|
||||
continue
|
||||
interrupt_id = _safe_str(getattr(it, "id", None))
|
||||
out.append(PendingInterrupt(interrupt_id=interrupt_id, value=value))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def get_pending_interrupt_by_id(
|
||||
state: Any, interrupt_id: str
|
||||
) -> PendingInterrupt | None:
|
||||
for pending in list_pending_interrupts(state):
|
||||
if pending.interrupt_id == interrupt_id:
|
||||
return pending
|
||||
return None
|
||||
|
||||
|
||||
def get_pending_interrupt_for_tool_call(
|
||||
state: Any, tool_call_id: str
|
||||
) -> PendingInterrupt | None:
|
||||
for pending in list_pending_interrupts(state):
|
||||
actions = pending.value.get("action_requests")
|
||||
if not isinstance(actions, list):
|
||||
continue
|
||||
for action in actions:
|
||||
if not isinstance(action, dict):
|
||||
continue
|
||||
if action.get("tool_call_id") == tool_call_id:
|
||||
return pending
|
||||
return None
|
||||
|
||||
|
||||
def first_pending_interrupt(state: Any) -> PendingInterrupt | None:
|
||||
"""Explicit opt-in to legacy first-wins; prefer the id-aware helpers above."""
|
||||
pending = list_pending_interrupts(state)
|
||||
return pending[0] if pending else None
|
||||
|
||||
|
||||
def _coerce_interrupt_value(item: Any) -> dict[str, Any] | None:
|
||||
if isinstance(item, dict):
|
||||
return item if item else None
|
||||
value = getattr(item, "value", None)
|
||||
if isinstance(value, dict):
|
||||
return value if value else None
|
||||
return None
|
||||
|
||||
|
||||
def _safe_str(value: Any) -> str | None:
|
||||
return value if isinstance(value, str) and value else None
|
||||
414
surfsense_backend/app/services/streaming/service.py
Normal file
414
surfsense_backend/app/services/streaming/service.py
Normal file
|
|
@ -0,0 +1,414 @@
|
|||
"""Composition root: bundles every formatter + a per-invocation emitter registry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from . import envelope
|
||||
from .emitter import Emitter, EmitterRegistry
|
||||
from .events import (
|
||||
action_log,
|
||||
data,
|
||||
error,
|
||||
interrupt,
|
||||
lifecycle,
|
||||
reasoning,
|
||||
source,
|
||||
subagent_lifecycle,
|
||||
text,
|
||||
tool,
|
||||
)
|
||||
|
||||
|
||||
class StreamingService:
|
||||
def __init__(self) -> None:
|
||||
self._message_id: str | None = None
|
||||
self.emitter_registry = EmitterRegistry()
|
||||
|
||||
@property
|
||||
def message_id(self) -> str | None:
|
||||
return self._message_id
|
||||
|
||||
def begin_message(self, message_id: str | None = None) -> str:
|
||||
self._message_id = message_id or envelope.generate_message_id()
|
||||
return self._message_id
|
||||
|
||||
@staticmethod
|
||||
def generate_text_id() -> str:
|
||||
return envelope.generate_text_id()
|
||||
|
||||
@staticmethod
|
||||
def generate_reasoning_id() -> str:
|
||||
return envelope.generate_reasoning_id()
|
||||
|
||||
@staticmethod
|
||||
def generate_tool_call_id() -> str:
|
||||
return envelope.generate_tool_call_id()
|
||||
|
||||
@staticmethod
|
||||
def generate_subagent_run_id() -> str:
|
||||
return envelope.generate_subagent_run_id()
|
||||
|
||||
@staticmethod
|
||||
def get_response_headers() -> dict[str, str]:
|
||||
return envelope.get_response_headers()
|
||||
|
||||
@staticmethod
|
||||
def format_done() -> str:
|
||||
return envelope.format_done()
|
||||
|
||||
def resolve_emitter(
|
||||
self,
|
||||
*,
|
||||
run_id: str | None,
|
||||
parent_ids: Iterable[str] | None,
|
||||
) -> Emitter:
|
||||
return self.emitter_registry.resolve(run_id=run_id, parent_ids=parent_ids)
|
||||
|
||||
def format_message_start(
|
||||
self,
|
||||
message_id: str | None = None,
|
||||
*,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
chosen = self.begin_message(message_id)
|
||||
return lifecycle.format_message_start(chosen, emitter=emitter)
|
||||
|
||||
def format_message_finish(self, *, emitter: Emitter | None = None) -> str:
|
||||
return lifecycle.format_message_finish(emitter=emitter)
|
||||
|
||||
def format_step_start(self, *, emitter: Emitter | None = None) -> str:
|
||||
return lifecycle.format_step_start(emitter=emitter)
|
||||
|
||||
def format_step_finish(self, *, emitter: Emitter | None = None) -> str:
|
||||
return lifecycle.format_step_finish(emitter=emitter)
|
||||
|
||||
def format_text_start(
|
||||
self, text_id: str, *, emitter: Emitter | None = None
|
||||
) -> str:
|
||||
return text.format_text_start(text_id, emitter=emitter)
|
||||
|
||||
def format_text_delta(
|
||||
self, text_id: str, delta: str, *, emitter: Emitter | None = None
|
||||
) -> str:
|
||||
return text.format_text_delta(text_id, delta, emitter=emitter)
|
||||
|
||||
def format_text_end(
|
||||
self, text_id: str, *, emitter: Emitter | None = None
|
||||
) -> str:
|
||||
return text.format_text_end(text_id, emitter=emitter)
|
||||
|
||||
def format_reasoning_start(
|
||||
self, reasoning_id: str, *, emitter: Emitter | None = None
|
||||
) -> str:
|
||||
return reasoning.format_reasoning_start(reasoning_id, emitter=emitter)
|
||||
|
||||
def format_reasoning_delta(
|
||||
self,
|
||||
reasoning_id: str,
|
||||
delta: str,
|
||||
*,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return reasoning.format_reasoning_delta(reasoning_id, delta, emitter=emitter)
|
||||
|
||||
def format_reasoning_end(
|
||||
self, reasoning_id: str, *, emitter: Emitter | None = None
|
||||
) -> str:
|
||||
return reasoning.format_reasoning_end(reasoning_id, emitter=emitter)
|
||||
|
||||
def format_tool_input_start(
|
||||
self,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
*,
|
||||
langchain_tool_call_id: str | None = None,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return tool.format_tool_input_start(
|
||||
tool_call_id,
|
||||
tool_name,
|
||||
langchain_tool_call_id=langchain_tool_call_id,
|
||||
emitter=emitter,
|
||||
)
|
||||
|
||||
def format_tool_input_delta(
|
||||
self,
|
||||
tool_call_id: str,
|
||||
input_text_delta: str,
|
||||
*,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return tool.format_tool_input_delta(
|
||||
tool_call_id, input_text_delta, emitter=emitter
|
||||
)
|
||||
|
||||
def format_tool_input_available(
|
||||
self,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
input_data: dict[str, Any],
|
||||
*,
|
||||
langchain_tool_call_id: str | None = None,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return tool.format_tool_input_available(
|
||||
tool_call_id,
|
||||
tool_name,
|
||||
input_data,
|
||||
langchain_tool_call_id=langchain_tool_call_id,
|
||||
emitter=emitter,
|
||||
)
|
||||
|
||||
def format_tool_output_available(
|
||||
self,
|
||||
tool_call_id: str,
|
||||
output: Any,
|
||||
*,
|
||||
langchain_tool_call_id: str | None = None,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return tool.format_tool_output_available(
|
||||
tool_call_id,
|
||||
output,
|
||||
langchain_tool_call_id=langchain_tool_call_id,
|
||||
emitter=emitter,
|
||||
)
|
||||
|
||||
def format_source_url(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
source_id: str | None = None,
|
||||
title: str | None = None,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return source.format_source_url(
|
||||
url, source_id=source_id, title=title, emitter=emitter
|
||||
)
|
||||
|
||||
def format_source_document(
|
||||
self,
|
||||
source_id: str,
|
||||
*,
|
||||
media_type: str = "file",
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return source.format_source_document(
|
||||
source_id,
|
||||
media_type=media_type,
|
||||
title=title,
|
||||
description=description,
|
||||
emitter=emitter,
|
||||
)
|
||||
|
||||
def format_file(
|
||||
self, url: str, media_type: str, *, emitter: Emitter | None = None
|
||||
) -> str:
|
||||
return source.format_file(url, media_type, emitter=emitter)
|
||||
|
||||
def format_data(
|
||||
self, data_type: str, payload: Any, *, emitter: Emitter | None = None
|
||||
) -> str:
|
||||
return data.format_data(data_type, payload, emitter=emitter)
|
||||
|
||||
def format_terminal_info(
|
||||
self,
|
||||
text_value: str,
|
||||
*,
|
||||
message_type: str = "info",
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return data.format_terminal_info(
|
||||
text_value, message_type=message_type, emitter=emitter
|
||||
)
|
||||
|
||||
def format_further_questions(
|
||||
self,
|
||||
questions: list[str],
|
||||
*,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return data.format_further_questions(questions, emitter=emitter)
|
||||
|
||||
def format_thinking_step(
|
||||
self,
|
||||
*,
|
||||
step_id: str,
|
||||
title: str,
|
||||
status: str = "in_progress",
|
||||
items: list[str] | None = None,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return data.format_thinking_step(
|
||||
step_id=step_id,
|
||||
title=title,
|
||||
status=status,
|
||||
items=items,
|
||||
emitter=emitter,
|
||||
)
|
||||
|
||||
def format_thread_title_update(
|
||||
self,
|
||||
*,
|
||||
thread_id: int,
|
||||
title: str,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return data.format_thread_title_update(
|
||||
thread_id=thread_id, title=title, emitter=emitter
|
||||
)
|
||||
|
||||
def format_turn_info(
|
||||
self,
|
||||
*,
|
||||
chat_turn_id: str,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return data.format_turn_info(chat_turn_id=chat_turn_id, emitter=emitter)
|
||||
|
||||
def format_turn_status(
|
||||
self,
|
||||
*,
|
||||
status: str,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return data.format_turn_status(status=status, emitter=emitter)
|
||||
|
||||
def format_user_message_id(
|
||||
self,
|
||||
*,
|
||||
message_id: str,
|
||||
turn_id: str,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return data.format_user_message_id(
|
||||
message_id=message_id, turn_id=turn_id, emitter=emitter
|
||||
)
|
||||
|
||||
def format_assistant_message_id(
|
||||
self,
|
||||
*,
|
||||
message_id: str,
|
||||
turn_id: str,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return data.format_assistant_message_id(
|
||||
message_id=message_id, turn_id=turn_id, emitter=emitter
|
||||
)
|
||||
|
||||
def format_error(
|
||||
self,
|
||||
error_text: str,
|
||||
*,
|
||||
error_code: str | None = None,
|
||||
extra: dict[str, Any] | None = None,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return error.format_error(
|
||||
error_text,
|
||||
error_code=error_code,
|
||||
extra=extra,
|
||||
emitter=emitter,
|
||||
)
|
||||
|
||||
def format_interrupt_request(
|
||||
self,
|
||||
interrupt_value: dict[str, Any],
|
||||
*,
|
||||
interrupt_id: str | None = None,
|
||||
pending_interrupt_count: int | None = None,
|
||||
chat_turn_id: str | None = None,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return interrupt.format_interrupt_request(
|
||||
interrupt_value,
|
||||
interrupt_id=interrupt_id,
|
||||
pending_interrupt_count=pending_interrupt_count,
|
||||
chat_turn_id=chat_turn_id,
|
||||
emitter=emitter,
|
||||
)
|
||||
|
||||
def format_subagent_start(
|
||||
self,
|
||||
*,
|
||||
subagent_run_id: str,
|
||||
subagent_type: str,
|
||||
parent_tool_call_id: str,
|
||||
chat_turn_id: str | None = None,
|
||||
description: str | None = None,
|
||||
started_at: str | None = None,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return subagent_lifecycle.format_subagent_start(
|
||||
subagent_run_id=subagent_run_id,
|
||||
subagent_type=subagent_type,
|
||||
parent_tool_call_id=parent_tool_call_id,
|
||||
chat_turn_id=chat_turn_id,
|
||||
description=description,
|
||||
started_at=started_at,
|
||||
emitter=emitter,
|
||||
)
|
||||
|
||||
def format_subagent_finish(
|
||||
self,
|
||||
*,
|
||||
subagent_run_id: str,
|
||||
subagent_type: str,
|
||||
parent_tool_call_id: str,
|
||||
status: str = "completed",
|
||||
ended_at: str | None = None,
|
||||
duration_ms: int | None = None,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return subagent_lifecycle.format_subagent_finish(
|
||||
subagent_run_id=subagent_run_id,
|
||||
subagent_type=subagent_type,
|
||||
parent_tool_call_id=parent_tool_call_id,
|
||||
status=status,
|
||||
ended_at=ended_at,
|
||||
duration_ms=duration_ms,
|
||||
emitter=emitter,
|
||||
)
|
||||
|
||||
def format_subagent_error(
|
||||
self,
|
||||
*,
|
||||
subagent_run_id: str,
|
||||
subagent_type: str,
|
||||
parent_tool_call_id: str,
|
||||
error_text: str,
|
||||
error_type: str | None = None,
|
||||
ended_at: str | None = None,
|
||||
duration_ms: int | None = None,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return subagent_lifecycle.format_subagent_error(
|
||||
subagent_run_id=subagent_run_id,
|
||||
subagent_type=subagent_type,
|
||||
parent_tool_call_id=parent_tool_call_id,
|
||||
error_text=error_text,
|
||||
error_type=error_type,
|
||||
ended_at=ended_at,
|
||||
duration_ms=duration_ms,
|
||||
emitter=emitter,
|
||||
)
|
||||
|
||||
def format_action_log(
|
||||
self,
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return action_log.format_action_log(payload, emitter=emitter)
|
||||
|
||||
def format_action_log_updated(
|
||||
self,
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
emitter: Emitter | None = None,
|
||||
) -> str:
|
||||
return action_log.format_action_log_updated(payload, emitter=emitter)
|
||||
|
|
@ -51,6 +51,21 @@ logger = logging.getLogger(__name__)
|
|||
_MEANINGFUL_PART_TYPES: frozenset[str] = frozenset({"text", "reasoning", "tool-call"})
|
||||
|
||||
|
||||
def _merge_tool_part_metadata(part: dict[str, Any], metadata: dict[str, Any] | None) -> None:
|
||||
"""Shallow-merge ``metadata`` into ``part["metadata"]``; first key wins.
|
||||
|
||||
Used for tool-call linkage (``spanId``, ``thinkingStepId``, …): a later
|
||||
event must not overwrite an existing key so chunk order vs ``on_tool_start``
|
||||
stays stable.
|
||||
"""
|
||||
if not metadata:
|
||||
return
|
||||
md = part.setdefault("metadata", {})
|
||||
for k, v in metadata.items():
|
||||
if k not in md:
|
||||
md[k] = v
|
||||
|
||||
|
||||
class AssistantContentBuilder:
|
||||
"""Server-side projection of ``surfsense_web/lib/chat/streaming-state.ts``.
|
||||
|
||||
|
|
@ -61,6 +76,7 @@ class AssistantContentBuilder:
|
|||
| { type: "reasoning"; text: string }
|
||||
| { type: "tool-call"; toolCallId: str; toolName: str;
|
||||
args: dict; result?: any; argsText?: str; langchainToolCallId?: str;
|
||||
metadata?: { spanId?: str; thinkingStepId?: str; ... };
|
||||
state?: "aborted" }
|
||||
| { type: "data-thinking-steps"; data: { steps: ThinkingStepData[] } }
|
||||
| { type: "data-step-separator"; data: { stepIndex: int } }
|
||||
|
|
@ -85,8 +101,8 @@ class AssistantContentBuilder:
|
|||
self._current_text_idx: int = -1
|
||||
self._current_reasoning_idx: int = -1
|
||||
# ``ui_id``-keyed indexes for tool-call parts. ``ui_id`` is the
|
||||
# synthetic ``call_<run_id>`` (legacy) or the LangChain
|
||||
# ``tool_call.id`` (parity_v2) — same key the streaming layer
|
||||
# synthetic ``call_<run_id>`` (chunk fallback) or the LangChain
|
||||
# ``tool_call.id`` (indexed chunk path) — same key the streaming layer
|
||||
# threads through every ``tool-input-*`` / ``tool-output-*`` event.
|
||||
self._tool_call_idx_by_ui_id: dict[str, int] = {}
|
||||
# Live argsText accumulator (concatenated ``tool-input-delta`` chunks)
|
||||
|
|
@ -177,21 +193,27 @@ class AssistantContentBuilder:
|
|||
ui_id: str,
|
||||
tool_name: str,
|
||||
langchain_tool_call_id: str | None,
|
||||
*,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Register a tool-call card. Args are filled in by later events."""
|
||||
"""Register a tool-call card. Args are filled in by later events.
|
||||
|
||||
Optional ``metadata`` (``spanId``, ``thinkingStepId``, …) is stored on the
|
||||
part; duplicate ``tool-input-start`` calls merge with first-key-wins.
|
||||
"""
|
||||
if not ui_id:
|
||||
return
|
||||
# Skip duplicate registration: parity_v2 may emit
|
||||
# Skip duplicate registration: the stream may emit
|
||||
# ``tool-input-start`` from both ``on_chat_model_stream``
|
||||
# (when tool_call_chunks register a name) and ``on_tool_start``
|
||||
# (the canonical path). The FE de-dupes via ``toolCallIndices``;
|
||||
# we mirror that here.
|
||||
if ui_id in self._tool_call_idx_by_ui_id:
|
||||
if langchain_tool_call_id:
|
||||
idx = self._tool_call_idx_by_ui_id[ui_id]
|
||||
part = self.parts[idx]
|
||||
if not part.get("langchainToolCallId"):
|
||||
part["langchainToolCallId"] = langchain_tool_call_id
|
||||
idx = self._tool_call_idx_by_ui_id[ui_id]
|
||||
part = self.parts[idx]
|
||||
if langchain_tool_call_id and not part.get("langchainToolCallId"):
|
||||
part["langchainToolCallId"] = langchain_tool_call_id
|
||||
_merge_tool_part_metadata(part, metadata)
|
||||
return
|
||||
|
||||
part: dict[str, Any] = {
|
||||
|
|
@ -202,6 +224,8 @@ class AssistantContentBuilder:
|
|||
}
|
||||
if langchain_tool_call_id:
|
||||
part["langchainToolCallId"] = langchain_tool_call_id
|
||||
if metadata:
|
||||
part["metadata"] = dict(metadata)
|
||||
self.parts.append(part)
|
||||
self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1
|
||||
|
||||
|
|
@ -235,6 +259,8 @@ class AssistantContentBuilder:
|
|||
tool_name: str,
|
||||
args: dict[str, Any],
|
||||
langchain_tool_call_id: str | None,
|
||||
*,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Finalize the tool-call card's input.
|
||||
|
||||
|
|
@ -243,7 +269,7 @@ class AssistantContentBuilder:
|
|||
pretty-printed JSON, sets the full ``args`` dict, and backfills
|
||||
``langchainToolCallId`` if it wasn't known at ``tool-input-start`` time.
|
||||
Also creates the card if no prior ``tool-input-start`` registered it
|
||||
(legacy parity_v2-OFF / late-registration paths).
|
||||
(late-registration when no prior ``tool-input-start``).
|
||||
"""
|
||||
if not ui_id:
|
||||
return
|
||||
|
|
@ -264,6 +290,7 @@ class AssistantContentBuilder:
|
|||
part["argsText"] = final_args_text
|
||||
if langchain_tool_call_id and not part.get("langchainToolCallId"):
|
||||
part["langchainToolCallId"] = langchain_tool_call_id
|
||||
_merge_tool_part_metadata(part, metadata)
|
||||
return
|
||||
|
||||
# No prior tool-input-start: register the card now.
|
||||
|
|
@ -276,6 +303,7 @@ class AssistantContentBuilder:
|
|||
}
|
||||
if langchain_tool_call_id:
|
||||
new_part["langchainToolCallId"] = langchain_tool_call_id
|
||||
_merge_tool_part_metadata(new_part, metadata)
|
||||
self.parts.append(new_part)
|
||||
self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1
|
||||
|
||||
|
|
@ -287,6 +315,8 @@ class AssistantContentBuilder:
|
|||
ui_id: str,
|
||||
output: Any,
|
||||
langchain_tool_call_id: str | None,
|
||||
*,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Attach the tool's output (``result``) to the matching card.
|
||||
|
||||
|
|
@ -305,6 +335,7 @@ class AssistantContentBuilder:
|
|||
part["result"] = output
|
||||
if langchain_tool_call_id and not part.get("langchainToolCallId"):
|
||||
part["langchainToolCallId"] = langchain_tool_call_id
|
||||
_merge_tool_part_metadata(part, metadata)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Thinking steps & step separators
|
||||
|
|
@ -316,6 +347,8 @@ class AssistantContentBuilder:
|
|||
title: str,
|
||||
status: str,
|
||||
items: list[str] | None,
|
||||
*,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Update / insert the singleton ``data-thinking-steps`` part.
|
||||
|
||||
|
|
@ -328,12 +361,14 @@ class AssistantContentBuilder:
|
|||
if not step_id:
|
||||
return
|
||||
|
||||
new_step = {
|
||||
new_step: dict[str, Any] = {
|
||||
"id": step_id,
|
||||
"title": title or "",
|
||||
"status": status or "in_progress",
|
||||
"items": list(items) if items else [],
|
||||
}
|
||||
if metadata:
|
||||
new_step["metadata"] = dict(metadata)
|
||||
|
||||
# Find existing data-thinking-steps part.
|
||||
existing_idx = -1
|
||||
|
|
@ -347,6 +382,8 @@ class AssistantContentBuilder:
|
|||
replaced = False
|
||||
for i, step in enumerate(current_steps):
|
||||
if step.get("id") == step_id:
|
||||
if not metadata and step.get("metadata"):
|
||||
new_step["metadata"] = dict(step["metadata"])
|
||||
current_steps[i] = new_step
|
||||
replaced = True
|
||||
break
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
3
surfsense_backend/app/tasks/chat/streaming/__init__.py
Normal file
3
surfsense_backend/app/tasks/chat/streaming/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
"""Chat streaming helpers (e.g. LangGraph → SSE relay under ``graph_stream``)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
"""Error classification, structured logging, and terminal-error SSE emission."""
|
||||
|
||||
from __future__ import annotations
|
||||
187
surfsense_backend/app/tasks/chat/streaming/errors/classifier.py
Normal file
187
surfsense_backend/app/tasks/chat/streaming/errors/classifier.py
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
"""Classify stream exceptions for logging and client error payloads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Literal
|
||||
|
||||
from app.agents.new_chat.errors import BusyError
|
||||
from app.agents.new_chat.middleware.busy_mutex import (
|
||||
get_cancel_state,
|
||||
is_cancel_requested,
|
||||
)
|
||||
|
||||
TURN_CANCELLING_INITIAL_DELAY_MS = 200
|
||||
TURN_CANCELLING_BACKOFF_FACTOR = 2
|
||||
TURN_CANCELLING_MAX_DELAY_MS = 1500
|
||||
|
||||
|
||||
def compute_turn_cancelling_retry_delay(attempt: int) -> int:
|
||||
if attempt < 1:
|
||||
attempt = 1
|
||||
delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
|
||||
TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
|
||||
)
|
||||
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
|
||||
|
||||
|
||||
def log_chat_stream_error(
|
||||
*,
|
||||
flow: Literal["new", "resume", "regenerate"],
|
||||
error_kind: str,
|
||||
error_code: str | None,
|
||||
severity: Literal["info", "warn", "error"],
|
||||
is_expected: bool,
|
||||
request_id: str | None,
|
||||
thread_id: int | None,
|
||||
search_space_id: int | None,
|
||||
user_id: str | None,
|
||||
message: str,
|
||||
extra: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
payload: dict[str, Any] = {
|
||||
"event": "chat_stream_error",
|
||||
"flow": flow,
|
||||
"error_kind": error_kind,
|
||||
"error_code": error_code,
|
||||
"severity": severity,
|
||||
"is_expected": is_expected,
|
||||
"request_id": request_id or "unknown",
|
||||
"thread_id": thread_id,
|
||||
"search_space_id": search_space_id,
|
||||
"user_id": user_id,
|
||||
"message": message,
|
||||
}
|
||||
if extra:
|
||||
payload.update(extra)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
rendered = json.dumps(payload, ensure_ascii=False)
|
||||
if severity == "error":
|
||||
logger.error("[chat_stream_error] %s", rendered)
|
||||
elif severity == "warn":
|
||||
logger.warning("[chat_stream_error] %s", rendered)
|
||||
else:
|
||||
logger.info("[chat_stream_error] %s", rendered)
|
||||
|
||||
|
||||
def _parse_error_payload(message: str) -> dict[str, Any] | None:
|
||||
candidates = [message]
|
||||
first_brace_idx = message.find("{")
|
||||
if first_brace_idx >= 0:
|
||||
candidates.append(message[first_brace_idx:])
|
||||
|
||||
for candidate in candidates:
|
||||
try:
|
||||
parsed = json.loads(candidate)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
except Exception:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None:
|
||||
if not isinstance(parsed, dict):
|
||||
return None
|
||||
candidates: list[Any] = [parsed.get("code")]
|
||||
nested = parsed.get("error")
|
||||
if isinstance(nested, dict):
|
||||
candidates.append(nested.get("code"))
|
||||
for value in candidates:
|
||||
try:
|
||||
if value is None:
|
||||
continue
|
||||
return int(value)
|
||||
except Exception:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def is_provider_rate_limited(exc: BaseException) -> bool:
|
||||
"""Return True if the exception looks like an upstream HTTP 429 / rate limit."""
|
||||
raw = str(exc)
|
||||
lowered = raw.lower()
|
||||
if "ratelimit" in type(exc).__name__.lower():
|
||||
return True
|
||||
parsed = _parse_error_payload(raw)
|
||||
provider_code = _extract_provider_error_code(parsed)
|
||||
if provider_code == 429:
|
||||
return True
|
||||
|
||||
provider_error_type = ""
|
||||
if parsed:
|
||||
top_type = parsed.get("type")
|
||||
if isinstance(top_type, str):
|
||||
provider_error_type = top_type.lower()
|
||||
nested = parsed.get("error")
|
||||
if isinstance(nested, dict):
|
||||
nested_type = nested.get("type")
|
||||
if isinstance(nested_type, str):
|
||||
provider_error_type = nested_type.lower()
|
||||
if provider_error_type == "rate_limit_error":
|
||||
return True
|
||||
|
||||
return (
|
||||
"rate limited" in lowered
|
||||
or "rate-limited" in lowered
|
||||
or "temporarily rate-limited upstream" in lowered
|
||||
)
|
||||
|
||||
|
||||
def classify_stream_exception(
|
||||
exc: Exception,
|
||||
*,
|
||||
flow_label: str,
|
||||
) -> tuple[
|
||||
str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None
|
||||
]:
|
||||
"""Return kind, code, severity, expected flag, message, and optional extra dict."""
|
||||
raw = str(exc)
|
||||
if isinstance(exc, BusyError) or "Thread is busy with another request" in raw:
|
||||
busy_thread_id = str(exc.request_id) if isinstance(exc, BusyError) else None
|
||||
if busy_thread_id and is_cancel_requested(busy_thread_id):
|
||||
cancel_state = get_cancel_state(busy_thread_id)
|
||||
attempt = cancel_state[0] if cancel_state else 1
|
||||
retry_after_ms = compute_turn_cancelling_retry_delay(attempt)
|
||||
retry_after_at = int(time.time() * 1000) + retry_after_ms
|
||||
return (
|
||||
"thread_busy",
|
||||
"TURN_CANCELLING",
|
||||
"info",
|
||||
True,
|
||||
"A previous response is still stopping. Please try again in a moment.",
|
||||
{
|
||||
"retry_after_ms": retry_after_ms,
|
||||
"retry_after_at": retry_after_at,
|
||||
},
|
||||
)
|
||||
return (
|
||||
"thread_busy",
|
||||
"THREAD_BUSY",
|
||||
"warn",
|
||||
True,
|
||||
"Another response is still finishing for this thread. Please try again in a moment.",
|
||||
None,
|
||||
)
|
||||
|
||||
if is_provider_rate_limited(exc):
|
||||
return (
|
||||
"rate_limited",
|
||||
"RATE_LIMITED",
|
||||
"warn",
|
||||
True,
|
||||
"This model is temporarily rate-limited. Please try again in a few seconds or switch models.",
|
||||
None,
|
||||
)
|
||||
|
||||
return (
|
||||
"server_error",
|
||||
"SERVER_ERROR",
|
||||
"error",
|
||||
False,
|
||||
f"Error during {flow_label}: {raw}",
|
||||
None,
|
||||
)
|
||||
38
surfsense_backend/app/tasks/chat/streaming/errors/emitter.py
Normal file
38
surfsense_backend/app/tasks/chat/streaming/errors/emitter.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
"""Emit one terminal error SSE frame and log via the stream error classifier."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from .classifier import log_chat_stream_error
|
||||
|
||||
|
||||
def emit_stream_terminal_error(
|
||||
*,
|
||||
streaming_service: Any,
|
||||
flow: Literal["new", "resume", "regenerate"],
|
||||
request_id: str | None,
|
||||
thread_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
message: str,
|
||||
error_kind: str = "server_error",
|
||||
error_code: str = "SERVER_ERROR",
|
||||
severity: Literal["info", "warn", "error"] = "error",
|
||||
is_expected: bool = False,
|
||||
extra: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
log_chat_stream_error(
|
||||
flow=flow,
|
||||
error_kind=error_kind,
|
||||
error_code=error_code,
|
||||
severity=severity,
|
||||
is_expected=is_expected,
|
||||
request_id=request_id,
|
||||
thread_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
message=message,
|
||||
extra=extra,
|
||||
)
|
||||
return streaming_service.format_error(message, error_code=error_code, extra=extra)
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
"""LangGraph ``astream_events`` → SSE (``stream_output`` + ``StreamingResult``).
|
||||
|
||||
Imports are lazy to avoid a circular import with ``relay.event_relay``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ["StreamingResult", "stream_output"]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name == "stream_output":
|
||||
from app.tasks.chat.streaming.graph_stream.event_stream import stream_output
|
||||
|
||||
return stream_output
|
||||
if name == "StreamingResult":
|
||||
from app.tasks.chat.streaming.graph_stream.result import StreamingResult
|
||||
|
||||
return StreamingResult
|
||||
msg = f"module {__name__!r} has no attribute {name!r}"
|
||||
raise AttributeError(msg)
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
"""Run LangGraph event streams through ``EventRelay``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.graph_stream.result import StreamingResult
|
||||
from app.tasks.chat.streaming.relay.event_relay import EventRelay
|
||||
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
|
||||
|
||||
|
||||
async def stream_output(
|
||||
*,
|
||||
agent: Any,
|
||||
config: dict[str, Any],
|
||||
input_data: Any,
|
||||
streaming_service: Any,
|
||||
result: StreamingResult,
|
||||
step_prefix: str = "thinking",
|
||||
initial_step_id: str | None = None,
|
||||
initial_step_title: str = "",
|
||||
initial_step_items: list[str] | None = None,
|
||||
content_builder: Any | None = None,
|
||||
runtime_context: Any = None,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Yield SSE frames from agent ``astream_events`` via ``EventRelay``."""
|
||||
state = AgentEventRelayState.for_invocation(
|
||||
initial_step_id=initial_step_id,
|
||||
initial_step_title=initial_step_title,
|
||||
initial_step_items=initial_step_items,
|
||||
)
|
||||
|
||||
astream_kwargs: dict[str, Any] = {"config": config, "version": "v2"}
|
||||
if runtime_context is not None:
|
||||
astream_kwargs["context"] = runtime_context
|
||||
|
||||
events = agent.astream_events(input_data, **astream_kwargs)
|
||||
relay = EventRelay(streaming_service=streaming_service)
|
||||
async for frame in relay.relay(
|
||||
events,
|
||||
state=state,
|
||||
result=result,
|
||||
step_prefix=step_prefix,
|
||||
content_builder=content_builder,
|
||||
config=config,
|
||||
):
|
||||
yield frame
|
||||
|
||||
result.accumulated_text = state.accumulated_text
|
||||
result.agent_called_update_memory = state.called_update_memory
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
"""Mutable facts collected while relaying one agent stream (``stream_output``)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamingResult:
|
||||
accumulated_text: str = ""
|
||||
is_interrupted: bool = False
|
||||
interrupt_value: dict[str, Any] | None = None
|
||||
sandbox_files: list[str] = field(default_factory=list)
|
||||
agent_called_update_memory: bool = False
|
||||
request_id: str | None = None
|
||||
turn_id: str = ""
|
||||
filesystem_mode: str = "cloud"
|
||||
client_platform: str = "web"
|
||||
intent_detected: str = "chat_only"
|
||||
intent_confidence: float = 0.0
|
||||
write_attempted: bool = False
|
||||
write_succeeded: bool = False
|
||||
verification_succeeded: bool = False
|
||||
commit_gate_passed: bool = True
|
||||
commit_gate_reason: str = ""
|
||||
assistant_message_id: int | None = None
|
||||
content_builder: Any | None = field(default=None, repr=False)
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
"""LangGraph stream handlers by event kind."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
"""Close open text when a LangGraph chain or agent node finishes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
|
||||
|
||||
|
||||
def iter_chain_end_frames(
|
||||
_event: dict[str, Any],
|
||||
*,
|
||||
state: AgentEventRelayState,
|
||||
streaming_service: Any,
|
||||
content_builder: Any | None,
|
||||
) -> Iterator[str]:
|
||||
"""Close the open text stream if one is open."""
|
||||
if state.current_text_id is not None:
|
||||
yield streaming_service.format_text_end(state.current_text_id)
|
||||
if content_builder is not None:
|
||||
content_builder.on_text_end(state.current_text_id)
|
||||
state.current_text_id = None
|
||||
|
|
@ -0,0 +1,159 @@
|
|||
"""Chat model stream: text, reasoning, and tool-call chunk SSE."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.helpers.chunk_parts import extract_chunk_parts
|
||||
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
|
||||
from app.tasks.chat.streaming.relay.task_span import ensure_pending_task_span_for_lc
|
||||
from app.tasks.chat.streaming.relay.thinking_step_completion import (
|
||||
complete_active_thinking_step,
|
||||
)
|
||||
|
||||
|
||||
def iter_chat_model_stream_frames(
|
||||
event: dict[str, Any],
|
||||
*,
|
||||
state: AgentEventRelayState,
|
||||
streaming_service: Any,
|
||||
content_builder: Any | None,
|
||||
step_prefix: str,
|
||||
) -> Iterator[str]:
|
||||
"""SSE frames for one chat-model chunk."""
|
||||
if state.active_tool_depth > 0:
|
||||
return
|
||||
if "surfsense:internal" in event.get("tags", []):
|
||||
return
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
if not chunk:
|
||||
return
|
||||
parts = extract_chunk_parts(chunk)
|
||||
|
||||
reasoning_delta = parts["reasoning"]
|
||||
text_delta = parts["text"]
|
||||
|
||||
if reasoning_delta:
|
||||
if state.current_text_id is not None:
|
||||
yield streaming_service.format_text_end(state.current_text_id)
|
||||
if content_builder is not None:
|
||||
content_builder.on_text_end(state.current_text_id)
|
||||
state.current_text_id = None
|
||||
if state.current_reasoning_id is None:
|
||||
comp, new_active = complete_active_thinking_step(
|
||||
state=state,
|
||||
streaming_service=streaming_service,
|
||||
content_builder=content_builder,
|
||||
last_active_step_id=state.last_active_step_id,
|
||||
last_active_step_title=state.last_active_step_title,
|
||||
last_active_step_items=state.last_active_step_items,
|
||||
completed_step_ids=state.completed_step_ids,
|
||||
)
|
||||
if comp:
|
||||
yield comp
|
||||
state.last_active_step_id = new_active
|
||||
if state.just_finished_tool:
|
||||
state.last_active_step_id = None
|
||||
state.last_active_step_title = ""
|
||||
state.last_active_step_items = []
|
||||
state.just_finished_tool = False
|
||||
state.current_reasoning_id = streaming_service.generate_reasoning_id()
|
||||
yield streaming_service.format_reasoning_start(state.current_reasoning_id)
|
||||
if content_builder is not None:
|
||||
content_builder.on_reasoning_start(state.current_reasoning_id)
|
||||
yield streaming_service.format_reasoning_delta(
|
||||
state.current_reasoning_id, reasoning_delta
|
||||
)
|
||||
if content_builder is not None:
|
||||
content_builder.on_reasoning_delta(
|
||||
state.current_reasoning_id, reasoning_delta
|
||||
)
|
||||
|
||||
if text_delta:
|
||||
if state.current_reasoning_id is not None:
|
||||
yield streaming_service.format_reasoning_end(state.current_reasoning_id)
|
||||
if content_builder is not None:
|
||||
content_builder.on_reasoning_end(state.current_reasoning_id)
|
||||
state.current_reasoning_id = None
|
||||
if state.current_text_id is None:
|
||||
comp, new_active = complete_active_thinking_step(
|
||||
state=state,
|
||||
streaming_service=streaming_service,
|
||||
content_builder=content_builder,
|
||||
last_active_step_id=state.last_active_step_id,
|
||||
last_active_step_title=state.last_active_step_title,
|
||||
last_active_step_items=state.last_active_step_items,
|
||||
completed_step_ids=state.completed_step_ids,
|
||||
)
|
||||
if comp:
|
||||
yield comp
|
||||
state.last_active_step_id = new_active
|
||||
if state.just_finished_tool:
|
||||
state.last_active_step_id = None
|
||||
state.last_active_step_title = ""
|
||||
state.last_active_step_items = []
|
||||
state.just_finished_tool = False
|
||||
state.current_text_id = streaming_service.generate_text_id()
|
||||
yield streaming_service.format_text_start(state.current_text_id)
|
||||
if content_builder is not None:
|
||||
content_builder.on_text_start(state.current_text_id)
|
||||
yield streaming_service.format_text_delta(state.current_text_id, text_delta)
|
||||
state.accumulated_text += text_delta
|
||||
if content_builder is not None:
|
||||
content_builder.on_text_delta(state.current_text_id, text_delta)
|
||||
|
||||
if parts["tool_call_chunks"]:
|
||||
for tcc in parts["tool_call_chunks"]:
|
||||
idx = tcc.get("index")
|
||||
|
||||
if idx is not None and idx not in state.index_to_meta:
|
||||
lc_id = tcc.get("id")
|
||||
name = tcc.get("name")
|
||||
if lc_id and name:
|
||||
ui_id = lc_id
|
||||
tool_input_metadata: dict[str, Any] | None = None
|
||||
if name == "task":
|
||||
sid = ensure_pending_task_span_for_lc(state, str(lc_id))
|
||||
tool_input_metadata = {"spanId": sid}
|
||||
|
||||
if state.current_text_id is not None:
|
||||
yield streaming_service.format_text_end(state.current_text_id)
|
||||
if content_builder is not None:
|
||||
content_builder.on_text_end(state.current_text_id)
|
||||
state.current_text_id = None
|
||||
if state.current_reasoning_id is not None:
|
||||
yield streaming_service.format_reasoning_end(
|
||||
state.current_reasoning_id
|
||||
)
|
||||
if content_builder is not None:
|
||||
content_builder.on_reasoning_end(state.current_reasoning_id)
|
||||
state.current_reasoning_id = None
|
||||
|
||||
state.index_to_meta[idx] = {
|
||||
"ui_id": ui_id,
|
||||
"lc_id": lc_id,
|
||||
"name": name,
|
||||
}
|
||||
yield streaming_service.format_tool_input_start(
|
||||
ui_id,
|
||||
name,
|
||||
langchain_tool_call_id=lc_id,
|
||||
metadata=tool_input_metadata,
|
||||
)
|
||||
if content_builder is not None:
|
||||
content_builder.on_tool_input_start(
|
||||
ui_id, name, lc_id, metadata=tool_input_metadata
|
||||
)
|
||||
|
||||
meta = state.index_to_meta.get(idx) if idx is not None else None
|
||||
if meta:
|
||||
args_chunk = tcc.get("args") or ""
|
||||
if args_chunk:
|
||||
yield streaming_service.format_tool_input_delta(
|
||||
meta["ui_id"], args_chunk
|
||||
)
|
||||
if content_builder is not None:
|
||||
content_builder.on_tool_input_delta(meta["ui_id"], args_chunk)
|
||||
else:
|
||||
state.pending_tool_call_chunks.append(tcc)
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
"""Custom graph events routed to SSE (documents, action logs, report progress)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.handlers.custom_events import (
|
||||
handle_action_log,
|
||||
handle_action_log_updated,
|
||||
handle_document_created,
|
||||
handle_report_progress,
|
||||
)
|
||||
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
|
||||
|
||||
|
||||
def iter_custom_event_frames(
|
||||
event: dict[str, Any],
|
||||
*,
|
||||
state: AgentEventRelayState,
|
||||
streaming_service: Any,
|
||||
content_builder: Any | None,
|
||||
) -> Iterator[str]:
|
||||
"""Yield any SSE produced by ad-hoc graph events (documents, action logs, report progress)."""
|
||||
name = event.get("name")
|
||||
data = event.get("data", {})
|
||||
|
||||
if name == "report_progress":
|
||||
frame, state.last_active_step_items = handle_report_progress(
|
||||
data,
|
||||
last_active_step_id=state.last_active_step_id,
|
||||
last_active_step_title=state.last_active_step_title,
|
||||
last_active_step_items=state.last_active_step_items,
|
||||
streaming_service=streaming_service,
|
||||
content_builder=content_builder,
|
||||
thinking_metadata=state.span_metadata_if_active(),
|
||||
)
|
||||
if frame:
|
||||
yield frame
|
||||
return
|
||||
|
||||
if name == "document_created":
|
||||
frame = handle_document_created(data, streaming_service=streaming_service)
|
||||
if frame:
|
||||
yield frame
|
||||
return
|
||||
|
||||
if name == "action_log":
|
||||
frame = handle_action_log(data, streaming_service=streaming_service)
|
||||
if frame:
|
||||
yield frame
|
||||
return
|
||||
|
||||
if name == "action_log_updated":
|
||||
frame = handle_action_log_updated(data, streaming_service=streaming_service)
|
||||
if frame:
|
||||
yield frame
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
"""Custom-event payloads turned into SSE (no model/tool stream handling)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame
|
||||
|
||||
|
||||
def handle_report_progress(
|
||||
data: dict[str, Any],
|
||||
*,
|
||||
last_active_step_id: str | None,
|
||||
last_active_step_title: str,
|
||||
last_active_step_items: list[str],
|
||||
streaming_service: Any,
|
||||
content_builder: Any | None,
|
||||
thinking_metadata: dict[str, Any] | None = None,
|
||||
) -> tuple[str | None, list[str]]:
|
||||
"""Update report step items; may emit one thinking SSE frame.
|
||||
|
||||
Returns (frame or None, items list after update).
|
||||
"""
|
||||
message = data.get("message", "")
|
||||
if not message or not last_active_step_id:
|
||||
return None, last_active_step_items
|
||||
|
||||
phase = data.get("phase", "")
|
||||
topic_items = [
|
||||
item for item in last_active_step_items if item.startswith("Topic:")
|
||||
]
|
||||
|
||||
if phase in ("revising_section", "adding_section"):
|
||||
plan_items = [
|
||||
item
|
||||
for item in last_active_step_items
|
||||
if item.startswith("Topic:")
|
||||
or item.startswith("Modifying ")
|
||||
or item.startswith("Adding ")
|
||||
or item.startswith("Removing ")
|
||||
]
|
||||
plan_items = [item for item in plan_items if not item.endswith("...")]
|
||||
new_items = [*plan_items, message]
|
||||
else:
|
||||
new_items = [*topic_items, message]
|
||||
|
||||
frame = emit_thinking_step_frame(
|
||||
streaming_service=streaming_service,
|
||||
content_builder=content_builder,
|
||||
step_id=last_active_step_id,
|
||||
title=last_active_step_title,
|
||||
status="in_progress",
|
||||
items=new_items,
|
||||
metadata=thinking_metadata,
|
||||
)
|
||||
return frame, new_items
|
||||
|
||||
|
||||
def handle_document_created(data: dict[str, Any], *, streaming_service: Any) -> str | None:
|
||||
if not data.get("id"):
|
||||
return None
|
||||
return streaming_service.format_data(
|
||||
"documents-updated",
|
||||
{"action": "created", "document": data},
|
||||
)
|
||||
|
||||
|
||||
def handle_action_log(data: dict[str, Any], *, streaming_service: Any) -> str | None:
|
||||
if data.get("id") is None:
|
||||
return None
|
||||
return streaming_service.format_data("action-log", data)
|
||||
|
||||
|
||||
def handle_action_log_updated(
|
||||
data: dict[str, Any], *, streaming_service: Any
|
||||
) -> str | None:
|
||||
if data.get("id") is None:
|
||||
return None
|
||||
return streaming_service.format_data("action-log-updated", data)
|
||||
121
surfsense_backend/app/tasks/chat/streaming/handlers/tool_end.py
Normal file
121
surfsense_backend/app/tasks/chat/streaming/handlers/tool_end.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
"""Tool end: thinking completion, tool output, and terminal SSE."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools import (
|
||||
ToolCompletionEmissionContext,
|
||||
iter_tool_completion_emission_frames,
|
||||
resolve_tool_completed_thinking_step,
|
||||
)
|
||||
from app.tasks.chat.streaming.helpers.tool_output import tool_output_has_error
|
||||
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
|
||||
from app.tasks.chat.streaming.relay.task_span import clear_task_span_if_delegating_task_ended
|
||||
from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame
|
||||
|
||||
|
||||
def iter_tool_end_frames(
|
||||
event: dict[str, Any],
|
||||
*,
|
||||
state: AgentEventRelayState,
|
||||
streaming_service: Any,
|
||||
content_builder: Any | None,
|
||||
result: Any,
|
||||
step_prefix: str,
|
||||
config: dict[str, Any],
|
||||
) -> Iterator[str]:
|
||||
"""SSE frames when one tool run finishes."""
|
||||
state.active_tool_depth = max(0, state.active_tool_depth - 1)
|
||||
run_id = event.get("run_id", "")
|
||||
tool_name = event.get("name", "unknown_tool")
|
||||
raw_output = event.get("data", {}).get("output", "")
|
||||
staged_file_path = (
|
||||
state.file_path_by_run.pop(run_id, None) if run_id else None
|
||||
)
|
||||
|
||||
if tool_name == "update_memory":
|
||||
state.called_update_memory = True
|
||||
|
||||
if hasattr(raw_output, "content"):
|
||||
content = raw_output.content
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
tool_output = json.loads(content)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
tool_output = {"result": content}
|
||||
elif isinstance(content, dict):
|
||||
tool_output = content
|
||||
else:
|
||||
tool_output = {"result": str(content)}
|
||||
elif isinstance(raw_output, dict):
|
||||
tool_output = raw_output
|
||||
else:
|
||||
tool_output = {"result": str(raw_output) if raw_output else "completed"}
|
||||
|
||||
if tool_name in ("write_file", "edit_file"):
|
||||
if tool_output_has_error(tool_output):
|
||||
pass
|
||||
else:
|
||||
result.write_succeeded = True
|
||||
result.verification_succeeded = True
|
||||
|
||||
tool_call_id = state.ui_tool_call_id_by_run.get(
|
||||
run_id,
|
||||
f"call_{run_id[:32]}" if run_id else "call_unknown",
|
||||
)
|
||||
original_step_id = state.tool_step_ids.get(
|
||||
run_id, f"{step_prefix}-unknown-{run_id[:8]}"
|
||||
)
|
||||
state.completed_step_ids.add(original_step_id)
|
||||
|
||||
holder = state.current_lc_tool_call_id
|
||||
holder["value"] = None
|
||||
authoritative = getattr(raw_output, "tool_call_id", None)
|
||||
if isinstance(authoritative, str) and authoritative:
|
||||
holder["value"] = authoritative
|
||||
if run_id:
|
||||
state.lc_tool_call_id_by_run[run_id] = authoritative
|
||||
elif run_id and run_id in state.lc_tool_call_id_by_run:
|
||||
holder["value"] = state.lc_tool_call_id_by_run[run_id]
|
||||
|
||||
items = state.last_active_step_items
|
||||
title, completed_items = resolve_tool_completed_thinking_step(
|
||||
tool_name, tool_output, items
|
||||
)
|
||||
yield emit_thinking_step_frame(
|
||||
streaming_service=streaming_service,
|
||||
content_builder=content_builder,
|
||||
step_id=original_step_id,
|
||||
title=title,
|
||||
status="completed",
|
||||
items=completed_items,
|
||||
metadata=state.span_metadata_if_active(),
|
||||
)
|
||||
|
||||
state.just_finished_tool = True
|
||||
state.last_active_step_id = None
|
||||
state.last_active_step_title = ""
|
||||
state.last_active_step_items = []
|
||||
|
||||
emission_ctx = ToolCompletionEmissionContext(
|
||||
tool_name=tool_name,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_output=tool_output,
|
||||
streaming_service=streaming_service,
|
||||
content_builder=content_builder,
|
||||
langchain_tool_call_id_holder=holder,
|
||||
stream_result=result,
|
||||
langgraph_config=config,
|
||||
staged_workspace_file_path=staged_file_path,
|
||||
tool_metadata=state.tool_activity_metadata(
|
||||
thinking_step_id=original_step_id,
|
||||
),
|
||||
)
|
||||
yield from iter_tool_completion_emission_frames(emission_ctx)
|
||||
|
||||
clear_task_span_if_delegating_task_ended(
|
||||
state, tool_name=tool_name, run_id=run_id
|
||||
)
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
"""Emit tool-output SSE and optional assistant content updates."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def emit_tool_output_available_frame(
|
||||
*,
|
||||
streaming_service: Any,
|
||||
content_builder: Any | None,
|
||||
langchain_id_holder: dict[str, str | None],
|
||||
call_id: str,
|
||||
output: Any,
|
||||
tool_metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
if content_builder is not None:
|
||||
content_builder.on_tool_output_available(
|
||||
call_id,
|
||||
output,
|
||||
langchain_id_holder["value"],
|
||||
metadata=tool_metadata,
|
||||
)
|
||||
return streaming_service.format_tool_output_available(
|
||||
call_id,
|
||||
output,
|
||||
langchain_tool_call_id=langchain_id_holder["value"],
|
||||
metadata=tool_metadata,
|
||||
)
|
||||
|
|
@ -0,0 +1,161 @@
|
|||
"""Tool start: thinking-step and tool-input SSE."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools import resolve_tool_start_thinking
|
||||
from app.tasks.chat.streaming.helpers.tool_call_matching import (
|
||||
match_buffered_langchain_tool_call_id,
|
||||
)
|
||||
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
|
||||
from app.tasks.chat.streaming.relay.task_span import open_task_span
|
||||
from app.tasks.chat.streaming.relay.thinking_step_completion import (
|
||||
complete_active_thinking_step,
|
||||
)
|
||||
from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame
|
||||
|
||||
|
||||
def iter_tool_start_frames(
|
||||
event: dict[str, Any],
|
||||
*,
|
||||
state: AgentEventRelayState,
|
||||
streaming_service: Any,
|
||||
content_builder: Any | None,
|
||||
result: Any,
|
||||
step_prefix: str,
|
||||
) -> Iterator[str]:
|
||||
"""SSE frames for the start of one tool run."""
|
||||
state.active_tool_depth += 1
|
||||
tool_name = event.get("name", "unknown_tool")
|
||||
run_id = event.get("run_id", "")
|
||||
tool_input = event.get("data", {}).get("input", {})
|
||||
if tool_name in ("write_file", "edit_file"):
|
||||
result.write_attempted = True
|
||||
if isinstance(tool_input, dict):
|
||||
file_path = tool_input.get("file_path")
|
||||
if isinstance(file_path, str) and file_path.strip() and run_id:
|
||||
state.file_path_by_run[run_id] = file_path.strip()
|
||||
|
||||
if state.current_text_id is not None:
|
||||
yield streaming_service.format_text_end(state.current_text_id)
|
||||
if content_builder is not None:
|
||||
content_builder.on_text_end(state.current_text_id)
|
||||
state.current_text_id = None
|
||||
|
||||
if state.last_active_step_title != "Synthesizing response":
|
||||
comp, new_active = complete_active_thinking_step(
|
||||
state=state,
|
||||
streaming_service=streaming_service,
|
||||
content_builder=content_builder,
|
||||
last_active_step_id=state.last_active_step_id,
|
||||
last_active_step_title=state.last_active_step_title,
|
||||
last_active_step_items=state.last_active_step_items,
|
||||
completed_step_ids=state.completed_step_ids,
|
||||
)
|
||||
if comp:
|
||||
yield comp
|
||||
state.last_active_step_id = new_active
|
||||
|
||||
state.just_finished_tool = False
|
||||
tool_step_id = state.next_thinking_step_id(step_prefix)
|
||||
state.tool_step_ids[run_id] = tool_step_id
|
||||
state.last_active_step_id = tool_step_id
|
||||
|
||||
matched_meta: dict[str, str] | None = None
|
||||
taken_ui_ids = set(state.ui_tool_call_id_by_run.values())
|
||||
for meta in state.index_to_meta.values():
|
||||
if meta["name"] == tool_name and meta["ui_id"] not in taken_ui_ids:
|
||||
matched_meta = meta
|
||||
break
|
||||
|
||||
tool_call_id: str
|
||||
langchain_tool_call_id: str | None = None
|
||||
if matched_meta is not None:
|
||||
tool_call_id = matched_meta["ui_id"]
|
||||
langchain_tool_call_id = matched_meta["lc_id"]
|
||||
if run_id:
|
||||
state.lc_tool_call_id_by_run[run_id] = matched_meta["lc_id"]
|
||||
else:
|
||||
tool_call_id = (
|
||||
f"call_{run_id[:32]}"
|
||||
if run_id
|
||||
else streaming_service.generate_tool_call_id()
|
||||
)
|
||||
langchain_tool_call_id = match_buffered_langchain_tool_call_id(
|
||||
state.pending_tool_call_chunks,
|
||||
tool_name,
|
||||
run_id,
|
||||
state.lc_tool_call_id_by_run,
|
||||
)
|
||||
|
||||
if tool_name == "task":
|
||||
open_task_span(
|
||||
state,
|
||||
run_id=run_id,
|
||||
langchain_tool_call_id=langchain_tool_call_id,
|
||||
)
|
||||
|
||||
span_md = state.span_metadata_if_active()
|
||||
tool_md = state.tool_activity_metadata(thinking_step_id=tool_step_id)
|
||||
|
||||
if matched_meta is None:
|
||||
yield streaming_service.format_tool_input_start(
|
||||
tool_call_id,
|
||||
tool_name,
|
||||
langchain_tool_call_id=langchain_tool_call_id,
|
||||
metadata=tool_md,
|
||||
)
|
||||
if content_builder is not None:
|
||||
content_builder.on_tool_input_start(
|
||||
tool_call_id,
|
||||
tool_name,
|
||||
langchain_tool_call_id,
|
||||
metadata=tool_md,
|
||||
)
|
||||
|
||||
thinking = resolve_tool_start_thinking(tool_name, tool_input)
|
||||
state.last_active_step_title = thinking.title
|
||||
state.last_active_step_items = thinking.items
|
||||
frame_kw: dict[str, Any] = {
|
||||
"streaming_service": streaming_service,
|
||||
"content_builder": content_builder,
|
||||
"step_id": tool_step_id,
|
||||
"title": thinking.title,
|
||||
"status": "in_progress",
|
||||
"metadata": span_md,
|
||||
}
|
||||
if thinking.include_items_on_frame:
|
||||
frame_kw["items"] = thinking.items
|
||||
yield emit_thinking_step_frame(**frame_kw)
|
||||
|
||||
if run_id:
|
||||
state.ui_tool_call_id_by_run[run_id] = tool_call_id
|
||||
|
||||
if isinstance(tool_input, dict):
|
||||
_safe_input: dict[str, Any] = {}
|
||||
for _k, _v in tool_input.items():
|
||||
try:
|
||||
json.dumps(_v)
|
||||
_safe_input[_k] = _v
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
pass
|
||||
else:
|
||||
_safe_input = {"input": tool_input}
|
||||
yield streaming_service.format_tool_input_available(
|
||||
tool_call_id,
|
||||
tool_name,
|
||||
_safe_input,
|
||||
langchain_tool_call_id=langchain_tool_call_id,
|
||||
metadata=tool_md,
|
||||
)
|
||||
if content_builder is not None:
|
||||
content_builder.on_tool_input_available(
|
||||
tool_call_id,
|
||||
tool_name,
|
||||
_safe_input,
|
||||
langchain_tool_call_id,
|
||||
metadata=tool_md,
|
||||
)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
"""Per-tool streaming: thinking-step and completion emission."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.emission_context import (
|
||||
ToolCompletionEmissionContext,
|
||||
)
|
||||
from app.tasks.chat.streaming.handlers.tools.registry import (
|
||||
iter_tool_completion_emission_frames,
|
||||
resolve_tool_completed_thinking_step,
|
||||
resolve_tool_start_thinking,
|
||||
)
|
||||
from app.tasks.chat.streaming.handlers.tools.shared.model import (
|
||||
ToolStartThinking,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ToolCompletionEmissionContext",
|
||||
"ToolStartThinking",
|
||||
"iter_tool_completion_emission_frames",
|
||||
"resolve_tool_completed_thinking_step",
|
||||
"resolve_tool_start_thinking",
|
||||
]
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.emission_context import (
|
||||
ToolCompletionEmissionContext,
|
||||
)
|
||||
|
||||
|
||||
def iter_completion_emission_frames(
|
||||
ctx: ToolCompletionEmissionContext,
|
||||
) -> Iterator[str]:
|
||||
out = ctx.tool_output
|
||||
payload = out if isinstance(out, dict) else {"result": out}
|
||||
yield ctx.emit_tool_output_card(payload)
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.default import (
|
||||
thinking as default_thinking,
|
||||
)
|
||||
from app.tasks.chat.streaming.handlers.tools.shared.model import (
|
||||
ToolStartThinking,
|
||||
)
|
||||
|
||||
|
||||
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
|
||||
return default_thinking.resolve_start_thinking(tool_name, tool_input)
|
||||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
return default_thinking.resolve_completed_thinking(
|
||||
tool_name, tool_output, last_items
|
||||
)
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
from __future__ import annotations
|
||||
|
||||
SHARED_CONNECTOR_TOOLS: frozenset[str] = frozenset(
|
||||
{
|
||||
"create_calendar_event",
|
||||
"create_confluence_page",
|
||||
"create_dropbox_file",
|
||||
"create_gmail_draft",
|
||||
"create_google_drive_file",
|
||||
"create_jira_issue",
|
||||
"create_linear_issue",
|
||||
"create_notion_page",
|
||||
"create_onedrive_file",
|
||||
"delete_calendar_event",
|
||||
"delete_confluence_page",
|
||||
"delete_dropbox_file",
|
||||
"delete_google_drive_file",
|
||||
"delete_jira_issue",
|
||||
"delete_linear_issue",
|
||||
"delete_notion_page",
|
||||
"delete_onedrive_file",
|
||||
"send_gmail_email",
|
||||
"trash_gmail_email",
|
||||
"update_calendar_event",
|
||||
"update_confluence_page",
|
||||
"update_gmail_draft",
|
||||
"update_jira_issue",
|
||||
"update_linear_issue",
|
||||
"update_notion_page",
|
||||
}
|
||||
)
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
"""Fallback tool package."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
"""Default tool-output card and a short completion terminal line."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.emission_context import (
|
||||
ToolCompletionEmissionContext,
|
||||
)
|
||||
|
||||
|
||||
def iter_completion_emission_frames(
|
||||
ctx: ToolCompletionEmissionContext,
|
||||
) -> Iterator[str]:
|
||||
yield ctx.emit_tool_output_card(
|
||||
{
|
||||
"status": "completed",
|
||||
"result_length": len(str(ctx.tool_output)),
|
||||
},
|
||||
)
|
||||
yield ctx.streaming_service.format_terminal_info(
|
||||
f"Tool {ctx.tool_name} completed",
|
||||
"success",
|
||||
)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
"""Fallback thinking-step copy for unknown tools and connectors without custom UI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.shared.model import (
|
||||
ToolStartThinking,
|
||||
)
|
||||
|
||||
|
||||
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
|
||||
del tool_input
|
||||
title = tool_name.replace("_", " ").strip().capitalize() or tool_name
|
||||
return ToolStartThinking(title=title, items=[], include_items_on_frame=False)
|
||||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str]
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_output
|
||||
title = tool_name.replace("_", " ").strip().capitalize() or tool_name
|
||||
return (title, last_items)
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
"""generate_image: tool card + terminal summary."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.emission_context import (
|
||||
ToolCompletionEmissionContext,
|
||||
)
|
||||
|
||||
|
||||
def iter_completion_emission_frames(
|
||||
ctx: ToolCompletionEmissionContext,
|
||||
) -> Iterator[str]:
|
||||
out = ctx.tool_output
|
||||
payload = out if isinstance(out, dict) else {"result": out}
|
||||
yield ctx.emit_tool_output_card(payload)
|
||||
if isinstance(out, dict):
|
||||
if out.get("error"):
|
||||
yield ctx.streaming_service.format_terminal_info(
|
||||
f"Image generation failed: {out['error'][:60]}",
|
||||
"error",
|
||||
)
|
||||
else:
|
||||
yield ctx.streaming_service.format_terminal_info(
|
||||
"Image generated successfully",
|
||||
"success",
|
||||
)
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
"""generate_image: thinking-step copy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.deliverables.shared.tool_input import (
|
||||
as_tool_input_dict,
|
||||
)
|
||||
from app.tasks.chat.streaming.handlers.tools.shared.model import (
|
||||
ToolStartThinking,
|
||||
)
|
||||
|
||||
|
||||
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
|
||||
del tool_name
|
||||
d = as_tool_input_dict(tool_input)
|
||||
prompt = d.get("prompt", "") if isinstance(tool_input, dict) else str(tool_input)
|
||||
return ToolStartThinking(
|
||||
title="Generating image",
|
||||
items=[f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}"],
|
||||
)
|
||||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_name
|
||||
items = last_items
|
||||
if isinstance(tool_output, dict) and not tool_output.get("error"):
|
||||
completed = [*items, "Image generated successfully"]
|
||||
else:
|
||||
error_msg = (
|
||||
tool_output.get("error", "Generation failed")
|
||||
if isinstance(tool_output, dict)
|
||||
else "Generation failed"
|
||||
)
|
||||
completed = [*items, f"Error: {error_msg}"]
|
||||
return ("Generating image", completed)
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
"""generate_podcast: tool card + queue / success / failure terminal lines."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.emission_context import (
|
||||
ToolCompletionEmissionContext,
|
||||
)
|
||||
|
||||
|
||||
def iter_completion_emission_frames(
|
||||
ctx: ToolCompletionEmissionContext,
|
||||
) -> Iterator[str]:
|
||||
out = ctx.tool_output
|
||||
payload = out if isinstance(out, dict) else {"result": out}
|
||||
yield ctx.emit_tool_output_card(payload)
|
||||
if isinstance(out, dict) and out.get("status") in (
|
||||
"pending",
|
||||
"generating",
|
||||
"processing",
|
||||
):
|
||||
yield ctx.streaming_service.format_terminal_info(
|
||||
f"Podcast queued: {out.get('title', 'Podcast')}",
|
||||
"success",
|
||||
)
|
||||
elif isinstance(out, dict) and out.get("status") in ("ready", "success"):
|
||||
yield ctx.streaming_service.format_terminal_info(
|
||||
f"Podcast generated successfully: {out.get('title', 'Podcast')}",
|
||||
"success",
|
||||
)
|
||||
elif isinstance(out, dict) and out.get("status") in ("failed", "error"):
|
||||
error_msg = out.get("error", "Unknown error")
|
||||
yield ctx.streaming_service.format_terminal_info(
|
||||
f"Podcast generation failed: {error_msg}",
|
||||
"error",
|
||||
)
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
"""generate_podcast: thinking-step copy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.deliverables.shared.tool_input import (
|
||||
as_tool_input_dict,
|
||||
)
|
||||
from app.tasks.chat.streaming.handlers.tools.shared.model import (
|
||||
ToolStartThinking,
|
||||
)
|
||||
|
||||
|
||||
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
|
||||
del tool_name
|
||||
d = as_tool_input_dict(tool_input)
|
||||
podcast_title = (
|
||||
d.get("podcast_title", "SurfSense Podcast")
|
||||
if isinstance(tool_input, dict)
|
||||
else "SurfSense Podcast"
|
||||
)
|
||||
content_len = len(
|
||||
d.get("source_content", "") if isinstance(tool_input, dict) else ""
|
||||
)
|
||||
return ToolStartThinking(
|
||||
title="Generating podcast",
|
||||
items=[
|
||||
f"Title: {podcast_title}",
|
||||
f"Content: {content_len:,} characters",
|
||||
"Preparing audio generation...",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_name
|
||||
items = last_items
|
||||
podcast_status = (
|
||||
tool_output.get("status", "unknown")
|
||||
if isinstance(tool_output, dict)
|
||||
else "unknown"
|
||||
)
|
||||
podcast_title = (
|
||||
tool_output.get("title", "Podcast")
|
||||
if isinstance(tool_output, dict)
|
||||
else "Podcast"
|
||||
)
|
||||
if podcast_status in ("pending", "generating", "processing"):
|
||||
completed = [
|
||||
f"Title: {podcast_title}",
|
||||
"Podcast generation started",
|
||||
"Processing in background...",
|
||||
]
|
||||
elif podcast_status == "already_generating":
|
||||
completed = [
|
||||
f"Title: {podcast_title}",
|
||||
"Podcast already in progress",
|
||||
"Please wait for it to complete",
|
||||
]
|
||||
elif podcast_status in ("failed", "error"):
|
||||
error_msg = (
|
||||
tool_output.get("error", "Unknown error")
|
||||
if isinstance(tool_output, dict)
|
||||
else "Unknown error"
|
||||
)
|
||||
completed = [
|
||||
f"Title: {podcast_title}",
|
||||
f"Error: {error_msg[:50]}",
|
||||
]
|
||||
elif podcast_status in ("ready", "success"):
|
||||
completed = [
|
||||
f"Title: {podcast_title}",
|
||||
"Podcast ready",
|
||||
]
|
||||
else:
|
||||
completed = items
|
||||
return ("Generating podcast", completed)
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
"""generate_report: full payload + terminal line."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.emission_context import (
|
||||
ToolCompletionEmissionContext,
|
||||
)
|
||||
|
||||
|
||||
def iter_completion_emission_frames(
|
||||
ctx: ToolCompletionEmissionContext,
|
||||
) -> Iterator[str]:
|
||||
out = ctx.tool_output
|
||||
payload = out if isinstance(out, dict) else {"result": out}
|
||||
yield ctx.emit_tool_output_card(payload)
|
||||
if isinstance(out, dict) and out.get("status") == "ready":
|
||||
word_count = out.get("word_count", 0)
|
||||
yield ctx.streaming_service.format_terminal_info(
|
||||
f"Report generated: {out.get('title', 'Report')} ({word_count:,} words)",
|
||||
"success",
|
||||
)
|
||||
else:
|
||||
error_msg = (
|
||||
out.get("error", "Unknown error")
|
||||
if isinstance(out, dict)
|
||||
else "Unknown error"
|
||||
)
|
||||
yield ctx.streaming_service.format_terminal_info(
|
||||
f"Report generation failed: {error_msg}",
|
||||
"error",
|
||||
)
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
"""generate_report: thinking-step copy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.deliverables.shared.tool_input import (
|
||||
as_tool_input_dict,
|
||||
)
|
||||
from app.tasks.chat.streaming.handlers.tools.shared.model import (
|
||||
ToolStartThinking,
|
||||
)
|
||||
|
||||
|
||||
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
|
||||
del tool_name
|
||||
d = as_tool_input_dict(tool_input)
|
||||
report_topic = (
|
||||
d.get("topic", "Report") if isinstance(tool_input, dict) else "Report"
|
||||
)
|
||||
is_revision = bool(
|
||||
isinstance(tool_input, dict) and tool_input.get("parent_report_id")
|
||||
)
|
||||
step_title = "Revising report" if is_revision else "Generating report"
|
||||
return ToolStartThinking(
|
||||
title=step_title,
|
||||
items=[f"Topic: {report_topic}", "Analyzing source content..."],
|
||||
)
|
||||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_name
|
||||
items = last_items
|
||||
report_status = (
|
||||
tool_output.get("status", "unknown")
|
||||
if isinstance(tool_output, dict)
|
||||
else "unknown"
|
||||
)
|
||||
report_title = (
|
||||
tool_output.get("title", "Report")
|
||||
if isinstance(tool_output, dict)
|
||||
else "Report"
|
||||
)
|
||||
word_count = (
|
||||
tool_output.get("word_count", 0)
|
||||
if isinstance(tool_output, dict)
|
||||
else 0
|
||||
)
|
||||
is_revision = (
|
||||
tool_output.get("is_revision", False)
|
||||
if isinstance(tool_output, dict)
|
||||
else False
|
||||
)
|
||||
step_title = "Revising report" if is_revision else "Generating report"
|
||||
|
||||
if report_status == "ready":
|
||||
completed = [
|
||||
f"Topic: {report_title}",
|
||||
f"{word_count:,} words",
|
||||
"Report ready",
|
||||
]
|
||||
elif report_status == "failed":
|
||||
error_msg = (
|
||||
tool_output.get("error", "Unknown error")
|
||||
if isinstance(tool_output, dict)
|
||||
else "Unknown error"
|
||||
)
|
||||
completed = [
|
||||
f"Topic: {report_title}",
|
||||
f"Error: {error_msg[:50]}",
|
||||
]
|
||||
else:
|
||||
completed = items
|
||||
|
||||
return (step_title, completed)
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
"""generate_resume: full payload + terminal line."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.emission_context import (
|
||||
ToolCompletionEmissionContext,
|
||||
)
|
||||
|
||||
|
||||
def iter_completion_emission_frames(
|
||||
ctx: ToolCompletionEmissionContext,
|
||||
) -> Iterator[str]:
|
||||
out = ctx.tool_output
|
||||
payload = out if isinstance(out, dict) else {"result": out}
|
||||
yield ctx.emit_tool_output_card(payload)
|
||||
if isinstance(out, dict) and out.get("status") == "ready":
|
||||
yield ctx.streaming_service.format_terminal_info(
|
||||
f"Resume generated: {out.get('title', 'Resume')}",
|
||||
"success",
|
||||
)
|
||||
else:
|
||||
error_msg = (
|
||||
out.get("error", "Unknown error")
|
||||
if isinstance(out, dict)
|
||||
else "Unknown error"
|
||||
)
|
||||
yield ctx.streaming_service.format_terminal_info(
|
||||
f"Resume generation failed: {error_msg}",
|
||||
"error",
|
||||
)
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
"""generate_resume: generic thinking titles and items."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.default import (
|
||||
thinking as default_thinking,
|
||||
)
|
||||
from app.tasks.chat.streaming.handlers.tools.shared.model import (
|
||||
ToolStartThinking,
|
||||
)
|
||||
|
||||
|
||||
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
|
||||
return default_thinking.resolve_start_thinking(tool_name, tool_input)
|
||||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
return default_thinking.resolve_completed_thinking(
|
||||
tool_name, tool_output, last_items
|
||||
)
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
"""generate_video_presentation: tool card + terminal line."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.emission_context import (
|
||||
ToolCompletionEmissionContext,
|
||||
)
|
||||
|
||||
|
||||
def iter_completion_emission_frames(
|
||||
ctx: ToolCompletionEmissionContext,
|
||||
) -> Iterator[str]:
|
||||
out = ctx.tool_output
|
||||
payload = out if isinstance(out, dict) else {"result": out}
|
||||
yield ctx.emit_tool_output_card(payload)
|
||||
if isinstance(out, dict) and out.get("status") == "pending":
|
||||
yield ctx.streaming_service.format_terminal_info(
|
||||
f"Video presentation queued: {out.get('title', 'Presentation')}",
|
||||
"success",
|
||||
)
|
||||
elif isinstance(out, dict) and out.get("status") == "failed":
|
||||
error_msg = out.get("error", "Unknown error")
|
||||
yield ctx.streaming_service.format_terminal_info(
|
||||
f"Presentation generation failed: {error_msg}",
|
||||
"error",
|
||||
)
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
"""generate_video_presentation: generic in-progress thinking; completion is status-driven."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.default import (
|
||||
thinking as default_thinking,
|
||||
)
|
||||
from app.tasks.chat.streaming.handlers.tools.shared.model import (
|
||||
ToolStartThinking,
|
||||
)
|
||||
|
||||
|
||||
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
|
||||
return default_thinking.resolve_start_thinking(tool_name, tool_input)
|
||||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_name
|
||||
items = last_items
|
||||
vp_status = (
|
||||
tool_output.get("status", "unknown")
|
||||
if isinstance(tool_output, dict)
|
||||
else "unknown"
|
||||
)
|
||||
vp_title = (
|
||||
tool_output.get("title", "Presentation")
|
||||
if isinstance(tool_output, dict)
|
||||
else "Presentation"
|
||||
)
|
||||
if vp_status in ("pending", "generating"):
|
||||
completed = [
|
||||
f"Title: {vp_title}",
|
||||
"Presentation generation started",
|
||||
"Processing in background...",
|
||||
]
|
||||
elif vp_status == "failed":
|
||||
error_msg = (
|
||||
tool_output.get("error", "Unknown error")
|
||||
if isinstance(tool_output, dict)
|
||||
else "Unknown error"
|
||||
)
|
||||
completed = [
|
||||
f"Title: {vp_title}",
|
||||
f"Error: {error_msg[:50]}",
|
||||
]
|
||||
else:
|
||||
completed = items
|
||||
return ("Generating video presentation", completed)
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
"""save_document: default completion card and terminal line."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.default import emission as _default
|
||||
from app.tasks.chat.streaming.handlers.tools.emission_context import (
|
||||
ToolCompletionEmissionContext,
|
||||
)
|
||||
|
||||
|
||||
def iter_completion_emission_frames(
|
||||
ctx: ToolCompletionEmissionContext,
|
||||
) -> Iterator[str]:
|
||||
yield from _default.iter_completion_emission_frames(ctx)
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
"""save_document: thinking-step copy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.deliverables.shared.tool_input import (
|
||||
as_tool_input_dict,
|
||||
)
|
||||
from app.tasks.chat.streaming.handlers.tools.shared.model import (
|
||||
ToolStartThinking,
|
||||
)
|
||||
|
||||
|
||||
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
|
||||
del tool_name
|
||||
d = as_tool_input_dict(tool_input)
|
||||
doc_title = d.get("title", "") if isinstance(tool_input, dict) else str(tool_input)
|
||||
display_title = doc_title[:60] + ("…" if len(doc_title) > 60 else "")
|
||||
return ToolStartThinking(title="Saving document", items=[display_title])
|
||||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_name
|
||||
items = last_items
|
||||
result_str = (
|
||||
tool_output.get("result", "")
|
||||
if isinstance(tool_output, dict)
|
||||
else str(tool_output)
|
||||
)
|
||||
is_error = "Error" in result_str
|
||||
completed = [
|
||||
*items,
|
||||
result_str[:80] if is_error else "Saved to knowledge base",
|
||||
]
|
||||
return ("Saving document", completed)
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
"""Tool-call args for deliverable thinking modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def as_tool_input_dict(tool_input: Any) -> dict[str, Any]:
|
||||
return tool_input if isinstance(tool_input, dict) else {}
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
DELIVERABLE_TOOLS: frozenset[str] = frozenset(
|
||||
{
|
||||
"generate_image",
|
||||
"generate_podcast",
|
||||
"generate_report",
|
||||
"generate_resume",
|
||||
"generate_video_presentation",
|
||||
"save_document",
|
||||
}
|
||||
)
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
"""Context for one tool-completion emission pass."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tool_output_frame import (
|
||||
emit_tool_output_available_frame,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCompletionEmissionContext:
|
||||
"""Streaming service, tool output, and ids for completion frames."""
|
||||
|
||||
tool_name: str
|
||||
tool_call_id: str
|
||||
tool_output: Any
|
||||
streaming_service: Any
|
||||
content_builder: Any | None
|
||||
langchain_tool_call_id_holder: dict[str, str | None]
|
||||
stream_result: Any
|
||||
langgraph_config: dict[str, Any]
|
||||
staged_workspace_file_path: str | None
|
||||
tool_metadata: dict[str, Any] | None = None
|
||||
|
||||
def emit_tool_output_card(self, payload: Any) -> str:
|
||||
return emit_tool_output_available_frame(
|
||||
streaming_service=self.streaming_service,
|
||||
content_builder=self.content_builder,
|
||||
langchain_id_holder=self.langchain_tool_call_id_holder,
|
||||
call_id=self.tool_call_id,
|
||||
output=payload,
|
||||
tool_metadata=self.tool_metadata,
|
||||
)
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
"""edit_file: thinking-step copy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import (
|
||||
as_tool_input_dict,
|
||||
truncate_path,
|
||||
)
|
||||
from app.tasks.chat.streaming.handlers.tools.shared.model import (
|
||||
ToolStartThinking,
|
||||
)
|
||||
|
||||
|
||||
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
|
||||
del tool_name
|
||||
d = as_tool_input_dict(tool_input)
|
||||
fp = d.get("file_path", "") if isinstance(tool_input, dict) else str(tool_input)
|
||||
return ToolStartThinking(title="Editing file", items=[truncate_path(fp)])
|
||||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_output, tool_name
|
||||
return ("Editing file", last_items)
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
"""execute: exit code, stdout, sandbox file hints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections.abc import Iterator
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.emission_context import (
|
||||
ToolCompletionEmissionContext,
|
||||
)
|
||||
|
||||
|
||||
def iter_completion_emission_frames(
|
||||
ctx: ToolCompletionEmissionContext,
|
||||
) -> Iterator[str]:
|
||||
out = ctx.tool_output
|
||||
raw_text = out.get("result", "") if isinstance(out, dict) else str(out)
|
||||
exit_code: int | None = None
|
||||
output_text = raw_text
|
||||
m = re.match(r"^Exit code:\s*(\d+)", raw_text)
|
||||
if m:
|
||||
exit_code = int(m.group(1))
|
||||
om = re.search(r"\nOutput:\n([\s\S]*)", raw_text)
|
||||
output_text = om.group(1) if om else ""
|
||||
thread_id_str = ctx.langgraph_config.get("configurable", {}).get("thread_id", "")
|
||||
|
||||
for sf_match in re.finditer(
|
||||
r"^SANDBOX_FILE:\s*(.+)$", output_text, re.MULTILINE
|
||||
):
|
||||
fpath = sf_match.group(1).strip()
|
||||
if fpath and fpath not in ctx.stream_result.sandbox_files:
|
||||
ctx.stream_result.sandbox_files.append(fpath)
|
||||
|
||||
yield ctx.emit_tool_output_card(
|
||||
{
|
||||
"exit_code": exit_code,
|
||||
"output": output_text,
|
||||
"thread_id": thread_id_str,
|
||||
},
|
||||
)
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
"""execute: sandbox command thinking + completion lines."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import (
|
||||
as_tool_input_dict,
|
||||
)
|
||||
from app.tasks.chat.streaming.handlers.tools.shared.model import (
|
||||
ToolStartThinking,
|
||||
)
|
||||
|
||||
|
||||
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
|
||||
del tool_name
|
||||
d = as_tool_input_dict(tool_input)
|
||||
cmd = d.get("command", "") if isinstance(tool_input, dict) else str(tool_input)
|
||||
display_cmd = cmd[:80] + ("…" if len(cmd) > 80 else "")
|
||||
return ToolStartThinking(title="Running command", items=[f"$ {display_cmd}"])
|
||||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_name
|
||||
items = last_items
|
||||
raw_text = (
|
||||
tool_output.get("result", "")
|
||||
if isinstance(tool_output, dict)
|
||||
else str(tool_output)
|
||||
)
|
||||
m = re.match(r"^Exit code:\s*(\d+)", raw_text)
|
||||
exit_code_val = int(m.group(1)) if m else None
|
||||
if exit_code_val is not None and exit_code_val == 0:
|
||||
completed = [*items, "Completed successfully"]
|
||||
elif exit_code_val is not None:
|
||||
completed = [*items, f"Exit code: {exit_code_val}"]
|
||||
else:
|
||||
completed = [*items, "Finished"]
|
||||
return ("Running command", completed)
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
"""glob: thinking-step copy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import (
|
||||
as_tool_input_dict,
|
||||
)
|
||||
from app.tasks.chat.streaming.handlers.tools.shared.model import (
|
||||
ToolStartThinking,
|
||||
)
|
||||
|
||||
|
||||
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
|
||||
del tool_name
|
||||
d = as_tool_input_dict(tool_input)
|
||||
pat = d.get("pattern", "") if isinstance(tool_input, dict) else str(tool_input)
|
||||
base = d.get("path", "/") if isinstance(tool_input, dict) else "/"
|
||||
return ToolStartThinking(title="Searching files", items=[f"{pat} in {base}"])
|
||||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_output, tool_name
|
||||
return ("Searching files", last_items)
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
"""grep: thinking-step copy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import (
|
||||
as_tool_input_dict,
|
||||
)
|
||||
from app.tasks.chat.streaming.handlers.tools.shared.model import (
|
||||
ToolStartThinking,
|
||||
)
|
||||
|
||||
|
||||
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
|
||||
del tool_name
|
||||
d = as_tool_input_dict(tool_input)
|
||||
pat = d.get("pattern", "") if isinstance(tool_input, dict) else str(tool_input)
|
||||
grep_path = d.get("path", "") if isinstance(tool_input, dict) else ""
|
||||
display_pat = pat[:60] + ("…" if len(pat) > 60 else "")
|
||||
return ToolStartThinking(
|
||||
title="Searching content",
|
||||
items=[f'"{display_pat}"' + (f" in {grep_path}" if grep_path else "")],
|
||||
)
|
||||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_output, tool_name
|
||||
return ("Searching content", last_items)
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
"""ls: thinking-step copy for directory listing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.shared.model import (
|
||||
ToolStartThinking,
|
||||
)
|
||||
|
||||
|
||||
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
|
||||
del tool_name
|
||||
if isinstance(tool_input, dict):
|
||||
path = tool_input.get("path", "/")
|
||||
else:
|
||||
path = str(tool_input)
|
||||
return ToolStartThinking(title="Listing files", items=[path])
|
||||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_name
|
||||
if isinstance(tool_output, dict):
|
||||
ls_output = tool_output.get("result", "")
|
||||
elif isinstance(tool_output, str):
|
||||
ls_output = tool_output
|
||||
else:
|
||||
ls_output = str(tool_output) if tool_output else ""
|
||||
file_names: list[str] = []
|
||||
if ls_output:
|
||||
paths: list[str] = []
|
||||
try:
|
||||
parsed = ast.literal_eval(ls_output)
|
||||
if isinstance(parsed, list):
|
||||
paths = [str(p) for p in parsed]
|
||||
except (ValueError, SyntaxError):
|
||||
paths = [
|
||||
line.strip()
|
||||
for line in ls_output.strip().split("\n")
|
||||
if line.strip()
|
||||
]
|
||||
for p in paths:
|
||||
name = p.rstrip("/").split("/")[-1]
|
||||
if name and len(name) <= 40:
|
||||
file_names.append(name)
|
||||
elif name:
|
||||
file_names.append(name[:37] + "...")
|
||||
if file_names:
|
||||
if len(file_names) <= 5:
|
||||
completed = [f"[{name}]" for name in file_names]
|
||||
else:
|
||||
completed = [f"[{name}]" for name in file_names[:4]]
|
||||
completed.append(f"(+{len(file_names) - 4} more)")
|
||||
else:
|
||||
completed = ["No files found"]
|
||||
return ("Listing files", completed)
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
"""mkdir: thinking-step copy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import (
|
||||
as_tool_input_dict,
|
||||
)
|
||||
from app.tasks.chat.streaming.handlers.tools.shared.model import (
|
||||
ToolStartThinking,
|
||||
)
|
||||
|
||||
|
||||
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
|
||||
del tool_name
|
||||
d = as_tool_input_dict(tool_input)
|
||||
p = d.get("path", "") if isinstance(tool_input, dict) else str(tool_input)
|
||||
display = p if len(p) <= 80 else "…" + p[-77:]
|
||||
return ToolStartThinking(title="Creating folder", items=[display] if display else [])
|
||||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_output, tool_name
|
||||
return ("Creating folder", last_items)
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
"""move_file: thinking-step copy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import (
|
||||
as_tool_input_dict,
|
||||
truncate_middle,
|
||||
)
|
||||
from app.tasks.chat.streaming.handlers.tools.shared.model import (
|
||||
ToolStartThinking,
|
||||
)
|
||||
|
||||
|
||||
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
|
||||
del tool_name
|
||||
d = as_tool_input_dict(tool_input)
|
||||
src = d.get("source_path", "") if isinstance(tool_input, dict) else ""
|
||||
dst = d.get("destination_path", "") if isinstance(tool_input, dict) else ""
|
||||
display_src = truncate_middle(src, max_len=60)
|
||||
display_dst = truncate_middle(dst, max_len=60)
|
||||
return ToolStartThinking(
|
||||
title="Moving file",
|
||||
items=[f"{display_src} → {display_dst}"] if src or dst else [],
|
||||
)
|
||||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_output, tool_name
|
||||
return ("Moving file", last_items)
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
"""read_file: thinking-step copy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import (
|
||||
as_tool_input_dict,
|
||||
truncate_path,
|
||||
)
|
||||
from app.tasks.chat.streaming.handlers.tools.shared.model import (
|
||||
ToolStartThinking,
|
||||
)
|
||||
|
||||
|
||||
def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking:
|
||||
del tool_name
|
||||
d = as_tool_input_dict(tool_input)
|
||||
fp = d.get("file_path", "") if isinstance(tool_input, dict) else str(tool_input)
|
||||
return ToolStartThinking(title="Reading file", items=[truncate_path(fp)])
|
||||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_output, tool_name
|
||||
return ("Reading file", last_items)
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue