diff --git a/.cursor/skills/improve-codebase-architecture/DEEPENING.md b/.cursor/skills/improve-codebase-architecture/DEEPENING.md new file mode 100644 index 000000000..ecaf5d7dc --- /dev/null +++ b/.cursor/skills/improve-codebase-architecture/DEEPENING.md @@ -0,0 +1,37 @@ +# Deepening + +How to deepen a cluster of shallow modules safely, given its dependencies. Assumes the vocabulary in [LANGUAGE.md](LANGUAGE.md) — **module**, **interface**, **seam**, **adapter**. + +## Dependency categories + +When assessing a candidate for deepening, classify its dependencies. The category determines how the deepened module is tested across its seam. + +### 1. In-process + +Pure computation, in-memory state, no I/O. Always deepenable — merge the modules and test through the new interface directly. No adapter needed. + +### 2. Local-substitutable + +Dependencies that have local test stand-ins (PGLite for Postgres, in-memory filesystem). Deepenable if the stand-in exists. The deepened module is tested with the stand-in running in the test suite. The seam is internal; no port at the module's external interface. + +### 3. Remote but owned (Ports & Adapters) + +Your own services across a network boundary (microservices, internal APIs). Define a **port** (interface) at the seam. The deep module owns the logic; the transport is injected as an **adapter**. Tests use an in-memory adapter. Production uses an HTTP/gRPC/queue adapter. + +Recommendation shape: *"Define a port at the seam, implement an HTTP adapter for production and an in-memory adapter for testing, so the logic sits in one deep module even though it's deployed across a network."* + +### 4. True external (Mock) + +Third-party services (Stripe, Twilio, etc.) you don't control. The deepened module takes the external dependency as an injected port; tests provide a mock adapter. + +## Seam discipline + +- **One adapter means a hypothetical seam. Two adapters means a real one.** Don't introduce a port unless at least two adapters are justified (typically production + test). A single-adapter seam is just indirection. +- **Internal seams vs external seams.** A deep module can have internal seams (private to its implementation, used by its own tests) as well as the external seam at its interface. Don't expose internal seams through the interface just because tests use them. + +## Testing strategy: replace, don't layer + +- Old unit tests on shallow modules become waste once tests at the deepened module's interface exist — delete them. +- Write new tests at the deepened module's interface. The **interface is the test surface**. +- Tests assert on observable outcomes through the interface, not internal state. +- Tests should survive internal refactors — they describe behaviour, not implementation. If a test has to change when the implementation changes, it's testing past the interface. diff --git a/.cursor/skills/improve-codebase-architecture/INTERFACE-DESIGN.md b/.cursor/skills/improve-codebase-architecture/INTERFACE-DESIGN.md new file mode 100644 index 000000000..3197723a0 --- /dev/null +++ b/.cursor/skills/improve-codebase-architecture/INTERFACE-DESIGN.md @@ -0,0 +1,44 @@ +# Interface Design + +When the user wants to explore alternative interfaces for a chosen deepening candidate, use this parallel sub-agent pattern. Based on "Design It Twice" (Ousterhout) — your first idea is unlikely to be the best. + +Uses the vocabulary in [LANGUAGE.md](LANGUAGE.md) — **module**, **interface**, **seam**, **adapter**, **leverage**. + +## Process + +### 1. Frame the problem space + +Before spawning sub-agents, write a user-facing explanation of the problem space for the chosen candidate: + +- The constraints any new interface would need to satisfy +- The dependencies it would rely on, and which category they fall into (see [DEEPENING.md](DEEPENING.md)) +- A rough illustrative code sketch to ground the constraints — not a proposal, just a way to make the constraints concrete + +Show this to the user, then immediately proceed to Step 2. The user reads and thinks while the sub-agents work in parallel. + +### 2. Spawn sub-agents + +Spawn 3+ sub-agents in parallel using the Agent tool. Each must produce a **radically different** interface for the deepened module. + +Prompt each sub-agent with a separate technical brief (file paths, coupling details, dependency category from [DEEPENING.md](DEEPENING.md), what sits behind the seam). The brief is independent of the user-facing problem-space explanation in Step 1. Give each agent a different design constraint: + +- Agent 1: "Minimize the interface — aim for 1–3 entry points max. Maximise leverage per entry point." +- Agent 2: "Maximise flexibility — support many use cases and extension." +- Agent 3: "Optimise for the most common caller — make the default case trivial." +- Agent 4 (if applicable): "Design around ports & adapters for cross-seam dependencies." + +Include both [LANGUAGE.md](LANGUAGE.md) vocabulary and CONTEXT.md vocabulary in the brief so each sub-agent names things consistently with the architecture language and the project's domain language. + +Each sub-agent outputs: + +1. Interface (types, methods, params — plus invariants, ordering, error modes) +2. Usage example showing how callers use it +3. What the implementation hides behind the seam +4. Dependency strategy and adapters (see [DEEPENING.md](DEEPENING.md)) +5. Trade-offs — where leverage is high, where it's thin + +### 3. Present and compare + +Present designs sequentially so the user can absorb each one, then compare them in prose. Contrast by **depth** (leverage at the interface), **locality** (where change concentrates), and **seam placement**. + +After comparing, give your own recommendation: which design you think is strongest and why. If elements from different designs would combine well, propose a hybrid. Be opinionated — the user wants a strong read, not a menu. diff --git a/.cursor/skills/improve-codebase-architecture/LANGUAGE.md b/.cursor/skills/improve-codebase-architecture/LANGUAGE.md new file mode 100644 index 000000000..530c27630 --- /dev/null +++ b/.cursor/skills/improve-codebase-architecture/LANGUAGE.md @@ -0,0 +1,53 @@ +# Language + +Shared vocabulary for every suggestion this skill makes. Use these terms exactly — don't substitute "component," "service," "API," or "boundary." Consistent language is the whole point. + +## Terms + +**Module** +Anything with an interface and an implementation. Deliberately scale-agnostic — applies equally to a function, class, package, or tier-spanning slice. +_Avoid_: unit, component, service. + +**Interface** +Everything a caller must know to use the module correctly. Includes the type signature, but also invariants, ordering constraints, error modes, required configuration, and performance characteristics. +_Avoid_: API, signature (too narrow — those refer only to the type-level surface). + +**Implementation** +What's inside a module — its body of code. Distinct from **Adapter**: a thing can be a small adapter with a large implementation (a Postgres repo) or a large adapter with a small implementation (an in-memory fake). Reach for "adapter" when the seam is the topic; "implementation" otherwise. + +**Depth** +Leverage at the interface — the amount of behaviour a caller (or test) can exercise per unit of interface they have to learn. A module is **deep** when a large amount of behaviour sits behind a small interface. A module is **shallow** when the interface is nearly as complex as the implementation. + +**Seam** _(from Michael Feathers)_ +A place where you can alter behaviour without editing in that place. The *location* at which a module's interface lives. Choosing where to put the seam is its own design decision, distinct from what goes behind it. +_Avoid_: boundary (overloaded with DDD's bounded context). + +**Adapter** +A concrete thing that satisfies an interface at a seam. Describes *role* (what slot it fills), not substance (what's inside). + +**Leverage** +What callers get from depth. More capability per unit of interface they have to learn. One implementation pays back across N call sites and M tests. + +**Locality** +What maintainers get from depth. Change, bugs, knowledge, and verification concentrate at one place rather than spreading across callers. Fix once, fixed everywhere. + +## Principles + +- **Depth is a property of the interface, not the implementation.** A deep module can be internally composed of small, mockable, swappable parts — they just aren't part of the interface. A module can have **internal seams** (private to its implementation, used by its own tests) as well as the **external seam** at its interface. +- **The deletion test.** Imagine deleting the module. If complexity vanishes, the module wasn't hiding anything (it was a pass-through). If complexity reappears across N callers, the module was earning its keep. +- **The interface is the test surface.** Callers and tests cross the same seam. If you want to test *past* the interface, the module is probably the wrong shape. +- **One adapter means a hypothetical seam. Two adapters means a real one.** Don't introduce a seam unless something actually varies across it. + +## Relationships + +- A **Module** has exactly one **Interface** (the surface it presents to callers and tests). +- **Depth** is a property of a **Module**, measured against its **Interface**. +- A **Seam** is where a **Module**'s **Interface** lives. +- An **Adapter** sits at a **Seam** and satisfies the **Interface**. +- **Depth** produces **Leverage** for callers and **Locality** for maintainers. + +## Rejected framings + +- **Depth as ratio of implementation-lines to interface-lines** (Ousterhout): rewards padding the implementation. We use depth-as-leverage instead. +- **"Interface" as the TypeScript `interface` keyword or a class's public methods**: too narrow — interface here includes every fact a caller must know. +- **"Boundary"**: overloaded with DDD's bounded context. Say **seam** or **interface**. diff --git a/.cursor/skills/improve-codebase-architecture/SKILL.md b/.cursor/skills/improve-codebase-architecture/SKILL.md new file mode 100644 index 000000000..05984a609 --- /dev/null +++ b/.cursor/skills/improve-codebase-architecture/SKILL.md @@ -0,0 +1,71 @@ +--- +name: improve-codebase-architecture +description: Find deepening opportunities in a codebase, informed by the domain language in CONTEXT.md and the decisions in docs/adr/. Use when the user wants to improve architecture, find refactoring opportunities, consolidate tightly-coupled modules, or make a codebase more testable and AI-navigable. +--- + +# Improve Codebase Architecture + +Surface architectural friction and propose **deepening opportunities** — refactors that turn shallow modules into deep ones. The aim is testability and AI-navigability. + +## Glossary + +Use these terms exactly in every suggestion. Consistent language is the point — don't drift into "component," "service," "API," or "boundary." Full definitions in [LANGUAGE.md](LANGUAGE.md). + +- **Module** — anything with an interface and an implementation (function, class, package, slice). +- **Interface** — everything a caller must know to use the module: types, invariants, error modes, ordering, config. Not just the type signature. +- **Implementation** — the code inside. +- **Depth** — leverage at the interface: a lot of behaviour behind a small interface. **Deep** = high leverage. **Shallow** = interface nearly as complex as the implementation. +- **Seam** — where an interface lives; a place behaviour can be altered without editing in place. (Use this, not "boundary.") +- **Adapter** — a concrete thing satisfying an interface at a seam. +- **Leverage** — what callers get from depth. +- **Locality** — what maintainers get from depth: change, bugs, knowledge concentrated in one place. + +Key principles (see [LANGUAGE.md](LANGUAGE.md) for the full list): + +- **Deletion test**: imagine deleting the module. If complexity vanishes, it was a pass-through. If complexity reappears across N callers, it was earning its keep. +- **The interface is the test surface.** +- **One adapter = hypothetical seam. Two adapters = real seam.** + +This skill is _informed_ by the project's domain model. The domain language gives names to good seams; ADRs record decisions the skill should not re-litigate. + +## Process + +### 1. Explore + +Read the project's domain glossary and any ADRs in the area you're touching first. + +Then use the Agent tool with `subagent_type=Explore` to walk the codebase. Don't follow rigid heuristics — explore organically and note where you experience friction: + +- Where does understanding one concept require bouncing between many small modules? +- Where are modules **shallow** — interface nearly as complex as the implementation? +- Where have pure functions been extracted just for testability, but the real bugs hide in how they're called (no **locality**)? +- Where do tightly-coupled modules leak across their seams? +- Which parts of the codebase are untested, or hard to test through their current interface? + +Apply the **deletion test** to anything you suspect is shallow: would deleting it concentrate complexity, or just move it? A "yes, concentrates" is the signal you want. + +### 2. Present candidates + +Present a numbered list of deepening opportunities. For each candidate: + +- **Files** — which files/modules are involved +- **Problem** — why the current architecture is causing friction +- **Solution** — plain English description of what would change +- **Benefits** — explained in terms of locality and leverage, and also in how tests would improve + +**Use CONTEXT.md vocabulary for the domain, and [LANGUAGE.md](LANGUAGE.md) vocabulary for the architecture.** If `CONTEXT.md` defines "Order," talk about "the Order intake module" — not "the FooBarHandler," and not "the Order service." + +**ADR conflicts**: if a candidate contradicts an existing ADR, only surface it when the friction is real enough to warrant revisiting the ADR. Mark it clearly (e.g. _"contradicts ADR-0007 — but worth reopening because…"_). Don't list every theoretical refactor an ADR forbids. + +Do NOT propose interfaces yet. Ask the user: "Which of these would you like to explore?" + +### 3. Grilling loop + +Once the user picks a candidate, drop into a grilling conversation. Walk the design tree with them — constraints, dependencies, the shape of the deepened module, what sits behind the seam, what tests survive. + +Side effects happen inline as decisions crystallize: + +- **Naming a deepened module after a concept not in `CONTEXT.md`?** Add the term to `CONTEXT.md` — same discipline as `/grill-with-docs` (see [CONTEXT-FORMAT.md](../grill-with-docs/CONTEXT-FORMAT.md)). Create the file lazily if it doesn't exist. +- **Sharpening a fuzzy term during the conversation?** Update `CONTEXT.md` right there. +- **User rejects the candidate with a load-bearing reason?** Offer an ADR, framed as: _"Want me to record this as an ADR so future architecture reviews don't re-suggest it?"_ Only offer when the reason would actually be needed by a future explorer to avoid re-suggesting the same thing — skip ephemeral reasons ("not worth it right now") and self-evident ones. See [ADR-FORMAT.md](../grill-with-docs/ADR-FORMAT.md). +- **Want to explore alternative interfaces for the deepened module?** See [INTERFACE-DESIGN.md](INTERFACE-DESIGN.md). diff --git a/.gitignore b/.gitignore index 29c140ed3..ab24c0c05 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ surfsense_web/playwright/.auth/ surfsense_web/playwright-report/ surfsense_web/test-results/ surfsense_web/blob-report/ +hermes-agent/ diff --git a/.vscode/settings.json b/.vscode/settings.json index 05bd30702..7da4b54f8 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,4 +1,9 @@ { "biome.configurationPath": "./surfsense_web/biome.json", - "deepscan.ignoreConfirmWarning": true + "deepscan.ignoreConfirmWarning": true, + "python.defaultInterpreterPath": "${workspaceFolder}/surfsense_backend/.venv/bin/python", + "basedpyright.analysis.extraPaths": [ + "${workspaceFolder}/surfsense_backend" + ], + "python-envs.pythonProjects": [] } \ No newline at end of file diff --git a/VERSION b/VERSION index 818944f5b..df5db66fe 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.0.22 +0.0.23 diff --git a/docker/.env.example b/docker/.env.example index fd56bdccc..aba15f13f 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -324,7 +324,6 @@ SURFSENSE_ENABLE_ACTION_LOG=true SURFSENSE_ENABLE_REVERT_ROUTE=true SURFSENSE_ENABLE_PERMISSION=true SURFSENSE_ENABLE_DOOM_LOOP=true -SURFSENSE_ENABLE_STREAM_PARITY_V2=true # Periodic connector sync interval (default: 5m) # SCHEDULE_CHECKER_INTERVAL=5m diff --git a/skills-lock.json b/skills-lock.json index ce251e303..f722ec0d3 100644 --- a/skills-lock.json +++ b/skills-lock.json @@ -46,6 +46,12 @@ "sourceType": "github", "computedHash": "ddd61f32254be1303ce4b7be5d507c932de4af53489a0ebb1309bf61de99018c" }, + "improve-codebase-architecture": { + "source": "mattpocock/skills", + "sourceType": "github", + "skillPath": "skills/engineering/improve-codebase-architecture/SKILL.md", + "computedHash": "2da1d23b8f53cfe67f2e0b68924ab9f4ec400bb6480de097007eeaeb517d1722" + }, "internal-linking-optimizer": { "source": "aaron-he-zhu/seo-geo-claude-skills", "sourceType": "github", diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index ba89059c8..3d442973c 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -315,14 +315,6 @@ LANGSMITH_PROJECT=surfsense # SURFSENSE_ENABLE_ACTION_LOG=false # SURFSENSE_ENABLE_REVERT_ROUTE=false # Backend-only; flip when UI ships -# Streaming parity v2 — opt in to LangChain's structured AIMessageChunk -# content (typed reasoning blocks, tool-input deltas) and propagate the -# real tool_call_id to the SSE layer. When OFF, the stream falls back to -# the str-only text path and synthetic "call_" tool-call ids. -# Schema migrations 135/136 ship unconditionally because they are -# forward-compatible. -# SURFSENSE_ENABLE_STREAM_PARITY_V2=false - # Plugins # SURFSENSE_ENABLE_PLUGIN_LOADER=false # Comma-separated allowlist of plugin entry-point names diff --git a/surfsense_backend/app/agents/multi_agent_chat/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/__init__.py index f568dc6b2..6c7d79eb8 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/__init__.py +++ b/surfsense_backend/app/agents/multi_agent_chat/__init__.py @@ -2,6 +2,6 @@ from __future__ import annotations -from .main_agent import create_surfsense_deep_agent +from .main_agent import create_multi_agent_chat_deep_agent -__all__ = ["create_surfsense_deep_agent"] +__all__ = ["create_multi_agent_chat_deep_agent"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/__init__.py index b9a18fe53..f74ca0cd0 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/__init__.py +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/__init__.py @@ -2,6 +2,6 @@ from __future__ import annotations -from .runtime import create_surfsense_deep_agent +from .runtime import create_multi_agent_chat_deep_agent -__all__ = ["create_surfsense_deep_agent"] +__all__ = ["create_multi_agent_chat_deep_agent"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/compile_graph_sync.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/compile_graph_sync.py index 7afa30a31..4ed94bf7b 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/compile_graph_sync.py +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/compile_graph_sync.py @@ -11,6 +11,9 @@ from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool from langgraph.types import Checkpointer +from app.agents.multi_agent_chat.middleware import ( + build_main_agent_deepagent_middleware, +) from app.agents.multi_agent_chat.subagents.shared.permissions import ( ToolsPermissions, ) @@ -19,8 +22,6 @@ from app.agents.new_chat.feature_flags import AgentFeatureFlags from app.agents.new_chat.filesystem_selection import FilesystemMode from app.db import ChatVisibility -from .middleware import build_main_agent_deepagent_middleware - def build_compiled_agent_graph_sync( *, diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/__init__.py deleted file mode 100644 index 757ee02f8..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Main-agent graph middleware assembly (SurfSense + LangChain + deepagents).""" - -from __future__ import annotations - -from .deepagent_stack import build_main_agent_deepagent_middleware - -__all__ = ["build_main_agent_deepagent_middleware"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/checkpointed_subagent_middleware/config.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/checkpointed_subagent_middleware/config.py deleted file mode 100644 index 16211686c..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/checkpointed_subagent_middleware/config.py +++ /dev/null @@ -1,44 +0,0 @@ -"""RunnableConfig wiring for nested subagent invocations. - -Forwards the parent's ``runtime.config`` (thread_id, …) into the subagent and -exposes the side-channel ``stream_resume_chat`` uses to ferry resume payloads. -""" - -from __future__ import annotations - -from typing import Any - -from langchain.tools import ToolRuntime - -from .constants import DEFAULT_SUBAGENT_RECURSION_LIMIT - - -def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]: - """RunnableConfig for the nested invoke; raises ``recursion_limit`` to the parent's budget.""" - merged: dict[str, Any] = dict(runtime.config) if runtime.config else {} - current_limit = merged.get("recursion_limit") - try: - current_int = int(current_limit) if current_limit is not None else 0 - except (TypeError, ValueError): - current_int = 0 - if current_int < DEFAULT_SUBAGENT_RECURSION_LIMIT: - merged["recursion_limit"] = DEFAULT_SUBAGENT_RECURSION_LIMIT - return merged - - -def consume_surfsense_resume(runtime: ToolRuntime) -> Any: - """Pop the resume payload; siblings share ``configurable`` by reference.""" - cfg = runtime.config or {} - configurable = cfg.get("configurable") if isinstance(cfg, dict) else None - if not isinstance(configurable, dict): - return None - return configurable.pop("surfsense_resume_value", None) - - -def has_surfsense_resume(runtime: ToolRuntime) -> bool: - """True iff a resume payload is queued on this runtime (non-destructive).""" - cfg = runtime.config or {} - configurable = cfg.get("configurable") if isinstance(cfg, dict) else None - if not isinstance(configurable, dict): - return False - return "surfsense_resume_value" in configurable diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/deepagent_stack.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/deepagent_stack.py deleted file mode 100644 index 74e47cfab..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/deepagent_stack.py +++ /dev/null @@ -1,506 +0,0 @@ -"""Assemble the main-agent deep-agent middleware list (LangChain + SurfSense + deepagents).""" - -from __future__ import annotations - -import logging -from collections.abc import Sequence -from typing import Any - -from deepagents import SubAgent -from deepagents.backends import StateBackend -from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware -from deepagents.middleware.skills import SkillsMiddleware -from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT -from langchain.agents.middleware import ( - LLMToolSelectorMiddleware, - ModelCallLimitMiddleware, - ModelFallbackMiddleware, - TodoListMiddleware, - ToolCallLimitMiddleware, -) -from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware -from langchain_core.language_models import BaseChatModel -from langchain_core.tools import BaseTool -from langgraph.types import Checkpointer - -from app.agents.multi_agent_chat.subagents import ( - build_subagents, - get_subagents_to_exclude, -) -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, -) -from app.agents.new_chat.feature_flags import AgentFeatureFlags -from app.agents.new_chat.filesystem_selection import FilesystemMode -from app.agents.new_chat.middleware import ( - ActionLogMiddleware, - AnonymousDocumentMiddleware, - BusyMutexMiddleware, - ClearToolUsesEdit, - DedupHITLToolCallsMiddleware, - DoomLoopMiddleware, - FileIntentMiddleware, - KnowledgeBasePersistenceMiddleware, - KnowledgePriorityMiddleware, - KnowledgeTreeMiddleware, - MemoryInjectionMiddleware, - NoopInjectionMiddleware, - OtelSpanMiddleware, - PermissionMiddleware, - RetryAfterMiddleware, - SpillingContextEditingMiddleware, - SpillToBackendEdit, - SurfSenseFilesystemMiddleware, - ToolCallNameRepairMiddleware, - build_skills_backend_factory, - create_surfsense_compaction_middleware, - default_skills_sources, -) -from app.agents.new_chat.permissions import Rule, Ruleset -from app.agents.new_chat.plugin_loader import ( - PluginContext, - load_allowed_plugin_names_from_env, - load_plugin_middlewares, -) -from app.agents.new_chat.tools.registry import BUILTIN_TOOLS -from app.db import ChatVisibility - -from ...context_prune.prune_tool_names import safe_exclude_tools -from .checkpointed_subagent_middleware import SurfSenseCheckpointedSubAgentMiddleware - - -def build_main_agent_deepagent_middleware( - *, - llm: BaseChatModel, - tools: Sequence[BaseTool], - backend_resolver: Any, - filesystem_mode: FilesystemMode, - search_space_id: int, - user_id: str | None, - thread_id: int | None, - visibility: ChatVisibility, - anon_session_id: str | None, - available_connectors: list[str] | None, - available_document_types: list[str] | None, - mentioned_document_ids: list[int] | None, - max_input_tokens: int | None, - flags: AgentFeatureFlags, - subagent_dependencies: dict[str, Any], - checkpointer: Checkpointer, - mcp_tools_by_agent: dict[str, ToolsPermissions] | None = None, - disabled_tools: list[str] | None = None, -) -> list[Any]: - """Build ordered middleware for ``create_agent`` (Nones already stripped).""" - _memory_middleware = MemoryInjectionMiddleware( - user_id=user_id, - search_space_id=search_space_id, - thread_visibility=visibility, - ) - - gp_middleware = [ - TodoListMiddleware(), - _memory_middleware, - FileIntentMiddleware(llm=llm), - SurfSenseFilesystemMiddleware( - backend=backend_resolver, - filesystem_mode=filesystem_mode, - search_space_id=search_space_id, - created_by_id=user_id, - thread_id=thread_id, - ), - create_surfsense_compaction_middleware(llm, StateBackend), - PatchToolCallsMiddleware(), - AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"), - ] - - # Build permission rulesets up front so the GP subagent can mirror ``ask`` - # rules into ``interrupt_on``: tool calls emitted from within ``task`` runs - # never reach the parent's ``PermissionMiddleware``. - is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER - permission_enabled = flags.enable_permission and not flags.disable_new_agent_stack - permission_rulesets: list[Ruleset] = [] - if permission_enabled or is_desktop_fs: - permission_rulesets.append( - Ruleset( - rules=[Rule(permission="*", pattern="*", action="allow")], - origin="surfsense_defaults", - ) - ) - if is_desktop_fs: - permission_rulesets.append( - Ruleset( - rules=[ - Rule(permission="rm", pattern="*", action="ask"), - Rule(permission="rmdir", pattern="*", action="ask"), - Rule(permission="move_file", pattern="*", action="ask"), - Rule(permission="edit_file", pattern="*", action="ask"), - Rule(permission="write_file", pattern="*", action="ask"), - ], - origin="desktop_safety", - ) - ) - - # Tools that self-prompt via ``request_approval`` must not also appear - # as ``ask`` rules — that would double-prompt the user for one call. - _tool_names_in_use = {t.name for t in tools} - - # Deny parent-bound tools whose ``required_connector`` is missing. - # No-op today (connector subagents are pruned upstream); guards future - # additions to the parent's tool list. - if permission_enabled: - _available_set = set(available_connectors or []) - _synthesized: list[Rule] = [] - for tool_def in BUILTIN_TOOLS: - if tool_def.name not in _tool_names_in_use: - continue - rc = tool_def.required_connector - if rc and rc not in _available_set: - _synthesized.append( - Rule(permission=tool_def.name, pattern="*", action="deny") - ) - if _synthesized: - permission_rulesets.append( - Ruleset(rules=_synthesized, origin="connector_synthesized") - ) - gp_interrupt_on: dict[str, bool] = { - rule.permission: True - for rs in permission_rulesets - for rule in rs.rules - if rule.action == "ask" and rule.permission in _tool_names_in_use - } - - general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key] - **GENERAL_PURPOSE_SUBAGENT, - "model": llm, - "tools": tools, - "middleware": gp_middleware, - } - if gp_interrupt_on: - general_purpose_spec["interrupt_on"] = gp_interrupt_on - - # Deny-only on subagents: ``task`` runs bypass the parent's - # PermissionMiddleware, while bucket-based ask gates own the ask path. - subagent_deny_rulesets: list[Ruleset] = [ - Ruleset( - rules=[r for r in rs.rules if r.action == "deny"], - origin=rs.origin, - ) - for rs in permission_rulesets - ] - subagent_deny_rulesets = [rs for rs in subagent_deny_rulesets if rs.rules] - - subagent_deny_permission_mw: PermissionMiddleware | None = ( - PermissionMiddleware(rulesets=subagent_deny_rulesets) - if subagent_deny_rulesets - else None - ) - - if subagent_deny_permission_mw is not None: - # Run deny check on already-repaired tool calls; insert before - # PatchToolCallsMiddleware (append if the slot moves). - _patch_idx = next( - ( - i - for i, m in enumerate(gp_middleware) - if isinstance(m, PatchToolCallsMiddleware) - ), - len(gp_middleware), - ) - gp_middleware.insert(_patch_idx, subagent_deny_permission_mw) - - registry_subagents: list[SubAgent] = [] - try: - subagent_extra_middleware: list[Any] = [ - TodoListMiddleware(), - SurfSenseFilesystemMiddleware( - backend=backend_resolver, - filesystem_mode=filesystem_mode, - search_space_id=search_space_id, - created_by_id=user_id, - thread_id=thread_id, - ), - ] - if subagent_deny_permission_mw is not None: - subagent_extra_middleware.append(subagent_deny_permission_mw) - registry_subagents = build_subagents( - dependencies=subagent_dependencies, - model=llm, - extra_middleware=subagent_extra_middleware, - mcp_tools_by_agent=mcp_tools_by_agent or {}, - exclude=get_subagents_to_exclude(available_connectors), - disabled_tools=disabled_tools, - ) - logging.info( - "Registry subagents: %s", - [s["name"] for s in registry_subagents], - ) - except Exception: - logging.exception("Registry subagent build failed") - raise - - subagent_specs: list[SubAgent] = [general_purpose_spec, *registry_subagents] - - summarization_mw = create_surfsense_compaction_middleware(llm, StateBackend) - - context_edit_mw = None - if ( - flags.enable_context_editing - and not flags.disable_new_agent_stack - and max_input_tokens - ): - spill_edit = SpillToBackendEdit( - trigger=int(max_input_tokens * 0.55), - clear_at_least=int(max_input_tokens * 0.15), - keep=5, - exclude_tools=safe_exclude_tools(tools), - clear_tool_inputs=True, - ) - clear_edit = ClearToolUsesEdit( - trigger=int(max_input_tokens * 0.55), - clear_at_least=int(max_input_tokens * 0.15), - keep=5, - exclude_tools=safe_exclude_tools(tools), - clear_tool_inputs=True, - placeholder="[cleared - older tool output trimmed for context]", - ) - context_edit_mw = SpillingContextEditingMiddleware( - edits=[spill_edit, clear_edit], - backend_resolver=backend_resolver, - ) - - retry_mw = ( - RetryAfterMiddleware(max_retries=3) - if flags.enable_retry_after and not flags.disable_new_agent_stack - else None - ) - fallback_mw: ModelFallbackMiddleware | None = None - if flags.enable_model_fallback and not flags.disable_new_agent_stack: - try: - fallback_mw = ModelFallbackMiddleware( - "openai:gpt-4o-mini", - "anthropic:claude-3-5-haiku-20241022", - ) - except Exception: - logging.warning("ModelFallbackMiddleware init failed; skipping.") - fallback_mw = None - model_call_limit_mw = ( - ModelCallLimitMiddleware( - thread_limit=120, - run_limit=80, - exit_behavior="end", - ) - if flags.enable_model_call_limit and not flags.disable_new_agent_stack - else None - ) - tool_call_limit_mw = ( - ToolCallLimitMiddleware( - thread_limit=300, run_limit=80, exit_behavior="continue" - ) - if flags.enable_tool_call_limit and not flags.disable_new_agent_stack - else None - ) - - noop_mw = ( - NoopInjectionMiddleware() - if flags.enable_compaction_v2 and not flags.disable_new_agent_stack - else None - ) - - repair_mw = None - if flags.enable_tool_call_repair and not flags.disable_new_agent_stack: - registered_names: set[str] = {t.name for t in tools} - registered_names |= { - "write_todos", - "ls", - "read_file", - "write_file", - "edit_file", - "glob", - "grep", - "execute", - "task", - "mkdir", - "cd", - "pwd", - "move_file", - "rm", - "rmdir", - "list_tree", - "execute_code", - } - repair_mw = ToolCallNameRepairMiddleware( - registered_tool_names=registered_names, - fuzzy_match_threshold=None, - ) - - doom_loop_mw = ( - DoomLoopMiddleware(threshold=3) - if flags.enable_doom_loop and not flags.disable_new_agent_stack - else None - ) - - permission_mw: PermissionMiddleware | None = ( - PermissionMiddleware(rulesets=permission_rulesets) - if permission_rulesets - else None - ) - - action_log_mw: ActionLogMiddleware | None = None - if ( - flags.enable_action_log - and not flags.disable_new_agent_stack - and thread_id is not None - ): - try: - tool_defs_by_name = {td.name: td for td in BUILTIN_TOOLS} - action_log_mw = ActionLogMiddleware( - thread_id=thread_id, - search_space_id=search_space_id, - user_id=user_id, - tool_definitions=tool_defs_by_name, - ) - except Exception: # pragma: no cover - defensive - logging.warning( - "ActionLogMiddleware init failed; running without it.", - exc_info=True, - ) - action_log_mw = None - - busy_mutex_mw: BusyMutexMiddleware | None = ( - BusyMutexMiddleware() - if flags.enable_busy_mutex and not flags.disable_new_agent_stack - else None - ) - - otel_mw: OtelSpanMiddleware | None = ( - OtelSpanMiddleware() - if flags.enable_otel and not flags.disable_new_agent_stack - else None - ) - - plugin_middlewares: list[Any] = [] - if flags.enable_plugin_loader and not flags.disable_new_agent_stack: - try: - allowed_names = load_allowed_plugin_names_from_env() - if allowed_names: - plugin_middlewares = load_plugin_middlewares( - PluginContext.build( - search_space_id=search_space_id, - user_id=user_id, - thread_visibility=visibility, - llm=llm, - ), - allowed_plugin_names=allowed_names, - ) - except Exception: # pragma: no cover - defensive - logging.warning( - "Plugin loader failed; continuing without plugins.", - exc_info=True, - ) - plugin_middlewares = [] - - skills_mw: SkillsMiddleware | None = None - if flags.enable_skills and not flags.disable_new_agent_stack: - try: - skills_factory = build_skills_backend_factory( - search_space_id=search_space_id - if filesystem_mode == FilesystemMode.CLOUD - else None, - ) - skills_mw = SkillsMiddleware( - backend=skills_factory, - sources=default_skills_sources(), - ) - except Exception as exc: # pragma: no cover - defensive - logging.warning("SkillsMiddleware init failed; skipping: %s", exc) - skills_mw = None - - selector_mw: LLMToolSelectorMiddleware | None = None - if ( - flags.enable_llm_tool_selector - and not flags.disable_new_agent_stack - and len(tools) > 30 - ): - try: - selector_mw = LLMToolSelectorMiddleware( - model="openai:gpt-4o-mini", - max_tools=12, - always_include=[ - name - for name in ( - "update_memory", - "get_connected_accounts", - "scrape_webpage", - ) - if name in {t.name for t in tools} - ], - ) - except Exception: - logging.warning("LLMToolSelectorMiddleware init failed; skipping.") - selector_mw = None - - deepagent_middleware = [ - busy_mutex_mw, - otel_mw, - TodoListMiddleware(), - _memory_middleware, - AnonymousDocumentMiddleware( - anon_session_id=anon_session_id, - ) - if filesystem_mode == FilesystemMode.CLOUD - else None, - KnowledgeTreeMiddleware( - search_space_id=search_space_id, - filesystem_mode=filesystem_mode, - llm=llm, - ) - if filesystem_mode == FilesystemMode.CLOUD - else None, - KnowledgePriorityMiddleware( - llm=llm, - search_space_id=search_space_id, - filesystem_mode=filesystem_mode, - available_connectors=available_connectors, - available_document_types=available_document_types, - mentioned_document_ids=mentioned_document_ids, - ), - FileIntentMiddleware(llm=llm), - SurfSenseFilesystemMiddleware( - backend=backend_resolver, - filesystem_mode=filesystem_mode, - search_space_id=search_space_id, - created_by_id=user_id, - thread_id=thread_id, - ), - KnowledgeBasePersistenceMiddleware( - search_space_id=search_space_id, - created_by_id=user_id, - filesystem_mode=filesystem_mode, - thread_id=thread_id, - ) - if filesystem_mode == FilesystemMode.CLOUD - else None, - skills_mw, - SurfSenseCheckpointedSubAgentMiddleware( - checkpointer=checkpointer, - backend=StateBackend, - subagents=subagent_specs, - ), - selector_mw, - model_call_limit_mw, - tool_call_limit_mw, - context_edit_mw, - summarization_mw, - noop_mw, - retry_mw, - fallback_mw, - repair_mw, - permission_mw, - doom_loop_mw, - action_log_mw, - PatchToolCallsMiddleware(), - DedupHITLToolCallsMiddleware(agent_tools=list(tools)), - *plugin_middlewares, - AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"), - ] - return [m for m in deepagent_middleware if m is not None] diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/__init__.py index 3d4ae977d..593e8da20 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/__init__.py +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/__init__.py @@ -2,6 +2,6 @@ from __future__ import annotations -from .factory import create_surfsense_deep_agent +from .factory import create_multi_agent_chat_deep_agent -__all__ = ["create_surfsense_deep_agent"] +__all__ = ["create_multi_agent_chat_deep_agent"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/agent_cache.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/agent_cache.py new file mode 100644 index 000000000..42f984b79 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/agent_cache.py @@ -0,0 +1,117 @@ +"""Compiled agent graph caching for the multi-agent path.""" + +from __future__ import annotations + +import asyncio +from collections.abc import Sequence +from typing import Any + +from langchain_core.language_models import BaseChatModel +from langchain_core.tools import BaseTool +from langgraph.types import Checkpointer + +from app.agents.multi_agent_chat.subagents.shared.permissions import ToolsPermissions +from app.agents.new_chat.agent_cache import ( + flags_signature, + get_cache, + stable_hash, + system_prompt_hash, + tools_signature, +) +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.db import ChatVisibility + +from ..graph.compile_graph_sync import build_compiled_agent_graph_sync + + +def mcp_signature(mcp_tools_by_agent: dict[str, ToolsPermissions]) -> str: + """Hash the per-agent MCP tool surface so a change rotates the cache key.""" + rows = [] + for agent_name in sorted(mcp_tools_by_agent.keys()): + perms = mcp_tools_by_agent[agent_name] + allow_names = sorted(item.get("name", "") for item in perms.get("allow", [])) + ask_names = sorted(item.get("name", "") for item in perms.get("ask", [])) + rows.append((agent_name, allow_names, ask_names)) + return stable_hash(rows) + + +async def build_agent_with_cache( + *, + llm: BaseChatModel, + tools: Sequence[BaseTool], + final_system_prompt: str, + backend_resolver: Any, + filesystem_mode: FilesystemMode, + search_space_id: int, + user_id: str | None, + thread_id: int | None, + visibility: ChatVisibility, + anon_session_id: str | None, + available_connectors: list[str], + available_document_types: list[str], + mentioned_document_ids: list[int] | None, + max_input_tokens: int | None, + flags: AgentFeatureFlags, + checkpointer: Checkpointer, + subagent_dependencies: dict[str, Any], + mcp_tools_by_agent: dict[str, ToolsPermissions], + disabled_tools: list[str] | None, + config_id: str | None, +) -> Any: + """Compile the multi-agent graph, serving from cache when key components are stable.""" + + async def _build() -> Any: + return await asyncio.to_thread( + build_compiled_agent_graph_sync, + llm=llm, + tools=tools, + final_system_prompt=final_system_prompt, + backend_resolver=backend_resolver, + filesystem_mode=filesystem_mode, + search_space_id=search_space_id, + user_id=user_id, + thread_id=thread_id, + visibility=visibility, + anon_session_id=anon_session_id, + available_connectors=available_connectors, + available_document_types=available_document_types, + mentioned_document_ids=mentioned_document_ids, + max_input_tokens=max_input_tokens, + flags=flags, + checkpointer=checkpointer, + subagent_dependencies=subagent_dependencies, + mcp_tools_by_agent=mcp_tools_by_agent, + disabled_tools=disabled_tools, + ) + + if not (flags.enable_agent_cache and not flags.disable_new_agent_stack): + return await _build() + + # Every per-request value any middleware closes over at __init__ must be in + # the key, otherwise a hit will leak state across threads. Bump the schema + # version when the component list changes shape. + cache_key = stable_hash( + "multi-agent-v1", + config_id, + thread_id, + user_id, + search_space_id, + visibility, + filesystem_mode, + anon_session_id, + tools_signature( + tools, + available_connectors=available_connectors, + available_document_types=available_document_types, + ), + mcp_signature(mcp_tools_by_agent), + flags_signature(flags), + system_prompt_hash(final_system_prompt), + max_input_tokens, + sorted(disabled_tools) if disabled_tools else None, + ) + return await get_cache().get_or_build(cache_key, builder=_build) + + +__all__ = ["build_agent_with_cache", "mcp_signature"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py index 6a6fd39b7..cb6410acb 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py @@ -2,7 +2,6 @@ from __future__ import annotations -import asyncio import logging import time from collections.abc import Sequence @@ -26,23 +25,24 @@ from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags from app.agents.new_chat.filesystem_backends import build_backend_resolver from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection from app.agents.new_chat.llm_config import AgentConfig +from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_tool from app.agents.new_chat.tools.registry import build_tools_async from app.db import ChatVisibility from app.services.connector_service import ConnectorService from app.utils.perf import get_perf_logger -from ..graph.compile_graph_sync import build_compiled_agent_graph_sync from ..system_prompt import build_main_agent_system_prompt from ..tools import ( MAIN_AGENT_SURFSENSE_TOOL_NAMES, MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED, ) +from .agent_cache import build_agent_with_cache _perf_log = get_perf_logger() -async def create_surfsense_deep_agent( +async def create_multi_agent_chat_deep_agent( llm: BaseChatModel, search_space_id: int, db_session: AsyncSession, @@ -62,6 +62,9 @@ async def create_surfsense_deep_agent( ): """Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled.""" _t_agent_total = time.perf_counter() + + apply_litellm_prompt_caching(llm, agent_config=agent_config, thread_id=thread_id) + filesystem_selection = filesystem_selection or FilesystemSelection() backend_resolver = build_backend_resolver( filesystem_selection, @@ -85,7 +88,18 @@ async def create_surfsense_deep_agent( ) except Exception as e: - logging.warning("Failed to discover available connectors/document types: %s", e) + logging.warning( + "Connector/doc-type discovery failed; excluding connector subagents this turn: %s", + e, + ) + + # Fail closed: a None list short-circuits ``get_subagents_to_exclude`` to "exclude + # nothing", which would silently advertise every connector specialist on a flaky + # discovery call. Empty list excludes connector-gated subagents while keeping builtins. + if available_connectors is None: + available_connectors = [] + if available_document_types is None: + available_document_types = [] _perf_log.info( "[create_agent] Connector/doc-type discovery in %.3fs", time.perf_counter() - _t0, @@ -115,7 +129,18 @@ async def create_surfsense_deep_agent( } _t0 = time.perf_counter() - mcp_tools_by_agent = await load_mcp_tools_by_connector(db_session, search_space_id) + try: + mcp_tools_by_agent = await load_mcp_tools_by_connector( + db_session, search_space_id + ) + except Exception as e: + # Degrade to builtins-only rather than aborting the turn: a transient + # DB or MCP-server hiccup should not deny the user a response. + logging.warning( + "MCP tool discovery failed; subagents will run without MCP tools this turn: %s", + e, + ) + mcp_tools_by_agent = {} _perf_log.info( "[create_agent] load_mcp_tools_by_connector in %.3fs (%d buckets)", time.perf_counter() - _t0, @@ -195,9 +220,10 @@ async def create_surfsense_deep_agent( final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT + config_id = agent_config.config_id if agent_config is not None else None + _t0 = time.perf_counter() - agent = await asyncio.to_thread( - build_compiled_agent_graph_sync, + agent = await build_agent_with_cache( llm=llm, tools=tools, final_system_prompt=final_system_prompt, @@ -217,6 +243,7 @@ async def create_surfsense_deep_agent( subagent_dependencies=dependencies, mcp_tools_by_agent=mcp_tools_by_agent, disabled_tools=disabled_tools, + config_id=config_id, ) _perf_log.info( "[create_agent] Middleware stack + graph compiled in %.3fs", diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/__init__.py new file mode 100644 index 000000000..e6eed9fbe --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/__init__.py @@ -0,0 +1,7 @@ +"""Multi-agent middleware stack assembly.""" + +from __future__ import annotations + +from .stack import build_main_agent_deepagent_middleware + +__all__ = ["build_main_agent_deepagent_middleware"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/action_log.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/action_log.py new file mode 100644 index 000000000..c9f893d97 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/action_log.py @@ -0,0 +1,36 @@ +"""Audit row per tool call (reversibility metadata).""" + +from __future__ import annotations + +import logging + +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.new_chat.middleware import ActionLogMiddleware +from app.agents.new_chat.tools.registry import BUILTIN_TOOLS + +from ..shared.flags import enabled + + +def build_action_log_mw( + *, + flags: AgentFeatureFlags, + thread_id: int | None, + search_space_id: int, + user_id: str | None, +) -> ActionLogMiddleware | None: + if not enabled(flags, "enable_action_log") or thread_id is None: + return None + try: + tool_defs_by_name = {td.name: td for td in BUILTIN_TOOLS} + return ActionLogMiddleware( + thread_id=thread_id, + search_space_id=search_space_id, + user_id=user_id, + tool_definitions=tool_defs_by_name, + ) + except Exception: # pragma: no cover - defensive + logging.warning( + "ActionLogMiddleware init failed; running without it.", + exc_info=True, + ) + return None diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/anonymous_doc.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/anonymous_doc.py new file mode 100644 index 000000000..afd54a2d3 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/anonymous_doc.py @@ -0,0 +1,16 @@ +"""Anonymous document hydration from Redis (cloud only).""" + +from __future__ import annotations + +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.middleware import AnonymousDocumentMiddleware + + +def build_anonymous_doc_mw( + *, + filesystem_mode: FilesystemMode, + anon_session_id: str | None, +) -> AnonymousDocumentMiddleware | None: + if filesystem_mode != FilesystemMode.CLOUD: + return None + return AnonymousDocumentMiddleware(anon_session_id=anon_session_id) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/busy_mutex.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/busy_mutex.py new file mode 100644 index 000000000..0ea53bf16 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/busy_mutex.py @@ -0,0 +1,12 @@ +"""Per-thread cooperative lock around the whole turn.""" + +from __future__ import annotations + +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.new_chat.middleware import BusyMutexMiddleware + +from ..shared.flags import enabled + + +def build_busy_mutex_mw(flags: AgentFeatureFlags) -> BusyMutexMiddleware | None: + return BusyMutexMiddleware() if enabled(flags, "enable_busy_mutex") else None diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/checkpointed_subagent_middleware/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/checkpointed_subagent_middleware/__init__.py rename to surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/config.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/config.py new file mode 100644 index 000000000..ac232b92a --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/config.py @@ -0,0 +1,90 @@ +"""RunnableConfig wiring for nested subagent invocations. + +Forwards the parent's ``runtime.config`` (thread_id, …) into the subagent and +exposes the side-channel ``stream_resume_chat`` uses to ferry resume payloads. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from langchain.tools import ToolRuntime + +from .constants import DEFAULT_SUBAGENT_RECURSION_LIMIT + +logger = logging.getLogger(__name__) + +# langgraph stores the parent task's scratchpad under this configurable key; +# subagents inherit the chain via ``parent_scratchpad`` fallback. +_LANGGRAPH_SCRATCHPAD_KEY = "__pregel_scratchpad" + + +def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]: + """RunnableConfig for the nested invoke; raises ``recursion_limit`` to the parent's budget.""" + merged: dict[str, Any] = dict(runtime.config) if runtime.config else {} + current_limit = merged.get("recursion_limit") + try: + current_int = int(current_limit) if current_limit is not None else 0 + except (TypeError, ValueError): + current_int = 0 + if current_int < DEFAULT_SUBAGENT_RECURSION_LIMIT: + merged["recursion_limit"] = DEFAULT_SUBAGENT_RECURSION_LIMIT + return merged + + +def consume_surfsense_resume(runtime: ToolRuntime) -> Any: + """Pop the resume payload; siblings share ``configurable`` by reference.""" + cfg = runtime.config or {} + configurable = cfg.get("configurable") if isinstance(cfg, dict) else None + if not isinstance(configurable, dict): + return None + return configurable.pop("surfsense_resume_value", None) + + +def has_surfsense_resume(runtime: ToolRuntime) -> bool: + """True iff a resume payload is queued on this runtime (non-destructive).""" + cfg = runtime.config or {} + configurable = cfg.get("configurable") if isinstance(cfg, dict) else None + if not isinstance(configurable, dict): + return False + return "surfsense_resume_value" in configurable + + +def drain_parent_null_resume(runtime: ToolRuntime) -> None: + """Consume the parent's lingering ``NULL_TASK_ID/RESUME`` write before delegating. + + ``stream_resume_chat`` wakes the main agent with + ``Command(resume={"decisions": [...]})`` so the propagated + ``_lg_interrupt(...)`` can return. langgraph stores that payload as the + parent task's ``null_resume`` pending write, which only gets consumed + *after* ``subagent.[a]invoke`` returns (when the post-call propagation + re-fires). While the subagent is mid-execution, any *new* ``interrupt()`` + inside it (e.g. a follow-up tool call after a mixed approve/reject) walks + ``subagent_scratchpad → parent_scratchpad.get_null_resume`` and picks up + the parent's still-live decisions — mismatching against a different number + of hanging tool calls and crashing ``HumanInTheLoopMiddleware``. + + Draining the write here closes that cross-graph leak so subagent + interrupts pause cleanly and re-propagate as a fresh approval card. + """ + cfg = runtime.config or {} + configurable = cfg.get("configurable") if isinstance(cfg, dict) else None + if not isinstance(configurable, dict): + return + scratchpad = configurable.get(_LANGGRAPH_SCRATCHPAD_KEY) + if scratchpad is None: + return + consume = getattr(scratchpad, "get_null_resume", None) + if not callable(consume): + return + try: + consume(True) + except Exception: + # Defensive: if langgraph's internal scratchpad shape changes we don't + # want to break the resume path. Worst case the original ValueError + # still surfaces — same behavior as before this fix. + logger.debug( + "drain_parent_null_resume: scratchpad.get_null_resume raised", + exc_info=True, + ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/checkpointed_subagent_middleware/constants.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/constants.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/checkpointed_subagent_middleware/constants.py rename to surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/constants.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/checkpointed_subagent_middleware/middleware.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/middleware.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/checkpointed_subagent_middleware/middleware.py rename to surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/middleware.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/checkpointed_subagent_middleware/propagation.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/propagation.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/checkpointed_subagent_middleware/propagation.py rename to surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/propagation.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/checkpointed_subagent_middleware/resume.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/resume.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/checkpointed_subagent_middleware/resume.py rename to surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/resume.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/checkpointed_subagent_middleware/task_tool.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_tool.py similarity index 91% rename from surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/checkpointed_subagent_middleware/task_tool.py rename to surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_tool.py index d23dc33a9..7c0dd8624 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/checkpointed_subagent_middleware/task_tool.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_tool.py @@ -20,6 +20,7 @@ from langgraph.types import Command from .config import ( consume_surfsense_resume, + drain_parent_null_resume, has_surfsense_resume, subagent_invoke_config, ) @@ -69,9 +70,16 @@ def build_task_tool_with_parent_config( raise ValueError(msg) state_update = {k: v for k, v in result.items() if k not in EXCLUDED_STATE_KEYS} - message_text = ( - result["messages"][-1].text.rstrip() if result["messages"][-1].text else "" - ) + messages = result["messages"] + if not messages: + msg = ( + "CompiledSubAgent returned an empty 'messages' list. " + "Subagents must produce at least one message so the parent has " + "output to forward back to the user." + ) + raise ValueError(msg) + last_text = getattr(messages[-1], "text", None) or "" + message_text = last_text.rstrip() return Command( update={ **state_update, @@ -150,6 +158,9 @@ def build_task_tool_with_parent_config( ) expected = hitlrequest_action_count(pending_value) resume_value = fan_out_decisions_to_match(resume_value, expected) + # Prevent the parent's resume payload from leaking into subagent + # interrupts via langgraph's parent_scratchpad fallback. + drain_parent_null_resume(runtime) result = subagent.invoke( build_resume_command(resume_value, pending_id), config=sub_config, @@ -214,6 +225,9 @@ def build_task_tool_with_parent_config( ) expected = hitlrequest_action_count(pending_value) resume_value = fan_out_decisions_to_match(resume_value, expected) + # Prevent the parent's resume payload from leaking into subagent + # interrupts via langgraph's parent_scratchpad fallback. + drain_parent_null_resume(runtime) result = await subagent.ainvoke( build_resume_command(resume_value, pending_id), config=sub_config, diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/context_editing.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/context_editing.py new file mode 100644 index 000000000..e8f99933e --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/context_editing.py @@ -0,0 +1,50 @@ +"""Spill + clear-tool-uses passes to keep payloads under budget.""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +from langchain_core.tools import BaseTool + +from app.agents.multi_agent_chat.main_agent.context_prune.prune_tool_names import ( + safe_exclude_tools, +) +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.new_chat.middleware import ( + ClearToolUsesEdit, + SpillingContextEditingMiddleware, + SpillToBackendEdit, +) + +from ..shared.flags import enabled + + +def build_context_editing_mw( + *, + flags: AgentFeatureFlags, + max_input_tokens: int | None, + tools: Sequence[BaseTool], + backend_resolver: Any, +) -> SpillingContextEditingMiddleware | None: + if not enabled(flags, "enable_context_editing") or not max_input_tokens: + return None + spill_edit = SpillToBackendEdit( + trigger=int(max_input_tokens * 0.55), + clear_at_least=int(max_input_tokens * 0.15), + keep=5, + exclude_tools=safe_exclude_tools(tools), + clear_tool_inputs=True, + ) + clear_edit = ClearToolUsesEdit( + trigger=int(max_input_tokens * 0.55), + clear_at_least=int(max_input_tokens * 0.15), + keep=5, + exclude_tools=safe_exclude_tools(tools), + clear_tool_inputs=True, + placeholder="[cleared - older tool output trimmed for context]", + ) + return SpillingContextEditingMiddleware( + edits=[spill_edit, clear_edit], + backend_resolver=backend_resolver, + ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/dedup_hitl.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/dedup_hitl.py new file mode 100644 index 000000000..66cae300b --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/dedup_hitl.py @@ -0,0 +1,13 @@ +"""Drop duplicate HITL tool calls before execution.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from langchain_core.tools import BaseTool + +from app.agents.new_chat.middleware import DedupHITLToolCallsMiddleware + + +def build_dedup_hitl_mw(tools: Sequence[BaseTool]) -> DedupHITLToolCallsMiddleware: + return DedupHITLToolCallsMiddleware(agent_tools=list(tools)) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/doom_loop.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/doom_loop.py new file mode 100644 index 000000000..d67b8d518 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/doom_loop.py @@ -0,0 +1,14 @@ +"""Stop N identical tool calls in a row via interrupt.""" + +from __future__ import annotations + +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.new_chat.middleware import DoomLoopMiddleware + +from ..shared.flags import enabled + + +def build_doom_loop_mw(flags: AgentFeatureFlags) -> DoomLoopMiddleware | None: + return ( + DoomLoopMiddleware(threshold=3) if enabled(flags, "enable_doom_loop") else None + ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/kb_persistence.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/kb_persistence.py new file mode 100644 index 000000000..4b27581e7 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/kb_persistence.py @@ -0,0 +1,23 @@ +"""Commit staged cloud filesystem mutations to Postgres at end of turn.""" + +from __future__ import annotations + +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.middleware import KnowledgeBasePersistenceMiddleware + + +def build_kb_persistence_mw( + *, + filesystem_mode: FilesystemMode, + search_space_id: int, + user_id: str | None, + thread_id: int | None, +) -> KnowledgeBasePersistenceMiddleware | None: + if filesystem_mode != FilesystemMode.CLOUD: + return None + return KnowledgeBasePersistenceMiddleware( + search_space_id=search_space_id, + created_by_id=user_id, + filesystem_mode=filesystem_mode, + thread_id=thread_id, + ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/knowledge_priority.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/knowledge_priority.py new file mode 100644 index 000000000..395d2a7af --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/knowledge_priority.py @@ -0,0 +1,27 @@ +"""KB priority planner: injection.""" + +from __future__ import annotations + +from langchain_core.language_models import BaseChatModel + +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.middleware import KnowledgePriorityMiddleware + + +def build_knowledge_priority_mw( + *, + llm: BaseChatModel, + search_space_id: int, + filesystem_mode: FilesystemMode, + available_connectors: list[str] | None, + available_document_types: list[str] | None, + mentioned_document_ids: list[int] | None, +) -> KnowledgePriorityMiddleware: + return KnowledgePriorityMiddleware( + llm=llm, + search_space_id=search_space_id, + filesystem_mode=filesystem_mode, + available_connectors=available_connectors, + available_document_types=available_document_types, + mentioned_document_ids=mentioned_document_ids, + ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/knowledge_tree.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/knowledge_tree.py new file mode 100644 index 000000000..404082401 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/knowledge_tree.py @@ -0,0 +1,23 @@ +""" injection (cloud only).""" + +from __future__ import annotations + +from langchain_core.language_models import BaseChatModel + +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.middleware import KnowledgeTreeMiddleware + + +def build_knowledge_tree_mw( + *, + filesystem_mode: FilesystemMode, + search_space_id: int, + llm: BaseChatModel, +) -> KnowledgeTreeMiddleware | None: + if filesystem_mode != FilesystemMode.CLOUD: + return None + return KnowledgeTreeMiddleware( + search_space_id=search_space_id, + filesystem_mode=filesystem_mode, + llm=llm, + ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/noop_injection.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/noop_injection.py new file mode 100644 index 000000000..6e6467ad0 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/noop_injection.py @@ -0,0 +1,12 @@ +"""Provider-compat: append a `_noop` tool when tools=[] but history has tool calls.""" + +from __future__ import annotations + +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.new_chat.middleware import NoopInjectionMiddleware + +from ..shared.flags import enabled + + +def build_noop_injection_mw(flags: AgentFeatureFlags) -> NoopInjectionMiddleware | None: + return NoopInjectionMiddleware() if enabled(flags, "enable_compaction_v2") else None diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/otel.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/otel.py new file mode 100644 index 000000000..bd7516e65 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/otel.py @@ -0,0 +1,12 @@ +"""OTel spans on model and tool calls.""" + +from __future__ import annotations + +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.new_chat.middleware import OtelSpanMiddleware + +from ..shared.flags import enabled + + +def build_otel_mw(flags: AgentFeatureFlags) -> OtelSpanMiddleware | None: + return OtelSpanMiddleware() if enabled(flags, "enable_otel") else None diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/plugins.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/plugins.py new file mode 100644 index 000000000..4418e3806 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/plugins.py @@ -0,0 +1,49 @@ +"""Tail-of-stack plugin slot driven by env allowlist.""" + +from __future__ import annotations + +import logging +from typing import Any + +from langchain_core.language_models import BaseChatModel + +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.new_chat.plugin_loader import ( + PluginContext, + load_allowed_plugin_names_from_env, + load_plugin_middlewares, +) +from app.db import ChatVisibility + +from ..shared.flags import enabled + + +def build_plugin_middlewares( + *, + flags: AgentFeatureFlags, + search_space_id: int, + user_id: str | None, + visibility: ChatVisibility, + llm: BaseChatModel, +) -> list[Any]: + if not enabled(flags, "enable_plugin_loader"): + return [] + try: + allowed_names = load_allowed_plugin_names_from_env() + if not allowed_names: + return [] + return load_plugin_middlewares( + PluginContext.build( + search_space_id=search_space_id, + user_id=user_id, + thread_visibility=visibility, + llm=llm, + ), + allowed_plugin_names=allowed_names, + ) + except Exception: # pragma: no cover - defensive + logging.warning( + "Plugin loader failed; continuing without plugins.", + exc_info=True, + ) + return [] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/repair.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/repair.py new file mode 100644 index 000000000..378b61be1 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/repair.py @@ -0,0 +1,50 @@ +"""Repair miscased / unknown tool names to the registered set or invalid_tool.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from langchain_core.tools import BaseTool + +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.new_chat.middleware import ToolCallNameRepairMiddleware + +from ..shared.flags import enabled + +# deepagents-built-in tool names the repair pass treats as known. +_DEEPAGENT_BUILTIN_TOOL_NAMES: frozenset[str] = frozenset( + { + "write_todos", + "ls", + "read_file", + "write_file", + "edit_file", + "glob", + "grep", + "execute", + "task", + "mkdir", + "cd", + "pwd", + "move_file", + "rm", + "rmdir", + "list_tree", + "execute_code", + } +) + + +def build_repair_mw( + *, + flags: AgentFeatureFlags, + tools: Sequence[BaseTool], +) -> ToolCallNameRepairMiddleware | None: + if not enabled(flags, "enable_tool_call_repair"): + return None + registered_names: set[str] = {t.name for t in tools} + registered_names |= _DEEPAGENT_BUILTIN_TOOL_NAMES + return ToolCallNameRepairMiddleware( + registered_tool_names=registered_names, + fuzzy_match_threshold=None, + ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/selector.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/selector.py new file mode 100644 index 000000000..8e7a32be8 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/selector.py @@ -0,0 +1,39 @@ +"""LLM-based tool subset selection (only when >30 tools).""" + +from __future__ import annotations + +import logging +from collections.abc import Sequence + +from langchain.agents.middleware import LLMToolSelectorMiddleware +from langchain_core.tools import BaseTool + +from app.agents.new_chat.feature_flags import AgentFeatureFlags + +from ..shared.flags import enabled + + +def build_selector_mw( + *, + flags: AgentFeatureFlags, + tools: Sequence[BaseTool], +) -> LLMToolSelectorMiddleware | None: + if not enabled(flags, "enable_llm_tool_selector") or len(tools) <= 30: + return None + try: + return LLMToolSelectorMiddleware( + model="openai:gpt-4o-mini", + max_tools=12, + always_include=[ + name + for name in ( + "update_memory", + "get_connected_accounts", + "scrape_webpage", + ) + if name in {t.name for t in tools} + ], + ) + except Exception: + logging.warning("LLMToolSelectorMiddleware init failed; skipping.") + return None diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/skills.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/skills.py new file mode 100644 index 000000000..63a57c5a0 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/skills.py @@ -0,0 +1,39 @@ +"""Skill discovery + injection.""" + +from __future__ import annotations + +import logging + +from deepagents.middleware.skills import SkillsMiddleware + +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.middleware import ( + build_skills_backend_factory, + default_skills_sources, +) + +from ..shared.flags import enabled + + +def build_skills_mw( + *, + flags: AgentFeatureFlags, + filesystem_mode: FilesystemMode, + search_space_id: int, +) -> SkillsMiddleware | None: + if not enabled(flags, "enable_skills"): + return None + try: + skills_factory = build_skills_backend_factory( + search_space_id=search_space_id + if filesystem_mode == FilesystemMode.CLOUD + else None, + ) + return SkillsMiddleware( + backend=skills_factory, + sources=default_skills_sources(), + ) + except Exception as exc: # pragma: no cover - defensive + logging.warning("SkillsMiddleware init failed; skipping: %s", exc) + return None diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/anthropic_cache.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/anthropic_cache.py new file mode 100644 index 000000000..f99fb9c7f --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/anthropic_cache.py @@ -0,0 +1,9 @@ +"""Anthropic prompt caching annotations on system/tool/message blocks.""" + +from __future__ import annotations + +from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware + + +def build_anthropic_cache_mw() -> AnthropicPromptCachingMiddleware: + return AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore") diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/compaction.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/compaction.py new file mode 100644 index 000000000..b59e7d2c4 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/compaction.py @@ -0,0 +1,14 @@ +"""Context-window summarization with SurfSense protected sections.""" + +from __future__ import annotations + +from typing import Any + +from deepagents.backends import StateBackend +from langchain_core.language_models import BaseChatModel + +from app.agents.new_chat.middleware import create_surfsense_compaction_middleware + + +def build_compaction_mw(llm: BaseChatModel) -> Any: + return create_surfsense_compaction_middleware(llm, StateBackend) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/file_intent.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/file_intent.py new file mode 100644 index 000000000..5ff65aa12 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/file_intent.py @@ -0,0 +1,11 @@ +"""File-intent classifier that gates strict write contracts.""" + +from __future__ import annotations + +from langchain_core.language_models import BaseChatModel + +from app.agents.new_chat.middleware import FileIntentMiddleware + + +def build_file_intent_mw(llm: BaseChatModel) -> FileIntentMiddleware: + return FileIntentMiddleware(llm=llm) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem.py new file mode 100644 index 000000000..9481f5167 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem.py @@ -0,0 +1,25 @@ +"""SurfSense filesystem tools/middleware.""" + +from __future__ import annotations + +from typing import Any + +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.middleware import SurfSenseFilesystemMiddleware + + +def build_filesystem_mw( + *, + backend_resolver: Any, + filesystem_mode: FilesystemMode, + search_space_id: int, + user_id: str | None, + thread_id: int | None, +) -> SurfSenseFilesystemMiddleware: + return SurfSenseFilesystemMiddleware( + backend=backend_resolver, + filesystem_mode=filesystem_mode, + search_space_id=search_space_id, + created_by_id=user_id, + thread_id=thread_id, + ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/flags.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/flags.py new file mode 100644 index 000000000..69994ae00 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/flags.py @@ -0,0 +1,10 @@ +"""Single source of truth for the feature-flag predicate.""" + +from __future__ import annotations + +from app.agents.new_chat.feature_flags import AgentFeatureFlags + + +def enabled(flags: AgentFeatureFlags, attr: str) -> bool: + """``flags.`` is on AND the new-agent-stack kill switch is off.""" + return getattr(flags, attr) and not flags.disable_new_agent_stack diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/memory.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/memory.py new file mode 100644 index 000000000..9316b3e21 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/memory.py @@ -0,0 +1,19 @@ +"""User/team memory injection prepended to the conversation.""" + +from __future__ import annotations + +from app.agents.new_chat.middleware import MemoryInjectionMiddleware +from app.db import ChatVisibility + + +def build_memory_mw( + *, + user_id: str | None, + search_space_id: int, + visibility: ChatVisibility, +) -> MemoryInjectionMiddleware: + return MemoryInjectionMiddleware( + user_id=user_id, + search_space_id=search_space_id, + thread_visibility=visibility, + ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/patch_tool_calls.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/patch_tool_calls.py new file mode 100644 index 000000000..50036dbbe --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/patch_tool_calls.py @@ -0,0 +1,9 @@ +"""Repair dangling tool-call sequences before each agent turn.""" + +from __future__ import annotations + +from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware + + +def build_patch_tool_calls_mw() -> PatchToolCallsMiddleware: + return PatchToolCallsMiddleware() diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/__init__.py new file mode 100644 index 000000000..4f2228170 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/__init__.py @@ -0,0 +1,12 @@ +"""Permission rulesets fanned out to parent / general-purpose / subagent stacks.""" + +from __future__ import annotations + +from .context import PermissionContext, build_permission_context +from .middleware import build_full_permission_mw + +__all__ = [ + "PermissionContext", + "build_full_permission_mw", + "build_permission_context", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/context.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/context.py new file mode 100644 index 000000000..e121421a0 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/context.py @@ -0,0 +1,107 @@ +"""Derive shared permission context once; fan out to all three stack layers. + +The context carries: +- ``rulesets``: full ask/deny/allow rules for the main-agent permission middleware. +- ``general_purpose_interrupt_on``: ``ask`` rules mirrored as deepagents + ``interrupt_on`` so HITL still triggers from inside ``task`` runs (subagents + bypass the main-agent permission middleware). +- ``subagent_deny_mw``: a deny-only ``PermissionMiddleware`` instance shared + across the general-purpose and registry subagent stacks. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass + +from langchain_core.tools import BaseTool + +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.middleware import PermissionMiddleware +from app.agents.new_chat.permissions import Rule, Ruleset +from app.agents.new_chat.tools.registry import BUILTIN_TOOLS + +from ..flags import enabled + + +@dataclass(frozen=True) +class PermissionContext: + rulesets: list[Ruleset] + general_purpose_interrupt_on: dict[str, bool] + subagent_deny_mw: PermissionMiddleware | None + + +def build_permission_context( + *, + flags: AgentFeatureFlags, + filesystem_mode: FilesystemMode, + tools: Sequence[BaseTool], + available_connectors: list[str] | None, +) -> PermissionContext: + is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER + permission_enabled = enabled(flags, "enable_permission") + + rulesets: list[Ruleset] = [] + if permission_enabled or is_desktop_fs: + rulesets.append( + Ruleset( + rules=[Rule(permission="*", pattern="*", action="allow")], + origin="surfsense_defaults", + ) + ) + if is_desktop_fs: + rulesets.append( + Ruleset( + rules=[ + Rule(permission="rm", pattern="*", action="ask"), + Rule(permission="rmdir", pattern="*", action="ask"), + Rule(permission="move_file", pattern="*", action="ask"), + Rule(permission="edit_file", pattern="*", action="ask"), + Rule(permission="write_file", pattern="*", action="ask"), + ], + origin="desktop_safety", + ) + ) + + tool_names_in_use = {t.name for t in tools} + + if permission_enabled: + available_set = set(available_connectors or []) + synthesized: list[Rule] = [] + for tool_def in BUILTIN_TOOLS: + if tool_def.name not in tool_names_in_use: + continue + rc = tool_def.required_connector + if rc and rc not in available_set: + synthesized.append( + Rule(permission=tool_def.name, pattern="*", action="deny") + ) + if synthesized: + rulesets.append(Ruleset(rules=synthesized, origin="connector_synthesized")) + + general_purpose_interrupt_on: dict[str, bool] = { + rule.permission: True + for rs in rulesets + for rule in rs.rules + if rule.action == "ask" and rule.permission in tool_names_in_use + } + + deny_rulesets = [ + Ruleset( + rules=[r for r in rs.rules if r.action == "deny"], + origin=rs.origin, + ) + for rs in rulesets + ] + deny_rulesets = [rs for rs in deny_rulesets if rs.rules] + + subagent_deny_mw: PermissionMiddleware | None = ( + PermissionMiddleware(rulesets=deny_rulesets) if deny_rulesets else None + ) + + return PermissionContext( + rulesets=rulesets, + general_purpose_interrupt_on=general_purpose_interrupt_on, + subagent_deny_mw=subagent_deny_mw, + ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware.py new file mode 100644 index 000000000..704a26fb3 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware.py @@ -0,0 +1,10 @@ +"""Main-agent permission middleware (full ask/deny/allow rules).""" + +from __future__ import annotations + +from app.agents.new_chat.middleware import PermissionMiddleware +from app.agents.new_chat.permissions import Ruleset + + +def build_full_permission_mw(rulesets: list[Ruleset]) -> PermissionMiddleware | None: + return PermissionMiddleware(rulesets=rulesets) if rulesets else None diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/__init__.py new file mode 100644 index 000000000..92596b771 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/__init__.py @@ -0,0 +1,7 @@ +"""Resilience middleware shared as the same instances across parent / general-purpose / registry.""" + +from __future__ import annotations + +from .bundle import ResilienceBundle, build_resilience_bundle + +__all__ = ["ResilienceBundle", "build_resilience_bundle"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/bundle.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/bundle.py new file mode 100644 index 000000000..45f76a6f3 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/bundle.py @@ -0,0 +1,51 @@ +"""Construct each resilience middleware once; same instances flow into every consumer.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from langchain.agents.middleware import ( + ModelCallLimitMiddleware, + ToolCallLimitMiddleware, +) + +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.new_chat.middleware import RetryAfterMiddleware +from app.agents.new_chat.middleware.scoped_model_fallback import ( + ScopedModelFallbackMiddleware, +) + +from .fallback import build_fallback_mw +from .model_call_limit import build_model_call_limit_mw +from .retry import build_retry_mw +from .tool_call_limit import build_tool_call_limit_mw + + +@dataclass(frozen=True) +class ResilienceBundle: + retry: RetryAfterMiddleware | None + fallback: ScopedModelFallbackMiddleware | None + model_call_limit: ModelCallLimitMiddleware | None + tool_call_limit: ToolCallLimitMiddleware | None + + def as_list(self) -> list[Any]: + return [ + m + for m in ( + self.retry, + self.fallback, + self.model_call_limit, + self.tool_call_limit, + ) + if m is not None + ] + + +def build_resilience_bundle(flags: AgentFeatureFlags) -> ResilienceBundle: + return ResilienceBundle( + retry=build_retry_mw(flags), + fallback=build_fallback_mw(flags), + model_call_limit=build_model_call_limit_mw(flags), + tool_call_limit=build_tool_call_limit_mw(flags), + ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/fallback.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/fallback.py new file mode 100644 index 000000000..ea68a764e --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/fallback.py @@ -0,0 +1,27 @@ +"""Switch to a fallback model on provider/network errors only.""" + +from __future__ import annotations + +import logging + +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.new_chat.middleware.scoped_model_fallback import ( + ScopedModelFallbackMiddleware, +) + +from ..flags import enabled + + +def build_fallback_mw( + flags: AgentFeatureFlags, +) -> ScopedModelFallbackMiddleware | None: + if not enabled(flags, "enable_model_fallback"): + return None + try: + return ScopedModelFallbackMiddleware( + "openai:gpt-4o-mini", + "anthropic:claude-3-5-haiku-20241022", + ) + except Exception: + logging.warning("ScopedModelFallbackMiddleware init failed; skipping.") + return None diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/model_call_limit.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/model_call_limit.py new file mode 100644 index 000000000..85707a385 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/model_call_limit.py @@ -0,0 +1,21 @@ +"""Cap model calls per thread / per run to prevent runaway cost.""" + +from __future__ import annotations + +from langchain.agents.middleware import ModelCallLimitMiddleware + +from app.agents.new_chat.feature_flags import AgentFeatureFlags + +from ..flags import enabled + + +def build_model_call_limit_mw( + flags: AgentFeatureFlags, +) -> ModelCallLimitMiddleware | None: + if not enabled(flags, "enable_model_call_limit"): + return None + return ModelCallLimitMiddleware( + thread_limit=120, + run_limit=80, + exit_behavior="end", + ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/retry.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/retry.py new file mode 100644 index 000000000..c98fc4083 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/retry.py @@ -0,0 +1,16 @@ +"""Retry on transient model errors (e.g. Retry-After-bearing 429s).""" + +from __future__ import annotations + +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.new_chat.middleware import RetryAfterMiddleware + +from ..flags import enabled + + +def build_retry_mw(flags: AgentFeatureFlags) -> RetryAfterMiddleware | None: + return ( + RetryAfterMiddleware(max_retries=3) + if enabled(flags, "enable_retry_after") + else None + ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/tool_call_limit.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/tool_call_limit.py new file mode 100644 index 000000000..dcde81f37 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/resilience/tool_call_limit.py @@ -0,0 +1,21 @@ +"""Cap tool calls per thread / per run to bound infinite-loop blast radius.""" + +from __future__ import annotations + +from langchain.agents.middleware import ToolCallLimitMiddleware + +from app.agents.new_chat.feature_flags import AgentFeatureFlags + +from ..flags import enabled + + +def build_tool_call_limit_mw( + flags: AgentFeatureFlags, +) -> ToolCallLimitMiddleware | None: + if not enabled(flags, "enable_tool_call_limit"): + return None + return ToolCallLimitMiddleware( + thread_limit=300, + run_limit=80, + exit_behavior="continue", + ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/todos.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/todos.py new file mode 100644 index 000000000..ea9173a1d --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/todos.py @@ -0,0 +1,9 @@ +"""Todo-list middleware (each consumer needs its own instance).""" + +from __future__ import annotations + +from langchain.agents.middleware import TodoListMiddleware + + +def build_todos_mw() -> TodoListMiddleware: + return TodoListMiddleware() diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/stack.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/stack.py new file mode 100644 index 000000000..6d8faa3f4 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/stack.py @@ -0,0 +1,216 @@ +"""Main-agent middleware list assembly: one line per slot.""" + +from __future__ import annotations + +import logging +from collections.abc import Sequence +from typing import Any + +from deepagents import SubAgent +from deepagents.backends import StateBackend +from langchain_core.language_models import BaseChatModel +from langchain_core.tools import BaseTool +from langgraph.types import Checkpointer + +from app.agents.multi_agent_chat.subagents import ( + build_subagents, + get_subagents_to_exclude, +) +from app.agents.multi_agent_chat.subagents.builtins.general_purpose.agent import ( + build_subagent as build_general_purpose_subagent, +) +from app.agents.multi_agent_chat.subagents.shared.permissions import ToolsPermissions +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.db import ChatVisibility + +from .main_agent.action_log import build_action_log_mw +from .main_agent.anonymous_doc import build_anonymous_doc_mw +from .main_agent.busy_mutex import build_busy_mutex_mw +from .main_agent.checkpointed_subagent_middleware import ( + SurfSenseCheckpointedSubAgentMiddleware, +) +from .main_agent.context_editing import build_context_editing_mw +from .main_agent.dedup_hitl import build_dedup_hitl_mw +from .main_agent.doom_loop import build_doom_loop_mw +from .main_agent.kb_persistence import build_kb_persistence_mw +from .main_agent.knowledge_priority import build_knowledge_priority_mw +from .main_agent.knowledge_tree import build_knowledge_tree_mw +from .main_agent.noop_injection import build_noop_injection_mw +from .main_agent.otel import build_otel_mw +from .main_agent.plugins import build_plugin_middlewares +from .main_agent.repair import build_repair_mw +from .main_agent.selector import build_selector_mw +from .main_agent.skills import build_skills_mw +from .shared.anthropic_cache import build_anthropic_cache_mw +from .shared.compaction import build_compaction_mw +from .shared.file_intent import build_file_intent_mw +from .shared.filesystem import build_filesystem_mw +from .shared.memory import build_memory_mw +from .shared.patch_tool_calls import build_patch_tool_calls_mw +from .shared.permissions import ( + build_full_permission_mw, + build_permission_context, +) +from .shared.resilience import build_resilience_bundle +from .shared.todos import build_todos_mw +from .subagent.extras import build_subagent_extras + + +def build_main_agent_deepagent_middleware( + *, + llm: BaseChatModel, + tools: Sequence[BaseTool], + backend_resolver: Any, + filesystem_mode: FilesystemMode, + search_space_id: int, + user_id: str | None, + thread_id: int | None, + visibility: ChatVisibility, + anon_session_id: str | None, + available_connectors: list[str] | None, + available_document_types: list[str] | None, + mentioned_document_ids: list[int] | None, + max_input_tokens: int | None, + flags: AgentFeatureFlags, + subagent_dependencies: dict[str, Any], + checkpointer: Checkpointer, + mcp_tools_by_agent: dict[str, ToolsPermissions] | None = None, + disabled_tools: list[str] | None = None, +) -> list[Any]: + """Ordered middleware for ``create_agent`` (None entries already stripped).""" + permissions = build_permission_context( + flags=flags, + filesystem_mode=filesystem_mode, + tools=tools, + available_connectors=available_connectors, + ) + resilience = build_resilience_bundle(flags) + + # Single instance threaded into both the main-agent stack and the general-purpose subagent. + memory_mw = build_memory_mw( + user_id=user_id, + search_space_id=search_space_id, + visibility=visibility, + ) + + general_purpose_subagent = build_general_purpose_subagent( + llm=llm, + tools=tools, + backend_resolver=backend_resolver, + filesystem_mode=filesystem_mode, + search_space_id=search_space_id, + user_id=user_id, + thread_id=thread_id, + permissions=permissions, + resilience=resilience, + memory_mw=memory_mw, + ) + + subagents_registry: list[SubAgent] = [] + try: + subagent_extras = build_subagent_extras( + permissions=permissions, + resilience=resilience, + ) + subagents_registry = build_subagents( + dependencies=subagent_dependencies, + model=llm, + extra_middleware=subagent_extras, + mcp_tools_by_agent=mcp_tools_by_agent or {}, + exclude=get_subagents_to_exclude(available_connectors), + disabled_tools=disabled_tools, + ) + logging.debug( + "Subagents registry: %s", + [s["name"] for s in subagents_registry], + ) + except Exception: + # Degrade to general-purpose-only rather than aborting the turn: + # one bad subagent dep should not deny the user a response. + logging.exception( + "Subagents registry build failed; falling back to general-purpose only" + ) + subagents_registry = [] + + subagents: list[SubAgent] = [general_purpose_subagent, *subagents_registry] + + stack: list[Any] = [ + build_busy_mutex_mw(flags), + build_otel_mw(flags), + build_todos_mw(), + memory_mw, + build_anonymous_doc_mw( + filesystem_mode=filesystem_mode, anon_session_id=anon_session_id + ), + build_knowledge_tree_mw( + filesystem_mode=filesystem_mode, + search_space_id=search_space_id, + llm=llm, + ), + build_knowledge_priority_mw( + llm=llm, + search_space_id=search_space_id, + filesystem_mode=filesystem_mode, + available_connectors=available_connectors, + available_document_types=available_document_types, + mentioned_document_ids=mentioned_document_ids, + ), + build_file_intent_mw(llm), + build_filesystem_mw( + backend_resolver=backend_resolver, + filesystem_mode=filesystem_mode, + search_space_id=search_space_id, + user_id=user_id, + thread_id=thread_id, + ), + build_kb_persistence_mw( + filesystem_mode=filesystem_mode, + search_space_id=search_space_id, + user_id=user_id, + thread_id=thread_id, + ), + build_skills_mw( + flags=flags, + filesystem_mode=filesystem_mode, + search_space_id=search_space_id, + ), + SurfSenseCheckpointedSubAgentMiddleware( + checkpointer=checkpointer, + backend=StateBackend, + subagents=subagents, + ), + build_selector_mw(flags=flags, tools=tools), + resilience.model_call_limit, + resilience.tool_call_limit, + build_context_editing_mw( + flags=flags, + max_input_tokens=max_input_tokens, + tools=tools, + backend_resolver=backend_resolver, + ), + build_compaction_mw(llm), + build_noop_injection_mw(flags), + resilience.retry, + resilience.fallback, + build_repair_mw(flags=flags, tools=tools), + build_full_permission_mw(permissions.rulesets), + build_doom_loop_mw(flags), + build_action_log_mw( + flags=flags, + thread_id=thread_id, + search_space_id=search_space_id, + user_id=user_id, + ), + build_patch_tool_calls_mw(), + build_dedup_hitl_mw(tools), + *build_plugin_middlewares( + flags=flags, + search_space_id=search_space_id, + user_id=user_id, + visibility=visibility, + llm=llm, + ), + build_anthropic_cache_mw(), + ] + return [m for m in stack if m is not None] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/subagent/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/subagent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/subagent/extras.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/subagent/extras.py new file mode 100644 index 000000000..46dca8a81 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/subagent/extras.py @@ -0,0 +1,28 @@ +"""Extra middleware threaded into every registry subagent's stack. + +Registry subagents are scoped to one domain (deliverables, research, memory, +connectors, MCP) and never read or write the SurfSense filesystem — that +capability belongs to the main agent and is delegated to the general-purpose +subagent as an escape hatch. Keeping FS off the registry stacks avoids +polluting their tool surface with FS tools they never act on. +""" + +from __future__ import annotations + +from typing import Any + +from ..shared.permissions import PermissionContext +from ..shared.resilience import ResilienceBundle +from ..shared.todos import build_todos_mw + + +def build_subagent_extras( + *, + permissions: PermissionContext, + resilience: ResilienceBundle, +) -> list[Any]: + extras: list[Any] = [build_todos_mw()] + if permissions.subagent_deny_mw is not None: + extras.append(permissions.subagent_deny_mw) + extras.extend(resilience.as_list()) + return extras diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/general_purpose/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/general_purpose/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/general_purpose/agent.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/general_purpose/agent.py new file mode 100644 index 000000000..1c3c44f12 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/general_purpose/agent.py @@ -0,0 +1,105 @@ +"""General-purpose subagent for the multi-agent main agent.""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any, cast + +from deepagents import SubAgent +from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware +from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT +from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware +from langchain_core.language_models import BaseChatModel +from langchain_core.tools import BaseTool + +from app.agents.multi_agent_chat.middleware.shared.anthropic_cache import ( + build_anthropic_cache_mw, +) +from app.agents.multi_agent_chat.middleware.shared.compaction import ( + build_compaction_mw, +) +from app.agents.multi_agent_chat.middleware.shared.file_intent import ( + build_file_intent_mw, +) +from app.agents.multi_agent_chat.middleware.shared.filesystem import ( + build_filesystem_mw, +) +from app.agents.multi_agent_chat.middleware.shared.patch_tool_calls import ( + build_patch_tool_calls_mw, +) +from app.agents.multi_agent_chat.middleware.shared.permissions import ( + PermissionContext, +) +from app.agents.multi_agent_chat.middleware.shared.resilience import ( + ResilienceBundle, +) +from app.agents.multi_agent_chat.middleware.shared.todos import build_todos_mw +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.middleware import MemoryInjectionMiddleware + +NAME = "general-purpose" + + +def build_subagent( + *, + llm: BaseChatModel, + tools: Sequence[BaseTool], + backend_resolver: Any, + filesystem_mode: FilesystemMode, + search_space_id: int, + user_id: str | None, + thread_id: int | None, + permissions: PermissionContext, + resilience: ResilienceBundle, + memory_mw: MemoryInjectionMiddleware, +) -> SubAgent: + """Deny + resilience inserts encapsulated here so the orchestrator never mutates the list.""" + middleware: list[Any] = [ + build_todos_mw(), + memory_mw, + build_file_intent_mw(llm), + build_filesystem_mw( + backend_resolver=backend_resolver, + filesystem_mode=filesystem_mode, + search_space_id=search_space_id, + user_id=user_id, + thread_id=thread_id, + ), + build_compaction_mw(llm), + build_patch_tool_calls_mw(), + build_anthropic_cache_mw(), + ] + + if permissions.subagent_deny_mw is not None: + patch_idx = next( + ( + i + for i, m in enumerate(middleware) + if isinstance(m, PatchToolCallsMiddleware) + ), + len(middleware), + ) + middleware.insert(patch_idx, permissions.subagent_deny_mw) + + resilience_mws = resilience.as_list() + if resilience_mws: + cache_idx = next( + ( + i + for i, m in enumerate(middleware) + if isinstance(m, AnthropicPromptCachingMiddleware) + ), + len(middleware), + ) + for offset, mw in enumerate(resilience_mws): + middleware.insert(cache_idx + offset, mw) + + spec: dict[str, Any] = { + **GENERAL_PURPOSE_SUBAGENT, + "model": llm, + "tools": tools, + "middleware": middleware, + } + if permissions.general_purpose_interrupt_on: + spec["interrupt_on"] = permissions.general_purpose_interrupt_on + return cast(SubAgent, spec) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/create_event.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/create_event.py index 37bcf083e..a8183314a 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/create_event.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/create_event.py @@ -168,20 +168,46 @@ def create_create_calendar_event_tool( f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}" ) + tz = context.get("timezone", "UTC") + if ( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR ): - from app.utils.google_credentials import build_composio_credentials - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) - else: + if not cca_id: return { "status": "error", "message": "Composio connected account ID not found for this connector.", } + + from app.services.composio_service import ComposioService + + ( + event_id, + html_link, + error, + ) = await ComposioService().create_calendar_event( + connected_account_id=cca_id, + entity_id=f"surfsense_{user_id}", + summary=final_summary, + start_datetime=final_start_datetime, + end_datetime=final_end_datetime, + timezone=tz, + description=final_description, + location=final_location, + attendees=final_attendees, + ) + if error: + return {"status": "error", "message": error} + created = { + "id": event_id, + "summary": final_summary, + "htmlLink": html_link, + } + logger.info( + f"Calendar event created via Composio: id={event_id}, summary={final_summary}" + ) else: config_data = dict(connector.config) @@ -211,70 +237,69 @@ def create_create_calendar_event_tool( expiry=datetime.fromisoformat(exp) if exp else None, ) - service = await asyncio.get_event_loop().run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) - - tz = context.get("timezone", "UTC") - event_body: dict[str, Any] = { - "summary": final_summary, - "start": {"dateTime": final_start_datetime, "timeZone": tz}, - "end": {"dateTime": final_end_datetime, "timeZone": tz}, - } - if final_description: - event_body["description"] = final_description - if final_location: - event_body["location"] = final_location - if final_attendees: - event_body["attendees"] = [ - {"email": e.strip()} for e in final_attendees if e.strip() - ] - - try: - created = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - service.events() - .insert(calendarId="primary", body=event_body) - .execute() - ), + service = await asyncio.get_event_loop().run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) ) - except Exception as api_err: - from googleapiclient.errors import HttpError - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" + event_body: dict[str, Any] = { + "summary": final_summary, + "start": {"dateTime": final_start_datetime, "timeZone": tz}, + "end": {"dateTime": final_end_datetime, "timeZone": tz}, + } + if final_description: + event_body["description"] = final_description + if final_location: + event_body["location"] = final_location + if final_attendees: + event_body["attendees"] = [ + {"email": e.strip()} for e in final_attendees if e.strip() + ] + + try: + created = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + service.events() + .insert(calendarId="primary", body=event_body) + .execute() + ), ) - try: - from sqlalchemy.orm.attributes import flag_modified + except Exception as api_err: + from googleapiclient.errors import HttpError - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id - ) - ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: + if isinstance(api_err, HttpError) and api_err.resp.status == 403: logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, + f"Insufficient permissions for connector {actual_connector_id}: {api_err}" ) - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", - } - raise + try: + from sqlalchemy.orm.attributes import flag_modified - logger.info( - f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}" - ) + _res = await db_session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == actual_connector_id + ) + ) + _conn = _res.scalar_one_or_none() + if _conn and not _conn.config.get("auth_expired"): + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + actual_connector_id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + logger.info( + f"Calendar event created via Google API: id={created.get('id')}, summary={created.get('summary')}" + ) kb_message_suffix = "" try: diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/delete_event.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/delete_event.py index 4d9d69b4b..3d160e669 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/delete_event.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/delete_event.py @@ -163,16 +163,22 @@ def create_delete_calendar_event_tool( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR ): - from app.utils.google_credentials import build_composio_credentials - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) - else: + if not cca_id: return { "status": "error", "message": "Composio connected account ID not found for this connector.", } + + from app.services.composio_service import ComposioService + + error = await ComposioService().delete_calendar_event( + connected_account_id=cca_id, + entity_id=f"surfsense_{user_id}", + event_id=final_event_id, + ) + if error: + return {"status": "error", "message": error} else: config_data = dict(connector.config) @@ -202,51 +208,51 @@ def create_delete_calendar_event_tool( expiry=datetime.fromisoformat(exp) if exp else None, ) - service = await asyncio.get_event_loop().run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) - - try: - await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - service.events() - .delete(calendarId="primary", eventId=final_event_id) - .execute() - ), + service = await asyncio.get_event_loop().run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) ) - except Exception as api_err: - from googleapiclient.errors import HttpError - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" + try: + await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + service.events() + .delete(calendarId="primary", eventId=final_event_id) + .execute() + ), ) - try: - from sqlalchemy.orm.attributes import flag_modified + except Exception as api_err: + from googleapiclient.errors import HttpError - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id - ) - ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: + if isinstance(api_err, HttpError) and api_err.resp.status == 403: logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, + f"Insufficient permissions for connector {actual_connector_id}: {api_err}" ) - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", - } - raise + try: + from sqlalchemy.orm.attributes import flag_modified + + _res = await db_session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == actual_connector_id + ) + ) + _conn = _res.scalar_one_or_none() + if _conn and not _conn.config.get("auth_expired"): + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + actual_connector_id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", + } + raise logger.info(f"Calendar event deleted: event_id={final_event_id}") diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/search_events.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/search_events.py index dc6adb822..6772d5a1e 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/search_events.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/search_events.py @@ -16,6 +16,14 @@ _CALENDAR_TYPES = [ ] +def _to_calendar_boundary(value: str, *, is_end: bool) -> str: + """Promote a bare YYYY-MM-DD to RFC3339 with a day-edge time, leave full datetimes alone.""" + if "T" in value: + return value + time = "23:59:59" if is_end else "00:00:00" + return f"{value}T{time}Z" + + def create_search_calendar_events_tool( db_session: AsyncSession | None = None, search_space_id: int | None = None, @@ -61,22 +69,47 @@ def create_search_calendar_events_tool( "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.", } - creds = _build_credentials(connector) + if ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR + ): + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this connector.", + } - from app.connectors.google_calendar_connector import GoogleCalendarConnector + from app.services.composio_service import ComposioService - cal = GoogleCalendarConnector( - credentials=creds, - session=db_session, - user_id=user_id, - connector_id=connector.id, - ) + events_raw, error = await ComposioService().get_calendar_events( + connected_account_id=cca_id, + entity_id=f"surfsense_{user_id}", + time_min=_to_calendar_boundary(start_date, is_end=False), + time_max=_to_calendar_boundary(end_date, is_end=True), + max_results=max_results, + ) + if not events_raw and not error: + error = "No events found in the specified date range." + else: + creds = _build_credentials(connector) - events_raw, error = await cal.get_all_primary_calendar_events( - start_date=start_date, - end_date=end_date, - max_results=max_results, - ) + from app.connectors.google_calendar_connector import ( + GoogleCalendarConnector, + ) + + cal = GoogleCalendarConnector( + credentials=creds, + session=db_session, + user_id=user_id, + connector_id=connector.id, + ) + + events_raw, error = await cal.get_all_primary_calendar_events( + start_date=start_date, + end_date=end_date, + max_results=max_results, + ) if error: if ( diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/update_event.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/update_event.py index 259f52bba..a74979484 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/update_event.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/update_event.py @@ -192,20 +192,62 @@ def create_update_calendar_event_tool( f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}" ) + has_changes = any( + v is not None + for v in ( + final_new_summary, + final_new_start_datetime, + final_new_end_datetime, + final_new_description, + final_new_location, + final_new_attendees, + ) + ) + if not has_changes: + return { + "status": "error", + "message": "No changes specified. Please provide at least one field to update.", + } + if ( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR ): - from app.utils.google_credentials import build_composio_credentials - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) - else: + if not cca_id: return { "status": "error", "message": "Composio connected account ID not found for this connector.", } + + from app.services.composio_service import ComposioService + + tz_for_composio: str | None = None + if final_new_start_datetime is not None and not _is_date_only( + final_new_start_datetime + ): + tz_for_composio = ( + context.get("timezone") if isinstance(context, dict) else None + ) + + _, html_link, error = await ComposioService().update_calendar_event( + connected_account_id=cca_id, + entity_id=f"surfsense_{user_id}", + event_id=final_event_id, + summary=final_new_summary, + start_time=final_new_start_datetime, + end_time=final_new_end_datetime, + timezone=tz_for_composio, + description=final_new_description, + location=final_new_location, + attendees=final_new_attendees, + ) + if error: + return {"status": "error", "message": error} + updated = {"htmlLink": html_link} + logger.info( + f"Calendar event updated via Composio: event_id={final_event_id}" + ) else: config_data = dict(connector.config) @@ -235,81 +277,79 @@ def create_update_calendar_event_tool( expiry=datetime.fromisoformat(exp) if exp else None, ) - service = await asyncio.get_event_loop().run_in_executor( - None, lambda: build("calendar", "v3", credentials=creds) - ) - - update_body: dict[str, Any] = {} - if final_new_summary is not None: - update_body["summary"] = final_new_summary - if final_new_start_datetime is not None: - update_body["start"] = _build_time_body( - final_new_start_datetime, context + service = await asyncio.get_event_loop().run_in_executor( + None, lambda: build("calendar", "v3", credentials=creds) ) - if final_new_end_datetime is not None: - update_body["end"] = _build_time_body(final_new_end_datetime, context) - if final_new_description is not None: - update_body["description"] = final_new_description - if final_new_location is not None: - update_body["location"] = final_new_location - if final_new_attendees is not None: - update_body["attendees"] = [ - {"email": e.strip()} for e in final_new_attendees if e.strip() - ] - if not update_body: - return { - "status": "error", - "message": "No changes specified. Please provide at least one field to update.", - } - - try: - updated = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - service.events() - .patch( - calendarId="primary", - eventId=final_event_id, - body=update_body, - ) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" + update_body: dict[str, Any] = {} + if final_new_summary is not None: + update_body["summary"] = final_new_summary + if final_new_start_datetime is not None: + update_body["start"] = _build_time_body( + final_new_start_datetime, context ) - try: - from sqlalchemy.orm.attributes import flag_modified + if final_new_end_datetime is not None: + update_body["end"] = _build_time_body( + final_new_end_datetime, context + ) + if final_new_description is not None: + update_body["description"] = final_new_description + if final_new_location is not None: + update_body["location"] = final_new_location + if final_new_attendees is not None: + update_body["attendees"] = [ + {"email": e.strip()} for e in final_new_attendees if e.strip() + ] - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id + try: + updated = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + service.events() + .patch( + calendarId="primary", + eventId=final_event_id, + body=update_body, ) - ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, - ) - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", - } - raise + .execute() + ), + ) + except Exception as api_err: + from googleapiclient.errors import HttpError - logger.info(f"Calendar event updated: event_id={final_event_id}") + if isinstance(api_err, HttpError) and api_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {actual_connector_id}: {api_err}" + ) + try: + from sqlalchemy.orm.attributes import flag_modified + + _res = await db_session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == actual_connector_id + ) + ) + _conn = _res.scalar_one_or_none() + if _conn and not _conn.config.get("auth_expired"): + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + actual_connector_id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + logger.info( + f"Calendar event updated via Google API: event_id={final_event_id}" + ) kb_message_suffix = "" if document_id is not None: diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/create_draft.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/create_draft.py index 0bd044695..59e471097 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/create_draft.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/create_draft.py @@ -161,16 +161,39 @@ def create_create_gmail_draft_tool( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR ): - from app.utils.google_credentials import build_composio_credentials - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) - else: + if not cca_id: return { "status": "error", "message": "Composio connected account ID not found for this Gmail connector.", } + + from app.services.composio_service import ComposioService + + ( + draft_id, + draft_message_id, + draft_thread_id, + error, + ) = await ComposioService().create_gmail_draft( + connected_account_id=cca_id, + entity_id=f"surfsense_{user_id}", + to=final_to, + subject=final_subject, + body=final_body, + cc=final_cc, + bcc=final_bcc, + ) + if error: + return {"status": "error", "message": error} + created = { + "id": draft_id, + "message": { + "id": draft_message_id, + "threadId": draft_thread_id, + }, + } + logger.info(f"Gmail draft created via Composio: id={draft_id}") else: from google.oauth2.credentials import Credentials @@ -208,63 +231,65 @@ def create_create_gmail_draft_tool( expiry=datetime.fromisoformat(exp) if exp else None, ) - from googleapiclient.discovery import build + from googleapiclient.discovery import build - gmail_service = build("gmail", "v1", credentials=creds) + gmail_service = build("gmail", "v1", credentials=creds) - message = MIMEText(final_body) - message["to"] = final_to - message["subject"] = final_subject - if final_cc: - message["cc"] = final_cc - if final_bcc: - message["bcc"] = final_bcc - raw = base64.urlsafe_b64encode(message.as_bytes()).decode() + message = MIMEText(final_body) + message["to"] = final_to + message["subject"] = final_subject + if final_cc: + message["cc"] = final_cc + if final_bcc: + message["bcc"] = final_bcc + raw = base64.urlsafe_b64encode(message.as_bytes()).decode() - try: - created = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - gmail_service.users() - .drafts() - .create(userId="me", body={"message": {"raw": raw}}) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" + try: + created = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + gmail_service.users() + .drafts() + .create(userId="me", body={"message": {"raw": raw}}) + .execute() + ), ) - try: - from sqlalchemy.orm.attributes import flag_modified + except Exception as api_err: + from googleapiclient.errors import HttpError - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id - ) - ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: + if isinstance(api_err, HttpError) and api_err.resp.status == 403: logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, + f"Insufficient permissions for connector {actual_connector_id}: {api_err}" ) - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", - } - raise + try: + from sqlalchemy.orm.attributes import flag_modified - logger.info(f"Gmail draft created: id={created.get('id')}") + _res = await db_session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == actual_connector_id + ) + ) + _conn = _res.scalar_one_or_none() + if _conn and not _conn.config.get("auth_expired"): + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + actual_connector_id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", + } + raise + + logger.info( + f"Gmail draft created via Google API: id={created.get('id')}" + ) kb_message_suffix = "" try: diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/read_email.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/read_email.py index deec1627c..39526f25e 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/read_email.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/read_email.py @@ -50,7 +50,56 @@ def create_read_gmail_email_tool( "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", } - from app.agents.new_chat.tools.gmail.search_emails import _build_credentials + if ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR + ): + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this Gmail connector.", + } + + from app.agents.new_chat.tools.gmail.search_emails import ( + _format_gmail_summary, + ) + from app.services.composio_service import ComposioService + + detail, error = await ComposioService().get_gmail_message_detail( + connected_account_id=cca_id, + entity_id=f"surfsense_{user_id}", + message_id=message_id, + ) + if error: + return {"status": "error", "message": error} + if not detail: + return { + "status": "not_found", + "message": f"Email with ID '{message_id}' not found.", + } + + summary = _format_gmail_summary(detail) + content = ( + f"# {summary['subject']}\n\n" + f"**From:** {summary['from']}\n" + f"**To:** {summary['to']}\n" + f"**Date:** {summary['date']}\n\n" + f"## Message Content\n\n" + f"{detail.get('messageText') or detail.get('snippet') or ''}\n\n" + f"## Message Details\n\n" + f"- **Message ID:** {summary['message_id']}\n" + f"- **Thread ID:** {summary['thread_id']}\n" + ) + return { + "status": "success", + "message_id": summary["message_id"] or message_id, + "content": content, + } + + from app.agents.new_chat.tools.gmail.search_emails import ( + _build_credentials, + ) creds = _build_credentials(connector) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/search_emails.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/search_emails.py index 2e363609e..a9d7cdedf 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/search_emails.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/search_emails.py @@ -1,5 +1,4 @@ import logging -from datetime import datetime from typing import Any from langchain_core.tools import tool @@ -15,57 +14,6 @@ _GMAIL_TYPES = [ SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, ] -_token_encryption_cache: object | None = None - - -def _get_token_encryption(): - global _token_encryption_cache - if _token_encryption_cache is None: - from app.config import config - from app.utils.oauth_security import TokenEncryption - - if not config.SECRET_KEY: - raise RuntimeError("SECRET_KEY not configured for token decryption.") - _token_encryption_cache = TokenEncryption(config.SECRET_KEY) - return _token_encryption_cache - - -def _build_credentials(connector: SearchSourceConnector): - """Build Google OAuth Credentials from a connector's stored config. - - Handles both native OAuth connectors (with encrypted tokens) and - Composio-backed connectors. Shared by Gmail and Calendar tools. - """ - from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES - - if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: - from app.utils.google_credentials import build_composio_credentials - - cca_id = connector.config.get("composio_connected_account_id") - if not cca_id: - raise ValueError("Composio connected account ID not found.") - return build_composio_credentials(cca_id) - - from google.oauth2.credentials import Credentials - - cfg = dict(connector.config) - if cfg.get("_token_encrypted"): - enc = _get_token_encryption() - for key in ("token", "refresh_token", "client_secret"): - if cfg.get(key): - cfg[key] = enc.decrypt_token(cfg[key]) - - exp = (cfg.get("expiry") or "").replace("Z", "") - return Credentials( - token=cfg.get("token"), - refresh_token=cfg.get("refresh_token"), - token_uri=cfg.get("token_uri"), - client_id=cfg.get("client_id"), - client_secret=cfg.get("client_secret"), - scopes=cfg.get("scopes", []), - expiry=datetime.fromisoformat(exp) if exp else None, - ) - def create_search_gmail_tool( db_session: AsyncSession | None = None, @@ -110,6 +58,50 @@ def create_search_gmail_tool( "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", } + if ( + connector.connector_type + == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR + ): + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this Gmail connector.", + } + + from app.agents.new_chat.tools.gmail.search_emails import ( + _format_gmail_summary, + ) + from app.services.composio_service import ComposioService + + ( + messages, + _next, + _estimate, + error, + ) = await ComposioService().get_gmail_messages( + connected_account_id=cca_id, + entity_id=f"surfsense_{user_id}", + query=query, + max_results=max_results, + ) + if error: + return {"status": "error", "message": error} + + emails = [_format_gmail_summary(m) for m in messages] + if not emails: + return { + "status": "success", + "emails": [], + "total": 0, + "message": "No emails found.", + } + return {"status": "success", "emails": emails, "total": len(emails)} + + from app.agents.new_chat.tools.gmail.search_emails import ( + _build_credentials, + ) + creds = _build_credentials(connector) from app.connectors.google_gmail_connector import GoogleGmailConnector diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/send_email.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/send_email.py index c3f0999f4..d5de24b62 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/send_email.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/send_email.py @@ -162,16 +162,31 @@ def create_send_gmail_email_tool( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR ): - from app.utils.google_credentials import build_composio_credentials - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) - else: + if not cca_id: return { "status": "error", "message": "Composio connected account ID not found for this Gmail connector.", } + + from app.services.composio_service import ComposioService + + ( + sent_message_id, + sent_thread_id, + error, + ) = await ComposioService().send_gmail_email( + connected_account_id=cca_id, + entity_id=f"surfsense_{user_id}", + to=final_to, + subject=final_subject, + body=final_body, + cc=final_cc, + bcc=final_bcc, + ) + if error: + return {"status": "error", "message": error} + sent = {"id": sent_message_id, "threadId": sent_thread_id} else: from google.oauth2.credentials import Credentials @@ -209,61 +224,61 @@ def create_send_gmail_email_tool( expiry=datetime.fromisoformat(exp) if exp else None, ) - from googleapiclient.discovery import build + from googleapiclient.discovery import build - gmail_service = build("gmail", "v1", credentials=creds) + gmail_service = build("gmail", "v1", credentials=creds) - message = MIMEText(final_body) - message["to"] = final_to - message["subject"] = final_subject - if final_cc: - message["cc"] = final_cc - if final_bcc: - message["bcc"] = final_bcc - raw = base64.urlsafe_b64encode(message.as_bytes()).decode() + message = MIMEText(final_body) + message["to"] = final_to + message["subject"] = final_subject + if final_cc: + message["cc"] = final_cc + if final_bcc: + message["bcc"] = final_bcc + raw = base64.urlsafe_b64encode(message.as_bytes()).decode() - try: - sent = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - gmail_service.users() - .messages() - .send(userId="me", body={"raw": raw}) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {api_err}" + try: + sent = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + gmail_service.users() + .messages() + .send(userId="me", body={"raw": raw}) + .execute() + ), ) - try: - from sqlalchemy.orm.attributes import flag_modified + except Exception as api_err: + from googleapiclient.errors import HttpError - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id - ) - ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: + if isinstance(api_err, HttpError) and api_err.resp.status == 403: logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, + f"Insufficient permissions for connector {actual_connector_id}: {api_err}" ) - return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", - } - raise + try: + from sqlalchemy.orm.attributes import flag_modified + + _res = await db_session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == actual_connector_id + ) + ) + _conn = _res.scalar_one_or_none() + if _conn and not _conn.config.get("auth_expired"): + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + actual_connector_id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", + } + raise logger.info( f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}" diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/trash_email.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/trash_email.py index 1f1f6227a..b78f88934 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/trash_email.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/trash_email.py @@ -162,16 +162,22 @@ def create_trash_gmail_email_tool( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR ): - from app.utils.google_credentials import build_composio_credentials - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) - else: + if not cca_id: return { "status": "error", "message": "Composio connected account ID not found for this Gmail connector.", } + + from app.services.composio_service import ComposioService + + error = await ComposioService().trash_gmail_message( + connected_account_id=cca_id, + entity_id=f"surfsense_{user_id}", + message_id=final_message_id, + ) + if error: + return {"status": "error", "message": error} else: from google.oauth2.credentials import Credentials @@ -209,49 +215,49 @@ def create_trash_gmail_email_tool( expiry=datetime.fromisoformat(exp) if exp else None, ) - from googleapiclient.discovery import build + from googleapiclient.discovery import build - gmail_service = build("gmail", "v1", credentials=creds) + gmail_service = build("gmail", "v1", credentials=creds) - try: - await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - gmail_service.users() - .messages() - .trash(userId="me", id=final_message_id) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {connector.id}: {api_err}" + try: + await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + gmail_service.users() + .messages() + .trash(userId="me", id=final_message_id) + .execute() + ), ) - try: - from sqlalchemy.orm.attributes import flag_modified + except Exception as api_err: + from googleapiclient.errors import HttpError - if not connector.config.get("auth_expired"): - connector.config = { - **connector.config, - "auth_expired": True, - } - flag_modified(connector, "config") - await db_session.commit() - except Exception: + if isinstance(api_err, HttpError) and api_err.resp.status == 403: logger.warning( - "Failed to persist auth_expired for connector %s", - connector.id, - exc_info=True, + f"Insufficient permissions for connector {connector.id}: {api_err}" ) - return { - "status": "insufficient_permissions", - "connector_id": connector.id, - "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", - } - raise + try: + from sqlalchemy.orm.attributes import flag_modified + + if not connector.config.get("auth_expired"): + connector.config = { + **connector.config, + "auth_expired": True, + } + flag_modified(connector, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + connector.id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": connector.id, + "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", + } + raise logger.info(f"Gmail email trashed: message_id={final_message_id}") diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/update_draft.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/update_draft.py index 91178cd21..b6688ac53 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/update_draft.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/update_draft.py @@ -192,16 +192,51 @@ def create_update_gmail_draft_tool( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR ): - from app.utils.google_credentials import build_composio_credentials - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - creds = build_composio_credentials(cca_id) - else: + if not cca_id: return { "status": "error", "message": "Composio connected account ID not found for this Gmail connector.", } + + if not final_draft_id: + return { + "status": "error", + "message": ( + "Could not find this draft in Gmail. " + "It may have already been sent or deleted." + ), + } + + from app.services.composio_service import ComposioService + + ( + new_draft_id, + new_message_id, + error, + ) = await ComposioService().update_gmail_draft( + connected_account_id=cca_id, + entity_id=f"surfsense_{user_id}", + draft_id=final_draft_id, + to=final_to or None, + subject=final_subject, + body=final_body, + cc=final_cc, + bcc=final_bcc, + ) + if error: + if "not found" in error.lower() or "no longer" in error.lower(): + return { + "status": "error", + "message": "Draft no longer exists in Gmail. It may have been sent or deleted.", + } + return {"status": "error", "message": error} + + updated = { + "id": new_draft_id or final_draft_id, + "message": {"id": new_message_id} if new_message_id else {}, + } + logger.info(f"Gmail draft updated via Composio: id={updated.get('id')}") else: from google.oauth2.credentials import Credentials @@ -239,88 +274,90 @@ def create_update_gmail_draft_tool( expiry=datetime.fromisoformat(exp) if exp else None, ) - from googleapiclient.discovery import build + from googleapiclient.discovery import build - gmail_service = build("gmail", "v1", credentials=creds) + gmail_service = build("gmail", "v1", credentials=creds) - # Resolve draft_id if not already available - if not final_draft_id: - logger.info( - f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}" - ) - final_draft_id = await _find_draft_id_by_message( - gmail_service, message_id - ) - - if not final_draft_id: - return { - "status": "error", - "message": ( - "Could not find this draft in Gmail. " - "It may have already been sent or deleted." - ), - } - - message = MIMEText(final_body) - if final_to: - message["to"] = final_to - message["subject"] = final_subject - if final_cc: - message["cc"] = final_cc - if final_bcc: - message["bcc"] = final_bcc - raw = base64.urlsafe_b64encode(message.as_bytes()).decode() - - try: - updated = await asyncio.get_event_loop().run_in_executor( - None, - lambda: ( - gmail_service.users() - .drafts() - .update( - userId="me", - id=final_draft_id, - body={"message": {"raw": raw}}, - ) - .execute() - ), - ) - except Exception as api_err: - from googleapiclient.errors import HttpError - - if isinstance(api_err, HttpError) and api_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {connector.id}: {api_err}" + # Resolve draft_id if not already available + if not final_draft_id: + logger.info( + f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}" + ) + final_draft_id = await _find_draft_id_by_message( + gmail_service, message_id ) - try: - from sqlalchemy.orm.attributes import flag_modified - if not connector.config.get("auth_expired"): - connector.config = { - **connector.config, - "auth_expired": True, - } - flag_modified(connector, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - connector.id, - exc_info=True, - ) - return { - "status": "insufficient_permissions", - "connector_id": connector.id, - "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", - } - if isinstance(api_err, HttpError) and api_err.resp.status == 404: + if not final_draft_id: return { "status": "error", - "message": "Draft no longer exists in Gmail. It may have been sent or deleted.", + "message": ( + "Could not find this draft in Gmail. " + "It may have already been sent or deleted." + ), } - raise - logger.info(f"Gmail draft updated: id={updated.get('id')}") + message = MIMEText(final_body) + if final_to: + message["to"] = final_to + message["subject"] = final_subject + if final_cc: + message["cc"] = final_cc + if final_bcc: + message["bcc"] = final_bcc + raw = base64.urlsafe_b64encode(message.as_bytes()).decode() + + try: + updated = await asyncio.get_event_loop().run_in_executor( + None, + lambda: ( + gmail_service.users() + .drafts() + .update( + userId="me", + id=final_draft_id, + body={"message": {"raw": raw}}, + ) + .execute() + ), + ) + except Exception as api_err: + from googleapiclient.errors import HttpError + + if isinstance(api_err, HttpError) and api_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {connector.id}: {api_err}" + ) + try: + from sqlalchemy.orm.attributes import flag_modified + + if not connector.config.get("auth_expired"): + connector.config = { + **connector.config, + "auth_expired": True, + } + flag_modified(connector, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + connector.id, + exc_info=True, + ) + return { + "status": "insufficient_permissions", + "connector_id": connector.id, + "message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", + } + if isinstance(api_err, HttpError) and api_err.resp.status == 404: + return { + "status": "error", + "message": "Draft no longer exists in Gmail. It may have been sent or deleted.", + } + raise + + logger.info( + f"Gmail draft updated via Google API: id={updated.get('id')}" + ) kb_message_suffix = "" if document_id: diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/create_file.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/create_file.py index f36db8f3f..9e9a30429 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/create_file.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/create_file.py @@ -179,59 +179,96 @@ def create_create_google_drive_file_tool( f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}" ) - pre_built_creds = None + async def _flag_auth_expired() -> None: + try: + from sqlalchemy.orm.attributes import flag_modified + + _res = await db_session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == actual_connector_id + ) + ) + _conn = _res.scalar_one_or_none() + if _conn and not _conn.config.get("auth_expired"): + _conn.config = {**_conn.config, "auth_expired": True} + flag_modified(_conn, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + actual_connector_id, + exc_info=True, + ) + if ( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR ): - from app.utils.google_credentials import build_composio_credentials - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - pre_built_creds = build_composio_credentials(cca_id) + if not cca_id: + return { + "status": "error", + "message": "Composio connected account ID not found for this Google Drive connector.", + } - client = GoogleDriveClient( - session=db_session, - connector_id=actual_connector_id, - credentials=pre_built_creds, - ) - try: - created = await client.create_file( + from app.services.composio_service import ComposioService + + created, error = await ComposioService().create_drive_file_from_text( + connected_account_id=cca_id, + entity_id=f"surfsense_{user_id}", name=final_name, mime_type=mime_type, - parent_folder_id=final_parent_folder_id, content=final_content, + parent_id=final_parent_folder_id, ) - except HttpError as http_err: - if http_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {actual_connector_id}: {http_err}" - ) - try: - from sqlalchemy.orm.attributes import flag_modified - _res = await db_session.execute( - select(SearchSourceConnector).where( - SearchSourceConnector.id == actual_connector_id - ) - ) - _conn = _res.scalar_one_or_none() - if _conn and not _conn.config.get("auth_expired"): - _conn.config = {**_conn.config, "auth_expired": True} - flag_modified(_conn, "config") - await db_session.commit() - except Exception: + if error or not created: + err_lower = (error or "").lower() + if ( + "insufficient" in err_lower + or "permission" in err_lower + or "403" in err_lower + ): logger.warning( - "Failed to persist auth_expired for connector %s", - actual_connector_id, - exc_info=True, + f"Insufficient permissions for Composio Drive connector {actual_connector_id}: {error}" ) + await _flag_auth_expired() + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.", + } + logger.error( + f"Composio Drive create_file failed for connector {actual_connector_id}: {error}" + ) return { - "status": "insufficient_permissions", - "connector_id": actual_connector_id, - "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.", + "status": "error", + "message": "Something went wrong while creating the file. Please try again.", } - raise + else: + client = GoogleDriveClient( + session=db_session, + connector_id=actual_connector_id, + ) + try: + created = await client.create_file( + name=final_name, + mime_type=mime_type, + parent_folder_id=final_parent_folder_id, + content=final_content, + ) + except HttpError as http_err: + if http_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {actual_connector_id}: {http_err}" + ) + await _flag_auth_expired() + return { + "status": "insufficient_permissions", + "connector_id": actual_connector_id, + "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.", + } + raise logger.info( f"Google Drive file created: id={created.get('id')}, name={created.get('name')}" diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/trash_file.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/trash_file.py index 832afff0d..f7531cf3d 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/trash_file.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/trash_file.py @@ -158,51 +158,84 @@ def create_delete_google_drive_file_tool( f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}" ) - pre_built_creds = None + async def _flag_auth_expired() -> None: + try: + from sqlalchemy.orm.attributes import flag_modified + + if not connector.config.get("auth_expired"): + connector.config = { + **connector.config, + "auth_expired": True, + } + flag_modified(connector, "config") + await db_session.commit() + except Exception: + logger.warning( + "Failed to persist auth_expired for connector %s", + connector.id, + exc_info=True, + ) + if ( connector.connector_type == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR ): - from app.utils.google_credentials import build_composio_credentials - cca_id = connector.config.get("composio_connected_account_id") - if cca_id: - pre_built_creds = build_composio_credentials(cca_id) - - client = GoogleDriveClient( - session=db_session, - connector_id=connector.id, - credentials=pre_built_creds, - ) - try: - await client.trash_file(file_id=final_file_id) - except HttpError as http_err: - if http_err.resp.status == 403: - logger.warning( - f"Insufficient permissions for connector {connector.id}: {http_err}" - ) - try: - from sqlalchemy.orm.attributes import flag_modified - - if not connector.config.get("auth_expired"): - connector.config = { - **connector.config, - "auth_expired": True, - } - flag_modified(connector, "config") - await db_session.commit() - except Exception: - logger.warning( - "Failed to persist auth_expired for connector %s", - connector.id, - exc_info=True, - ) + if not cca_id: return { - "status": "insufficient_permissions", - "connector_id": connector.id, - "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.", + "status": "error", + "message": "Composio connected account ID not found for this Google Drive connector.", } - raise + + from app.services.composio_service import ComposioService + + error = await ComposioService().trash_drive_file( + connected_account_id=cca_id, + entity_id=f"surfsense_{user_id}", + file_id=final_file_id, + ) + if error: + err_lower = error.lower() + if ( + "insufficient" in err_lower + or "permission" in err_lower + or "403" in err_lower + ): + logger.warning( + f"Insufficient permissions for Composio Drive connector {connector.id}: {error}" + ) + await _flag_auth_expired() + return { + "status": "insufficient_permissions", + "connector_id": connector.id, + "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.", + } + logger.error( + f"Composio Drive trash_file failed for connector {connector.id}: {error}" + ) + return { + "status": "error", + "message": "Something went wrong while trashing the file. Please try again.", + } + else: + client = GoogleDriveClient( + session=db_session, + connector_id=connector.id, + ) + try: + await client.trash_file(file_id=final_file_id) + except HttpError as http_err: + if http_err.resp.status == 403: + logger.warning( + f"Insufficient permissions for connector {connector.id}: {http_err}" + ) + await _flag_auth_expired() + return { + "status": "insufficient_permissions", + "connector_id": connector.id, + "message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.", + } + raise logger.info( f"Google Drive file deleted (moved to trash): file_id={final_file_id}" diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/__init__.py index 768738118..dc721013a 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/__init__.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/__init__.py @@ -1,11 +1,3 @@ -"""Jira tools for creating, updating, and deleting issues.""" +"""Jira route: native tool factories are empty; MCP supplies tools when configured.""" -from .create_issue import create_create_jira_issue_tool -from .delete_issue import create_delete_jira_issue_tool -from .update_issue import create_update_jira_issue_tool - -__all__ = [ - "create_create_jira_issue_tool", - "create_delete_jira_issue_tool", - "create_update_jira_issue_tool", -] +__all__: list[str] = [] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/index.py index 342f120be..08b0e005e 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/index.py @@ -6,29 +6,9 @@ from app.agents.multi_agent_chat.subagents.shared.permissions import ( ToolsPermissions, ) -from .create_issue import create_create_jira_issue_tool -from .delete_issue import create_delete_jira_issue_tool -from .update_issue import create_update_jira_issue_tool - def load_tools( *, dependencies: dict[str, Any] | None = None, **kwargs: Any ) -> ToolsPermissions: - d = {**(dependencies or {}), **kwargs} - common = { - "db_session": d["db_session"], - "search_space_id": d["search_space_id"], - "user_id": d["user_id"], - "connector_id": d.get("connector_id"), - } - create = create_create_jira_issue_tool(**common) - update = create_update_jira_issue_tool(**common) - delete = create_delete_jira_issue_tool(**common) - return { - "allow": [], - "ask": [ - {"name": getattr(create, "name", "") or "", "tool": create}, - {"name": getattr(update, "name", "") or "", "tool": update}, - {"name": getattr(delete, "name", "") or "", "tool": delete}, - ], - } + _ = {**(dependencies or {}), **kwargs} + return {"allow": [], "ask": []} diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/__init__.py index 31acf1e2a..5b464a9df 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/__init__.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/__init__.py @@ -1,11 +1,3 @@ -"""Linear tools for creating, updating, and deleting issues.""" +"""Linear route: native tool factories are empty; MCP supplies tools when configured.""" -from .create_issue import create_create_linear_issue_tool -from .delete_issue import create_delete_linear_issue_tool -from .update_issue import create_update_linear_issue_tool - -__all__ = [ - "create_create_linear_issue_tool", - "create_delete_linear_issue_tool", - "create_update_linear_issue_tool", -] +__all__: list[str] = [] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/create_issue.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/create_issue.py deleted file mode 100644 index ff254e133..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/create_issue.py +++ /dev/null @@ -1,248 +0,0 @@ -import logging -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.new_chat.tools.hitl import request_approval -from app.connectors.linear_connector import LinearAPIError, LinearConnector -from app.services.linear import LinearToolMetadataService - -logger = logging.getLogger(__name__) - - -def create_create_linear_issue_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, - connector_id: int | None = None, -): - """ - Factory function to create the create_linear_issue tool. - - Args: - db_session: Database session for accessing the Linear connector - search_space_id: Search space ID to find the Linear connector - user_id: User ID for fetching user-specific context - connector_id: Optional specific connector ID (if known) - - Returns: - Configured create_linear_issue tool - """ - - @tool - async def create_linear_issue( - title: str, - description: str | None = None, - ) -> dict[str, Any]: - """Create a new issue in Linear. - - Use this tool when the user explicitly asks to create, add, or file - a new issue / ticket / task in Linear. The user MUST describe the issue - before you call this tool. If the request is vague, ask what the issue - should be about. Never call this tool without a clear topic from the user. - - Args: - title: Short, descriptive issue title. Infer from the user's request. - description: Optional markdown body for the issue. Generate from context. - - Returns: - Dictionary with: - - status: "success", "rejected", or "error" - - issue_id: Linear issue UUID (if success) - - identifier: Human-readable ID like "ENG-42" (if success) - - url: URL to the created issue (if success) - - message: Result message - - IMPORTANT: If status is "rejected", the user explicitly declined the action. - Respond with a brief acknowledgment (e.g., "Understood, I won't create the issue.") - and move on. Do NOT retry, troubleshoot, or suggest alternatives. - - Examples: - - "Create a Linear issue for the login bug" - - "File a ticket about the payment timeout problem" - - "Add an issue for the broken search feature" - """ - logger.info(f"create_linear_issue called: title='{title}'") - - if db_session is None or search_space_id is None or user_id is None: - logger.error( - "Linear tool not properly configured - missing required parameters" - ) - return { - "status": "error", - "message": "Linear tool not properly configured. Please contact support.", - } - - try: - metadata_service = LinearToolMetadataService(db_session) - context = await metadata_service.get_creation_context( - search_space_id, user_id - ) - - if "error" in context: - logger.error(f"Failed to fetch creation context: {context['error']}") - return {"status": "error", "message": context["error"]} - - workspaces = context.get("workspaces", []) - if workspaces and all(w.get("auth_expired") for w in workspaces): - logger.warning("All Linear accounts have expired authentication") - return { - "status": "auth_error", - "message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.", - "connector_type": "linear", - } - - logger.info(f"Requesting approval for creating Linear issue: '{title}'") - result = request_approval( - action_type="linear_issue_creation", - tool_name="create_linear_issue", - params={ - "title": title, - "description": description, - "team_id": None, - "state_id": None, - "assignee_id": None, - "priority": None, - "label_ids": [], - "connector_id": connector_id, - }, - context=context, - ) - - if result.rejected: - logger.info("Linear issue creation rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_title = result.params.get("title", title) - final_description = result.params.get("description", description) - final_team_id = result.params.get("team_id") - final_state_id = result.params.get("state_id") - final_assignee_id = result.params.get("assignee_id") - final_priority = result.params.get("priority") - final_label_ids = result.params.get("label_ids") or [] - final_connector_id = result.params.get("connector_id", connector_id) - - if not final_title or not final_title.strip(): - logger.error("Title is empty or contains only whitespace") - return {"status": "error", "message": "Issue title cannot be empty."} - if not final_team_id: - return { - "status": "error", - "message": "A team must be selected to create an issue.", - } - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - actual_connector_id = final_connector_id - if actual_connector_id is None: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.LINEAR_CONNECTOR, - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "No Linear connector found. Please connect Linear in your workspace settings.", - } - actual_connector_id = connector.id - logger.info(f"Found Linear connector: id={actual_connector_id}") - else: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == actual_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.LINEAR_CONNECTOR, - ) - ) - connector = result.scalars().first() - if not connector: - return { - "status": "error", - "message": "Selected Linear connector is invalid or has been disconnected.", - } - logger.info(f"Validated Linear connector: id={actual_connector_id}") - - logger.info( - f"Creating Linear issue with final params: title='{final_title}'" - ) - linear_client = LinearConnector( - session=db_session, connector_id=actual_connector_id - ) - result = await linear_client.create_issue( - team_id=final_team_id, - title=final_title, - description=final_description, - state_id=final_state_id, - assignee_id=final_assignee_id, - priority=final_priority, - label_ids=final_label_ids if final_label_ids else None, - ) - - if result.get("status") == "error": - logger.error(f"Failed to create Linear issue: {result.get('message')}") - return {"status": "error", "message": result.get("message")} - - logger.info( - f"Linear issue created: {result.get('identifier')} - {result.get('title')}" - ) - - kb_message_suffix = "" - try: - from app.services.linear import LinearKBSyncService - - kb_service = LinearKBSyncService(db_session) - kb_result = await kb_service.sync_after_create( - issue_id=result.get("id"), - issue_identifier=result.get("identifier", ""), - issue_title=result.get("title", final_title), - issue_url=result.get("url"), - description=final_description, - connector_id=actual_connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - if kb_result["status"] == "success": - kb_message_suffix = " Your knowledge base has also been updated." - else: - kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." - except Exception as kb_err: - logger.warning(f"KB sync after create failed: {kb_err}") - kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync." - - return { - "status": "success", - "issue_id": result.get("id"), - "identifier": result.get("identifier"), - "url": result.get("url"), - "message": (result.get("message", "") + kb_message_suffix), - } - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - - logger.error(f"Error creating Linear issue: {e}", exc_info=True) - if isinstance(e, ValueError | LinearAPIError): - message = str(e) - else: - message = ( - "Something went wrong while creating the issue. Please try again." - ) - return {"status": "error", "message": message} - - return create_linear_issue diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/delete_issue.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/delete_issue.py deleted file mode 100644 index 29ef0cdf2..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/delete_issue.py +++ /dev/null @@ -1,245 +0,0 @@ -import logging -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.new_chat.tools.hitl import request_approval -from app.connectors.linear_connector import LinearAPIError, LinearConnector -from app.services.linear import LinearToolMetadataService - -logger = logging.getLogger(__name__) - - -def create_delete_linear_issue_tool( - db_session: AsyncSession | None = None, - search_space_id: int | None = None, - user_id: str | None = None, - connector_id: int | None = None, -): - """ - Factory function to create the delete_linear_issue tool. - - Args: - db_session: Database session for accessing the Linear connector - search_space_id: Search space ID to find the Linear connector - user_id: User ID for finding the correct Linear connector - connector_id: Optional specific connector ID (if known) - - Returns: - Configured delete_linear_issue tool - """ - - @tool - async def delete_linear_issue( - issue_ref: str, - delete_from_kb: bool = False, - ) -> dict[str, Any]: - """Archive (delete) a Linear issue. - - Use this tool when the user asks to delete, remove, or archive a Linear issue. - Note that Linear archives issues rather than permanently deleting them - (they can be restored from the archive). - - - Args: - issue_ref: The issue to delete. Can be the issue title (e.g. "Fix login bug"), - the identifier (e.g. "ENG-42"), or the full document title - (e.g. "ENG-42: Fix login bug"). - delete_from_kb: Whether to also remove the issue from the knowledge base. - Default is False. Set to True to remove from both Linear - and the knowledge base. - - Returns: - Dictionary with: - - status: "success", "rejected", "not_found", or "error" - - identifier: Human-readable ID like "ENG-42" (if success) - - message: Success or error message - - deleted_from_kb: Whether the issue was also removed from the knowledge base (if success) - - IMPORTANT: - - If status is "rejected", the user explicitly declined the action. - Respond with a brief acknowledgment (e.g., "Understood, I won't delete the issue.") - and move on. Do NOT ask for alternatives or troubleshoot. - - If status is "not_found", inform the user conversationally using the exact message - provided. Do NOT treat this as an error. Simply relay the message and ask the user - to verify the issue title or identifier, or check if it has been indexed. - Examples: - - "Delete the 'Fix login bug' Linear issue" - - "Archive ENG-42" - - "Remove the 'Old payment flow' issue from Linear" - """ - logger.info( - f"delete_linear_issue called: issue_ref='{issue_ref}', delete_from_kb={delete_from_kb}" - ) - - if db_session is None or search_space_id is None or user_id is None: - logger.error( - "Linear tool not properly configured - missing required parameters" - ) - return { - "status": "error", - "message": "Linear tool not properly configured. Please contact support.", - } - - try: - metadata_service = LinearToolMetadataService(db_session) - context = await metadata_service.get_delete_context( - search_space_id, user_id, issue_ref - ) - - if "error" in context: - error_msg = context["error"] - if context.get("auth_expired"): - logger.warning(f"Auth expired for delete context: {error_msg}") - return { - "status": "auth_error", - "message": error_msg, - "connector_id": context.get("connector_id"), - "connector_type": "linear", - } - if "not found" in error_msg.lower(): - logger.warning(f"Issue not found: {error_msg}") - return {"status": "not_found", "message": error_msg} - else: - logger.error(f"Failed to fetch delete context: {error_msg}") - return {"status": "error", "message": error_msg} - - issue_id = context["issue"]["id"] - issue_identifier = context["issue"].get("identifier", "") - document_id = context["issue"]["document_id"] - connector_id_from_context = context.get("workspace", {}).get("id") - - logger.info( - f"Requesting approval for deleting Linear issue: '{issue_ref}' " - f"(id={issue_id}, delete_from_kb={delete_from_kb})" - ) - result = request_approval( - action_type="linear_issue_deletion", - tool_name="delete_linear_issue", - params={ - "issue_id": issue_id, - "connector_id": connector_id_from_context, - "delete_from_kb": delete_from_kb, - }, - context=context, - ) - - if result.rejected: - logger.info("Linear issue deletion rejected by user") - return { - "status": "rejected", - "message": "User declined. Do not retry or suggest alternatives.", - } - - final_issue_id = result.params.get("issue_id", issue_id) - final_connector_id = result.params.get( - "connector_id", connector_id_from_context - ) - final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb) - - logger.info( - f"Deleting Linear issue with final params: issue_id={final_issue_id}, " - f"connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}" - ) - - from sqlalchemy.future import select - - from app.db import SearchSourceConnector, SearchSourceConnectorType - - if final_connector_id: - result = await db_session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == final_connector_id, - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.LINEAR_CONNECTOR, - ) - ) - connector = result.scalars().first() - if not connector: - logger.error( - f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}" - ) - return { - "status": "error", - "message": "Selected Linear connector is invalid or has been disconnected.", - } - actual_connector_id = connector.id - logger.info(f"Validated Linear connector: id={actual_connector_id}") - else: - logger.error("No connector found for this issue") - return { - "status": "error", - "message": "No connector found for this issue.", - } - - linear_client = LinearConnector( - session=db_session, connector_id=actual_connector_id - ) - - result = await linear_client.archive_issue(issue_id=final_issue_id) - - logger.info( - f"archive_issue result: {result.get('status')} - {result.get('message', '')}" - ) - - deleted_from_kb = False - if ( - result.get("status") == "success" - and final_delete_from_kb - and document_id - ): - try: - from app.db import Document - - doc_result = await db_session.execute( - select(Document).filter(Document.id == document_id) - ) - document = doc_result.scalars().first() - if document: - await db_session.delete(document) - await db_session.commit() - deleted_from_kb = True - logger.info( - f"Deleted document {document_id} from knowledge base" - ) - else: - logger.warning(f"Document {document_id} not found in KB") - except Exception as e: - logger.error(f"Failed to delete document from KB: {e}") - await db_session.rollback() - result["warning"] = ( - f"Issue archived in Linear, but failed to remove from knowledge base: {e!s}" - ) - - if result.get("status") == "success": - result["deleted_from_kb"] = deleted_from_kb - if issue_identifier: - result["message"] = ( - f"Issue {issue_identifier} archived successfully." - ) - if deleted_from_kb: - result["message"] = ( - f"{result.get('message', '')} Also removed from the knowledge base." - ) - - return result - - except Exception as e: - from langgraph.errors import GraphInterrupt - - if isinstance(e, GraphInterrupt): - raise - - logger.error(f"Error deleting Linear issue: {e}", exc_info=True) - if isinstance(e, ValueError | LinearAPIError): - message = str(e) - else: - message = ( - "Something went wrong while deleting the issue. Please try again." - ) - return {"status": "error", "message": message} - - return delete_linear_issue diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/index.py index f1ee49964..08b0e005e 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/index.py @@ -6,29 +6,9 @@ from app.agents.multi_agent_chat.subagents.shared.permissions import ( ToolsPermissions, ) -from .create_issue import create_create_linear_issue_tool -from .delete_issue import create_delete_linear_issue_tool -from .update_issue import create_update_linear_issue_tool - def load_tools( *, dependencies: dict[str, Any] | None = None, **kwargs: Any ) -> ToolsPermissions: - d = {**(dependencies or {}), **kwargs} - common = { - "db_session": d["db_session"], - "search_space_id": d["search_space_id"], - "user_id": d["user_id"], - "connector_id": d.get("connector_id"), - } - create = create_create_linear_issue_tool(**common) - update = create_update_linear_issue_tool(**common) - delete = create_delete_linear_issue_tool(**common) - return { - "allow": [], - "ask": [ - {"name": getattr(create, "name", "") or "", "tool": create}, - {"name": getattr(update, "name", "") or "", "tool": update}, - {"name": getattr(delete, "name", "") or "", "tool": delete}, - ], - } + _ = {**(dependencies or {}), **kwargs} + return {"allow": [], "ask": []} diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index 1f4024d9d..605c31416 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -31,7 +31,6 @@ from langchain.agents import create_agent from langchain.agents.middleware import ( LLMToolSelectorMiddleware, ModelCallLimitMiddleware, - ModelFallbackMiddleware, TodoListMiddleware, ToolCallLimitMiddleware, ) @@ -77,6 +76,9 @@ from app.agents.new_chat.middleware import ( create_surfsense_compaction_middleware, default_skills_sources, ) +from app.agents.new_chat.middleware.scoped_model_fallback import ( + ScopedModelFallbackMiddleware, +) from app.agents.new_chat.permissions import Rule, Ruleset from app.agents.new_chat.plugin_loader import ( PluginContext, @@ -792,15 +794,15 @@ def _build_compiled_agent_blocking( # Fallback chain — primary is the agent's own model; we add cheap # alternatives. Off by default; only the first call site that # configures the chain via env should enable it. - fallback_mw: ModelFallbackMiddleware | None = None + fallback_mw: ScopedModelFallbackMiddleware | None = None if flags.enable_model_fallback and not flags.disable_new_agent_stack: try: - fallback_mw = ModelFallbackMiddleware( + fallback_mw = ScopedModelFallbackMiddleware( "openai:gpt-4o-mini", "anthropic:claude-3-5-haiku-20241022", ) except Exception: - logging.warning("ModelFallbackMiddleware init failed; skipping.") + logging.warning("ScopedModelFallbackMiddleware init failed; skipping.") fallback_mw = None model_call_limit_mw = ( ModelCallLimitMiddleware( diff --git a/surfsense_backend/app/agents/new_chat/context.py b/surfsense_backend/app/agents/new_chat/context.py index d720b524b..a20a43a66 100644 --- a/surfsense_backend/app/agents/new_chat/context.py +++ b/surfsense_backend/app/agents/new_chat/context.py @@ -46,6 +46,10 @@ class SurfSenseContextSchema: Read by ``KnowledgePriorityMiddleware`` to seed its priority list. Stays out of the compiled-agent cache key — that's the whole point of putting it here. + mentioned_folder_ids: KB folders the user @-mentioned this turn + (cloud filesystem mode). Surfaced as ``[USER-MENTIONED]`` + entries in ```` so the agent prioritises + walking those folders with ``ls`` / ``find_documents``. file_operation_contract: One-shot file operation contract emitted by ``FileIntentMiddleware`` for the upcoming turn. turn_id / request_id: Correlation IDs surfaced by the streaming @@ -59,6 +63,7 @@ class SurfSenseContextSchema: search_space_id: int | None = None mentioned_document_ids: list[int] = field(default_factory=list) + mentioned_folder_ids: list[int] = field(default_factory=list) file_operation_contract: FileOperationContractState | None = None turn_id: str | None = None request_id: str | None = None diff --git a/surfsense_backend/app/agents/new_chat/feature_flags.py b/surfsense_backend/app/agents/new_chat/feature_flags.py index b3dc0fa82..3cea051ef 100644 --- a/surfsense_backend/app/agents/new_chat/feature_flags.py +++ b/surfsense_backend/app/agents/new_chat/feature_flags.py @@ -28,7 +28,6 @@ Defaults: SURFSENSE_ENABLE_PERMISSION=true SURFSENSE_ENABLE_DOOM_LOOP=true SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call - SURFSENSE_ENABLE_STREAM_PARITY_V2=true Master kill-switch (overrides everything else): @@ -88,15 +87,6 @@ class AgentFeatureFlags: enable_action_log: bool = True enable_revert_route: bool = True - # Streaming parity v2 — opt in to LangChain's structured - # ``AIMessageChunk`` content (typed reasoning blocks, tool-input - # deltas) and propagate the real ``tool_call_id`` to the SSE layer. - # When OFF the ``stream_new_chat`` task falls back to the str-only - # text path and the synthetic ``call_`` tool-call id (no - # ``langchainToolCallId`` propagation). Schema migrations 135/136 - # ship unconditionally because they're forward-compatible. - enable_stream_parity_v2: bool = True - # Plugins enable_plugin_loader: bool = False @@ -169,7 +159,6 @@ class AgentFeatureFlags: enable_kb_planner_runnable=False, enable_action_log=False, enable_revert_route=False, - enable_stream_parity_v2=False, enable_plugin_loader=False, enable_otel=False, enable_agent_cache=False, @@ -208,10 +197,6 @@ class AgentFeatureFlags: # Snapshot / revert enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", True), enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", True), - # Streaming parity v2 - enable_stream_parity_v2=_env_bool( - "SURFSENSE_ENABLE_STREAM_PARITY_V2", True - ), # Plugins enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False), # Observability diff --git a/surfsense_backend/app/agents/new_chat/mention_resolver.py b/surfsense_backend/app/agents/new_chat/mention_resolver.py new file mode 100644 index 000000000..00bb7e71f --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/mention_resolver.py @@ -0,0 +1,281 @@ +"""Resolve @-mention chips to canonical virtual paths and substitute the +user-visible ``@title`` tokens with backtick-wrapped paths in the prompt +the agent sees. + +The frontend's mention seam is a single discriminated-union list of +``{kind: "doc" | "folder", id, title, document_type?}`` chips (see +``surfsense_web/atoms/chat/mentioned-documents.atom.ts``). When a turn +reaches the backend stream task we have three needs that this module +centralises: + +1. Map each chip to its canonical virtual path + (``/documents/.../file.xml`` for docs, ``/documents/MyFolder/`` for + folders) so the agent sees concrete filesystem locations instead of + ambiguous ``@``-titles. +2. Substitute ``@title`` tokens in the user-typed text with backtick- + wrapped paths so the path becomes part of the ``HumanMessage`` body + the LLM consumes — without rewriting the persisted user message + text (which keeps ``@title`` so chip rendering on reload is + unchanged). +3. Surface the resolved id sets (docs + folders) to the priority + middleware so it can render ``[USER-MENTIONED]`` priority entries + without re-doing path resolution. + +This is intentionally one module — see the architectural note in +``mention-paths-and-folders`` plan: previously the doc-resolution lived +inline in ``stream_new_chat`` and the folder mention had no resolution +at all. Centralising both behind a single ``resolve_mentions`` call +turns a leaky multi-field seam into a single deeper interface. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.path_resolver import ( + DOCUMENTS_ROOT, + build_path_index, + doc_to_virtual_path, +) +from app.db import Document, Folder +from app.schemas.new_chat import MentionedDocumentInfo + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class ResolvedMention: + """Canonical view of a single @-mention chip. + + ``virtual_path`` is the path the agent will see (no trailing slash + for documents, trailing ``/`` for folders to match the convention + used by ``KnowledgeTreeMiddleware``). + """ + + kind: str # "doc" | "folder" + id: int + title: str + virtual_path: str + + +@dataclass +class ResolvedMentionSet: + """Aggregate result of resolving a turn's mention chips. + + ``token_to_path`` maps ``@title`` (the literal token the user typed + and the editor emitted) to the canonical virtual path for that + chip. It is produced longest-token-first so substitution mirrors + ``parseMentionSegments`` on the frontend (a longer title like + ``@Project Roadmap`` is never shadowed by a shorter prefix + ``@Project``). + + ``mentioned_document_ids`` collapses doc + surfsense_doc chips into + a single ordered, deduped list because the priority middleware + treats them uniformly downstream — see + ``KnowledgePriorityMiddleware._compute_priority_paths``. + """ + + mentions: list[ResolvedMention] = field(default_factory=list) + token_to_path: list[tuple[str, str]] = field(default_factory=list) + mentioned_document_ids: list[int] = field(default_factory=list) + mentioned_folder_ids: list[int] = field(default_factory=list) + + +def _folder_virtual_path(folder_id: int, folder_paths: dict[int, str]) -> str: + """Return ``/documents/Folder/Sub/`` for a folder id. + + Falls back to the documents root when the folder is missing from + the index (deleted or in a different search space). Trailing slash + matches ``KnowledgeTreeMiddleware`` (``/documents/MyFolder/``) so + the agent's ``ls`` can dispatch on it as a directory. + """ + base = folder_paths.get(folder_id, DOCUMENTS_ROOT) + return f"{base}/" if not base.endswith("/") else base + + +async def resolve_mentions( + session: AsyncSession, + *, + search_space_id: int, + mentioned_documents: list[MentionedDocumentInfo] | None, + mentioned_document_ids: list[int] | None = None, + mentioned_surfsense_doc_ids: list[int] | None = None, + mentioned_folder_ids: list[int] | None = None, +) -> ResolvedMentionSet: + """Resolve every @-mention chip on a turn into virtual paths. + + The function takes both the ``mentioned_documents`` discriminated + list (chip metadata used for substitution + persistence) and the + parallel id arrays (``mentioned_document_ids``, + ``mentioned_surfsense_doc_ids``, ``mentioned_folder_ids``) for two + reasons: + + * Legacy clients that haven't migrated to the unified chip list + still send the id arrays — we treat the union as authoritative. + * The id arrays are the canonical input to + ``KnowledgePriorityMiddleware`` (via ``SurfSenseContextSchema``); + returning the deduped, validated lists lets the route forward + them unchanged. + + Resolution is best-effort: a chip whose id no longer exists (e.g. + document was deleted between mention and submit) is silently + dropped. The agent still sees the user's original text, just + without a backtick-path substitution for that chip. + """ + chip_doc_ids: list[int] = [] + chip_folder_ids: list[int] = [] + chip_titles_by_id: dict[tuple[str, int], str] = {} + if mentioned_documents: + for chip in mentioned_documents: + kind = chip.kind + if kind == "folder": + chip_folder_ids.append(chip.id) + else: + chip_doc_ids.append(chip.id) + chip_titles_by_id[(kind, chip.id)] = chip.title + + doc_id_pool: list[int] = list( + dict.fromkeys( + [ + *(mentioned_document_ids or []), + *(mentioned_surfsense_doc_ids or []), + *chip_doc_ids, + ] + ) + ) + folder_id_pool: list[int] = list( + dict.fromkeys([*(mentioned_folder_ids or []), *chip_folder_ids]) + ) + + if not doc_id_pool and not folder_id_pool: + return ResolvedMentionSet() + + index = await build_path_index(session, search_space_id) + + doc_rows: dict[int, Document] = {} + if doc_id_pool: + result = await session.execute( + select(Document).where( + Document.search_space_id == search_space_id, + Document.id.in_(doc_id_pool), + ) + ) + for row in result.scalars().all(): + doc_rows[row.id] = row + + folder_rows: dict[int, Folder] = {} + if folder_id_pool: + result = await session.execute( + select(Folder).where( + Folder.search_space_id == search_space_id, + Folder.id.in_(folder_id_pool), + ) + ) + for row in result.scalars().all(): + folder_rows[row.id] = row + + resolved: list[ResolvedMention] = [] + accepted_doc_ids: list[int] = [] + accepted_folder_ids: list[int] = [] + + for doc_id in doc_id_pool: + row = doc_rows.get(doc_id) + if row is None: + logger.debug( + "mention_resolver: dropping doc id=%s (not found in space=%s)", + doc_id, + search_space_id, + ) + continue + title = chip_titles_by_id.get(("doc", doc_id), str(row.title or "")) + path = doc_to_virtual_path( + doc_id=row.id, + title=str(row.title or "untitled"), + folder_id=row.folder_id, + index=index, + ) + resolved.append( + ResolvedMention(kind="doc", id=row.id, title=title, virtual_path=path) + ) + accepted_doc_ids.append(row.id) + + for folder_id in folder_id_pool: + row = folder_rows.get(folder_id) + if row is None: + logger.debug( + "mention_resolver: dropping folder id=%s (not found in space=%s)", + folder_id, + search_space_id, + ) + continue + title = chip_titles_by_id.get(("folder", folder_id), str(row.name or "")) + path = _folder_virtual_path(row.id, index.folder_paths) + resolved.append( + ResolvedMention(kind="folder", id=row.id, title=title, virtual_path=path) + ) + accepted_folder_ids.append(row.id) + + token_to_path: list[tuple[str, str]] = [] + seen_tokens: set[str] = set() + for mention in resolved: + if not mention.title: + continue + token = f"@{mention.title}" + if token in seen_tokens: + continue + seen_tokens.add(token) + token_to_path.append((token, mention.virtual_path)) + token_to_path.sort(key=lambda pair: len(pair[0]), reverse=True) + + return ResolvedMentionSet( + mentions=resolved, + token_to_path=token_to_path, + mentioned_document_ids=accepted_doc_ids, + mentioned_folder_ids=accepted_folder_ids, + ) + + +def substitute_in_text(text: str, token_to_path: list[tuple[str, str]]) -> str: + """Replace each ``@title`` token with a backtick-wrapped virtual path. + + Mirrors ``parseMentionSegments`` on the frontend: longest token + first, single forward pass, no regex (titles can contain regex + metacharacters). The substitution is idempotent for already- + substituted text because the backtick-wrapped path no longer + starts with ``@``. + + Empty / no-op cases short-circuit so callers can pass this through + unconditionally without paying for a scan. + """ + if not text or not token_to_path: + return text + + out: list[str] = [] + i = 0 + n = len(text) + while i < n: + matched: tuple[str, str] | None = None + for token, path in token_to_path: + if text.startswith(token, i): + matched = (token, path) + break + if matched is None: + out.append(text[i]) + i += 1 + continue + token, path = matched + out.append(f"`{path}`") + i += len(token) + return "".join(out) + + +__all__ = [ + "ResolvedMention", + "ResolvedMentionSet", + "resolve_mentions", + "substitute_in_text", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py index ee5c1d182..a859220b1 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py +++ b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py @@ -54,6 +54,7 @@ from app.db import ( NATIVE_TO_LEGACY_DOCTYPE, Chunk, Document, + Folder, shielded_async_session, ) from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever @@ -832,6 +833,22 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] mention_ids = list(self.mentioned_document_ids) self.mentioned_document_ids = [] + # Folder mentions live alongside doc mentions on the runtime + # context. They never feed hybrid search (folders aren't + # embedded) — they're surfaced purely as ``[USER-MENTIONED]`` + # priority entries so the agent walks the folder with ``ls`` / + # ``find_documents`` instead of ignoring it. Cloud filesystem + # mode only. + folder_mention_ids: list[int] = [] + if ( + ctx is not None + and getattr(self, "filesystem_mode", FilesystemMode.CLOUD) + == FilesystemMode.CLOUD + ): + ctx_folders = getattr(ctx, "mentioned_folder_ids", None) + if ctx_folders: + folder_mention_ids = list(ctx_folders) + mentioned_results: list[dict[str, Any]] = [] if mention_ids: mentioned_results = await fetch_mentioned_documents( @@ -876,16 +893,21 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] priority, matched_chunk_ids = await self._materialize_priority(merged) + if folder_mention_ids: + folder_entries = await self._materialize_folder_priority(folder_mention_ids) + priority = folder_entries + priority + new_messages = list(messages) insert_at = max(len(new_messages) - 1, 0) new_messages.insert(insert_at, _render_priority_message(priority)) _perf_log.info( - "[kb_priority] completed in %.3fs query=%r priority=%d mentioned=%d", + "[kb_priority] completed in %.3fs query=%r priority=%d mentioned=%d folders=%d", asyncio.get_event_loop().time() - t0, user_text[:80], len(priority), len(mentioned_results), + len(folder_mention_ids), ) return { @@ -894,6 +916,58 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] "messages": new_messages, } + async def _materialize_folder_priority( + self, folder_ids: list[int] + ) -> list[dict[str, Any]]: + """Resolve user-mentioned folder ids to ```` entries. + + Each entry uses the canonical ``/documents/Folder/Sub/`` virtual + path (matching ``KnowledgeTreeMiddleware`` and the agent's + ``ls`` adapter) and is flagged ``mentioned=True`` so the + rendered line carries ``[USER-MENTIONED]``. ``score`` is left + ``None`` so the renderer prints ``n/a`` — folders aren't + ranked, the agent decides which children to read. + """ + if not folder_ids: + return [] + async with shielded_async_session() as session: + index: PathIndex = await build_path_index(session, self.search_space_id) + folder_rows = await session.execute( + select(Folder.id, Folder.name).where( + Folder.search_space_id == self.search_space_id, + Folder.id.in_(folder_ids), + ) + ) + folder_titles: dict[int, str] = { + row.id: row.name for row in folder_rows.all() + } + + entries: list[dict[str, Any]] = [] + seen: set[int] = set() + for folder_id in folder_ids: + if folder_id in seen: + continue + seen.add(folder_id) + base = index.folder_paths.get(folder_id) + if base is None: + logger.debug( + "kb_priority: dropping folder id=%s (missing from path index)", + folder_id, + ) + continue + path = base if base.endswith("/") else f"{base}/" + entries.append( + { + "path": path, + "score": None, + "document_id": None, + "folder_id": folder_id, + "title": folder_titles.get(folder_id, ""), + "mentioned": True, + } + ) + return entries + async def _materialize_priority( self, merged: list[dict[str, Any]] ) -> tuple[list[dict[str, Any]], dict[int, list[int]]]: diff --git a/surfsense_backend/app/agents/new_chat/middleware/scoped_model_fallback.py b/surfsense_backend/app/agents/new_chat/middleware/scoped_model_fallback.py new file mode 100644 index 000000000..99eb2d74a --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/scoped_model_fallback.py @@ -0,0 +1,91 @@ +"""Fallback only on provider/network errors; let programming bugs raise.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from langchain.agents.middleware import ModelFallbackMiddleware + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + from langchain.agents.middleware.types import ModelRequest, ModelResponse + from langchain_core.messages import AIMessage + + +# Matched by class name across the MRO so we don't have to import every +# provider SDK (openai/anthropic/google/...). Extend as new providers ship. +_FALLBACK_ELIGIBLE_NAMES: frozenset[str] = frozenset( + { + "RateLimitError", + "APIStatusError", + "InternalServerError", + "ServiceUnavailableError", + "BadGatewayError", + "GatewayTimeoutError", + "APIConnectionError", + "APITimeoutError", + "ConnectError", + "ConnectTimeout", + "ReadTimeout", + "RemoteProtocolError", + "TimeoutError", + "TimeoutException", + } +) + + +def _is_fallback_eligible(exc: BaseException) -> bool: + return any(cls.__name__ in _FALLBACK_ELIGIBLE_NAMES for cls in type(exc).__mro__) + + +class ScopedModelFallbackMiddleware(ModelFallbackMiddleware): + """Re-raise non-provider exceptions instead of walking the fallback chain.""" + + def wrap_model_call( # type: ignore[override] + self, + request: ModelRequest[Any], + handler: Callable[[ModelRequest[Any]], ModelResponse[Any]], + ) -> ModelResponse[Any] | AIMessage: + last_exception: Exception + try: + return handler(request) + except Exception as e: + if not _is_fallback_eligible(e): + raise + last_exception = e + + for fallback_model in self.models: + try: + return handler(request.override(model=fallback_model)) + except Exception as e: + if not _is_fallback_eligible(e): + raise + last_exception = e + continue + + raise last_exception + + async def awrap_model_call( # type: ignore[override] + self, + request: ModelRequest[Any], + handler: Callable[[ModelRequest[Any]], Awaitable[ModelResponse[Any]]], + ) -> ModelResponse[Any] | AIMessage: + last_exception: Exception + try: + return await handler(request) + except Exception as e: + if not _is_fallback_eligible(e): + raise + last_exception = e + + for fallback_model in self.models: + try: + return await handler(request.override(model=fallback_model)) + except Exception as e: + if not _is_fallback_eligible(e): + raise + last_exception = e + continue + + raise last_exception diff --git a/surfsense_backend/app/indexing_pipeline/document_chunker.py b/surfsense_backend/app/indexing_pipeline/document_chunker.py index 4f3c698ef..6ae81b7a8 100644 --- a/surfsense_backend/app/indexing_pipeline/document_chunker.py +++ b/surfsense_backend/app/indexing_pipeline/document_chunker.py @@ -1,5 +1,15 @@ +import re + from app.config import config +# Regex that matches a Markdown table block (header + separator + one or more rows) +# A table block starts with a | at the beginning of a line and ends when a +# non-table line (or end of string) is encountered. +_TABLE_BLOCK_RE = re.compile( + r"(?:(?:^|\n)(?=[ \t]*\|)(?:[ \t]*\|[^\n]*\n)+)", + re.MULTILINE, +) + def chunk_text(text: str, use_code_chunker: bool = False) -> list[str]: """Chunk a text string using the configured chunker and return the chunk texts.""" @@ -7,3 +17,43 @@ def chunk_text(text: str, use_code_chunker: bool = False) -> list[str]: config.code_chunker_instance if use_code_chunker else config.chunker_instance ) return [c.text for c in chunker.chunk(text)] + + +def chunk_text_hybrid(text: str) -> list[str]: + """Table-aware chunker that prevents Markdown tables from being split mid-row. + + Algorithm: + 1. Scan the document for Markdown table blocks. + 2. Each table block is emitted as a single, unmodified chunk so that its + header, separator row, and data rows always stay together. + 3. The non-table prose segments between (and around) tables are passed through + the normal ``chunk_text`` chunker and their sub-chunks are interleaved in + document order. + + This ensures that table data is never sliced in the middle by the token-based + chunker, which would otherwise produce garbled rows that are useless for RAG. + + Fixes #1334. + """ + chunks: list[str] = [] + cursor = 0 + + for match in _TABLE_BLOCK_RE.finditer(text): + # Prose before this table + prose = text[cursor : match.start()].strip() + if prose: + chunks.extend(chunk_text(prose)) + + # The table itself is kept as one indivisible chunk + table_block = match.group(0).strip() + if table_block: + chunks.append(table_block) + + cursor = match.end() + + # Remaining prose after the last table (or entire text if no tables) + trailing = text[cursor:].strip() + if trailing: + chunks.extend(chunk_text(trailing)) + + return chunks diff --git a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py index e6b2458f3..2339647ea 100644 --- a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py +++ b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py @@ -19,7 +19,7 @@ from app.db import ( DocumentType, ) from app.indexing_pipeline.connector_document import ConnectorDocument -from app.indexing_pipeline.document_chunker import chunk_text +from app.indexing_pipeline.document_chunker import chunk_text, chunk_text_hybrid from app.indexing_pipeline.document_embedder import embed_texts from app.indexing_pipeline.document_hashing import ( compute_content_hash, @@ -387,11 +387,19 @@ class IndexingPipelineService: ) t_step = time.perf_counter() - chunk_texts = await asyncio.to_thread( - chunk_text, - connector_doc.source_markdown, - use_code_chunker=connector_doc.should_use_code_chunker, - ) + if connector_doc.should_use_code_chunker: + chunk_texts = await asyncio.to_thread( + chunk_text, + connector_doc.source_markdown, + use_code_chunker=True, + ) + else: + # Use the table-aware hybrid chunker so Markdown tables are not + # split mid-row (see issue #1334). + chunk_texts = await asyncio.to_thread( + chunk_text_hybrid, + connector_doc.source_markdown, + ) texts_to_embed = [content, *chunk_texts] embeddings = await asyncio.to_thread(embed_texts, texts_to_embed) diff --git a/surfsense_backend/app/routes/documents_routes.py b/surfsense_backend/app/routes/documents_routes.py index f1ca3b6bf..96c5d2344 100644 --- a/surfsense_backend/app/routes/documents_routes.py +++ b/surfsense_backend/app/routes/documents_routes.py @@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload +from app.agents.new_chat.path_resolver import virtual_path_to_doc from app.db import ( Chunk, Document, @@ -752,7 +753,24 @@ async def get_document_by_virtual_path( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Resolve a knowledge-base document id by exact virtual path.""" + """Resolve a knowledge-base document by its agent-facing virtual path. + + The agent renders every document under ``/documents/...`` with a + ``.xml`` extension appended via ``safe_filename`` (so a PDF titled + ``2025-W2.pdf`` becomes ``/documents/2025-W2.pdf.xml``). When the user + clicks that path in an answer, this endpoint must round-trip back to + the underlying ``Document`` row regardless of its type — agent-created + NOTE docs (which carry ``virtual_path`` in metadata), uploaded PDFs, + and connector docs all flow through here. + + Resolution is delegated to :func:`virtual_path_to_doc`, the single + source of truth that handles: + + * ``unique_identifier_hash`` lookup (agent NOTE fast path) + * ``" ().xml"`` disambiguation suffixes + * ``.xml`` extension stripping for title-based fallback + * ``safe_filename`` round-trip for connector titles with lossy chars + """ try: await check_permission( session, @@ -762,24 +780,19 @@ async def get_document_by_virtual_path( "You don't have permission to read documents in this search space", ) - result = await session.execute( - select( - Document.id, - Document.title, - Document.document_type, - ).filter( - Document.search_space_id == search_space_id, - Document.document_metadata["virtual_path"].as_string() == virtual_path, - ) + document = await virtual_path_to_doc( + session, + search_space_id=search_space_id, + virtual_path=virtual_path, ) - row = result.first() - if row is None: + if document is None: raise HTTPException(status_code=404, detail="Document not found") return DocumentTitleRead( - id=row.id, - title=row.title, - document_type=row.document_type, + id=document.id, + title=document.title, + document_type=document.document_type, + folder_id=document.folder_id, ) except HTTPException: raise diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index ad96654f5..44fc1c392 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -71,7 +71,10 @@ from app.schemas.new_chat import ( TokenUsageSummary, TurnStatusResponse, ) -from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat +from app.tasks.chat.stream_new_chat import ( + stream_new_chat, + stream_resume_chat, +) from app.users import current_active_user from app.utils.perf import get_perf_logger from app.utils.rbac import check_permission @@ -1778,6 +1781,7 @@ async def handle_new_chat( llm_config_id=llm_config_id, mentioned_document_ids=request.mentioned_document_ids, mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids, + mentioned_folder_ids=request.mentioned_folder_ids, mentioned_documents=mentioned_documents_payload, needs_history_bootstrap=thread.needs_history_bootstrap, thread_visibility=thread.visibility, @@ -2263,6 +2267,7 @@ async def regenerate_response( llm_config_id=llm_config_id, mentioned_document_ids=request.mentioned_document_ids, mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids, + mentioned_folder_ids=request.mentioned_folder_ids, mentioned_documents=mentioned_documents_payload, checkpoint_id=target_checkpoint_id, needs_history_bootstrap=thread.needs_history_bootstrap, diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index 95d183433..c809d6235 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -201,18 +201,34 @@ class NewChatUserImagePart(BaseModel): class MentionedDocumentInfo(BaseModel): - """Display metadata for a single ``@``-mentioned document. + """Display metadata for a single ``@``-mention chip. - The full triple ``{id, title, document_type}`` is forwarded by the - frontend mention chip so the server can embed it in the persisted - user message ``ContentPart[]`` (single ``mentioned-documents`` part). - The history loader then renders the chips on reload without an extra + Carries either a knowledge-base document or a knowledge-base folder + (discriminated by ``kind``). The full triple + ``{id, title, document_type}`` is forwarded by the frontend mention + chip so the server can embed it in the persisted user message + ``ContentPart[]`` (single ``mentioned-documents`` part). The + history loader then renders the chips on reload without an extra fetch — mirrors the pre-refactor frontend ``persistUserTurn`` shape. + + ``kind`` defaults to ``"doc"`` so legacy clients and persisted rows + that predate folder mentions deserialise unchanged. """ id: int title: str = Field(..., min_length=1, max_length=500) document_type: str = Field(..., min_length=1, max_length=100) + kind: Literal["doc", "folder"] = Field( + default="doc", + description=( + "Discriminator for the chip's referent: ``doc`` is a " + "knowledge-base ``Document`` row, ``folder`` is a " + "knowledge-base ``Folder`` row. Folders carry the sentinel " + "``document_type='FOLDER'`` to keep the frontend dedup key " + "``(kind:document_type:id)`` from colliding doc and folder " + "ids that happen to share an integer value." + ), + ) class NewChatRequest(BaseModel): @@ -228,15 +244,26 @@ class NewChatRequest(BaseModel): mentioned_surfsense_doc_ids: list[int] | None = ( None # Optional SurfSense documentation IDs mentioned with @ in the chat ) + mentioned_folder_ids: list[int] | None = Field( + default=None, + description=( + "Optional knowledge-base folder IDs the user mentioned with " + "@. Resolved to virtual paths (``/documents/.../``) by " + "``mention_resolver`` and surfaced to the agent via " + "(a) backtick-wrapped substitution in ``user_query`` and " + "(b) a ``[USER-MENTIONED]`` entry in ````. " + "The agent's ``ls`` tool can then walk the folder itself." + ), + ) mentioned_documents: list[MentionedDocumentInfo] | None = Field( default=None, description=( - "Display metadata (id, title, document_type) for every " - "@-mentioned document. Persisted as a ``mentioned-documents`` " - "ContentPart on the user message so reload renders chips " - "without an extra fetch. Optional and additive — when None " - "the user message is persisted without a mentioned-documents " - "part." + "Display metadata (id, title, document_type, kind) for every " + "@-mention chip — both documents and folders. Persisted as a " + "``mentioned-documents`` ContentPart on the user message so " + "reload renders chips without an extra fetch. Optional and " + "additive — when None the user message is persisted without " + "a mentioned-documents part." ), ) disabled_tools: list[str] | None = ( @@ -290,14 +317,22 @@ class RegenerateRequest(BaseModel): ) mentioned_document_ids: list[int] | None = None mentioned_surfsense_doc_ids: list[int] | None = None + mentioned_folder_ids: list[int] | None = Field( + default=None, + description=( + "Optional knowledge-base folder IDs the user mentioned with " + "@ on the edited user turn. Only used when ``user_query`` is " + "non-None (edit). Mirrors ``NewChatRequest.mentioned_folder_ids``." + ), + ) mentioned_documents: list[MentionedDocumentInfo] | None = Field( default=None, description=( - "Display metadata (id, title, document_type) for every " - "@-mentioned document on the edited user turn. Only used " - "when ``user_query`` is non-None (edit). Persisted as a " - "``mentioned-documents`` ContentPart on the new user " - "message. None means no chip metadata." + "Display metadata (id, title, document_type, kind) for every " + "@-mention chip on the edited user turn — both documents and " + "folders. Only used when ``user_query`` is non-None (edit). " + "Persisted as a ``mentioned-documents`` ContentPart on the " + "new user message. None means no chip metadata." ), ) disabled_tools: list[str] | None = None @@ -373,6 +408,16 @@ class ResumeRequest(BaseModel): filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud" client_platform: Literal["web", "desktop"] = "web" local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None + mentioned_folder_ids: list[int] | None = Field( + default=None, + description=( + "Forwarded for symmetry with /new_chat and /regenerate. " + "Resume reuses the original interrupted user turn so this " + "field is informational only — the originating turn's " + "folder mentions already shaped the priority hints baked " + "into the agent's checkpoint." + ), + ) mentioned_documents: list[MentionedDocumentInfo] | None = Field( default=None, description=( @@ -380,7 +425,7 @@ class ResumeRequest(BaseModel): "/regenerate. Resume reuses the original interrupted user " "turn so the server does not write a new user message. " "Currently unused but accepted to keep request bodies " - "uniform across the three streaming entrypoints." + "uniform across new-message, regenerate, and resume stream routes." ), ) diff --git a/surfsense_backend/app/services/composio_service.py b/surfsense_backend/app/services/composio_service.py index edfab1d15..d73a0d4ce 100644 --- a/surfsense_backend/app/services/composio_service.py +++ b/surfsense_backend/app/services/composio_service.py @@ -1027,6 +1027,505 @@ class ComposioService: logger.error(f"Failed to list Calendar events: {e!s}") return [], str(e) + @staticmethod + def _unwrap_response_data(data: Any) -> Any: + """Composio responses often nest the meaningful payload under + ``data.data.response_data``. Walk that envelope safely and return + whichever inner dict actually has the result keys.""" + if not isinstance(data, dict): + return data + inner = data.get("data", data) + if isinstance(inner, dict): + return inner.get("response_data", inner) + return inner + + @staticmethod + def _split_email_csv(value: str | None) -> list[str] | None: + """Tools accept comma-separated cc/bcc strings; Composio expects an array.""" + if not value: + return None + addrs = [e.strip() for e in value.split(",") if e.strip()] + return addrs or None + + # ===== Gmail write methods ===== + + async def send_gmail_email( + self, + connected_account_id: str, + entity_id: str, + to: str, + subject: str, + body: str, + cc: str | None = None, + bcc: str | None = None, + is_html: bool = False, + ) -> tuple[str | None, str | None, str | None]: + """Send a Gmail message via the Composio ``GMAIL_SEND_EMAIL`` toolkit. + + Returns: + Tuple of (message_id, thread_id, error). On success ``error`` is + None and at least one of the IDs is populated when Composio + returns them; on failure both IDs are None. + """ + try: + params: dict[str, Any] = { + "recipient_email": to, + "subject": subject, + "body": body, + "is_html": is_html, + } + if cc: + cc_list = self._split_email_csv(cc) + if cc_list: + params["cc"] = cc_list + if bcc: + bcc_list = self._split_email_csv(bcc) + if bcc_list: + params["bcc"] = bcc_list + + result = await self.execute_tool( + connected_account_id=connected_account_id, + tool_name="GMAIL_SEND_EMAIL", + params=params, + entity_id=entity_id, + ) + if not result.get("success"): + return None, None, result.get("error", "Unknown error") + + payload = self._unwrap_response_data(result.get("data", {})) + message_id = None + thread_id = None + if isinstance(payload, dict): + message_id = ( + payload.get("id") + or payload.get("message_id") + or payload.get("messageId") + ) + thread_id = payload.get("threadId") or payload.get("thread_id") + return message_id, thread_id, None + except Exception as e: + logger.error(f"Failed to send Gmail email: {e!s}") + return None, None, str(e) + + async def create_gmail_draft( + self, + connected_account_id: str, + entity_id: str, + to: str, + subject: str, + body: str, + cc: str | None = None, + bcc: str | None = None, + is_html: bool = False, + ) -> tuple[str | None, str | None, str | None, str | None]: + """Create a Gmail draft via the Composio ``GMAIL_CREATE_EMAIL_DRAFT`` toolkit. + + Returns: + Tuple of (draft_id, message_id, thread_id, error). On success + ``error`` is None and ``draft_id`` is populated. + """ + try: + params: dict[str, Any] = { + "recipient_email": to, + "subject": subject, + "body": body, + "is_html": is_html, + } + cc_list = self._split_email_csv(cc) + if cc_list: + params["cc"] = cc_list + bcc_list = self._split_email_csv(bcc) + if bcc_list: + params["bcc"] = bcc_list + + result = await self.execute_tool( + connected_account_id=connected_account_id, + tool_name="GMAIL_CREATE_EMAIL_DRAFT", + params=params, + entity_id=entity_id, + ) + if not result.get("success"): + return None, None, None, result.get("error", "Unknown error") + + payload = self._unwrap_response_data(result.get("data", {})) + draft_id = None + message_id = None + thread_id = None + if isinstance(payload, dict): + draft_id = payload.get("id") or payload.get("draft_id") + draft_message = payload.get("message") or {} + if isinstance(draft_message, dict): + message_id = draft_message.get("id") or draft_message.get( + "message_id" + ) + thread_id = draft_message.get("threadId") or draft_message.get( + "thread_id" + ) + if message_id is None: + message_id = payload.get("message_id") or payload.get("messageId") + if thread_id is None: + thread_id = payload.get("thread_id") or payload.get("threadId") + return draft_id, message_id, thread_id, None + except Exception as e: + logger.error(f"Failed to create Gmail draft: {e!s}") + return None, None, None, str(e) + + async def update_gmail_draft( + self, + connected_account_id: str, + entity_id: str, + draft_id: str, + to: str | None = None, + subject: str | None = None, + body: str | None = None, + cc: str | None = None, + bcc: str | None = None, + is_html: bool = False, + ) -> tuple[str | None, str | None, str | None]: + """Update an existing Gmail draft via ``GMAIL_UPDATE_DRAFT``. + + Returns: + Tuple of (draft_id, message_id, error). + """ + try: + params: dict[str, Any] = { + "draft_id": draft_id, + "is_html": is_html, + } + if to: + params["recipient_email"] = to + if subject is not None: + params["subject"] = subject + if body is not None: + params["body"] = body + cc_list = self._split_email_csv(cc) + if cc_list: + params["cc"] = cc_list + bcc_list = self._split_email_csv(bcc) + if bcc_list: + params["bcc"] = bcc_list + + result = await self.execute_tool( + connected_account_id=connected_account_id, + tool_name="GMAIL_UPDATE_DRAFT", + params=params, + entity_id=entity_id, + ) + if not result.get("success"): + return None, None, result.get("error", "Unknown error") + + payload = self._unwrap_response_data(result.get("data", {})) + new_draft_id = draft_id + message_id = None + if isinstance(payload, dict): + new_draft_id = payload.get("id") or payload.get("draft_id") or draft_id + draft_message = payload.get("message") or {} + if isinstance(draft_message, dict): + message_id = draft_message.get("id") or draft_message.get( + "message_id" + ) + if message_id is None: + message_id = payload.get("message_id") or payload.get("messageId") + return new_draft_id, message_id, None + except Exception as e: + logger.error(f"Failed to update Gmail draft: {e!s}") + return None, None, str(e) + + async def trash_gmail_message( + self, + connected_account_id: str, + entity_id: str, + message_id: str, + ) -> str | None: + """Move a Gmail message to trash via ``GMAIL_MOVE_TO_TRASH``. + + Returns the error message on failure, ``None`` on success. + """ + try: + result = await self.execute_tool( + connected_account_id=connected_account_id, + tool_name="GMAIL_MOVE_TO_TRASH", + params={"message_id": message_id}, + entity_id=entity_id, + ) + if not result.get("success"): + return result.get("error", "Unknown error") + return None + except Exception as e: + logger.error(f"Failed to trash Gmail message: {e!s}") + return str(e) + + # ===== Google Calendar write methods ===== + + async def create_calendar_event( + self, + connected_account_id: str, + entity_id: str, + summary: str, + start_datetime: str, + end_datetime: str, + timezone: str | None = None, + description: str | None = None, + location: str | None = None, + attendees: list[str] | None = None, + calendar_id: str = "primary", + ) -> tuple[str | None, str | None, str | None]: + """Create a Google Calendar event via ``GOOGLECALENDAR_CREATE_EVENT``. + + Composio strips trailing timezone info on ``start_datetime`` / + ``end_datetime`` and uses the ``timezone`` field as the IANA name, + so callers may pass ISO 8601 strings with or without offsets. + + Returns: + Tuple of (event_id, html_link, error). + """ + try: + params: dict[str, Any] = { + "summary": summary, + "start_datetime": start_datetime, + "end_datetime": end_datetime, + "calendar_id": calendar_id, + } + if timezone: + params["timezone"] = timezone + if description: + params["description"] = description + if location: + params["location"] = location + if attendees: + params["attendees"] = [a for a in attendees if a] + + result = await self.execute_tool( + connected_account_id=connected_account_id, + tool_name="GOOGLECALENDAR_CREATE_EVENT", + params=params, + entity_id=entity_id, + ) + if not result.get("success"): + return None, None, result.get("error", "Unknown error") + + payload = self._unwrap_response_data(result.get("data", {})) + event_id = None + html_link = None + if isinstance(payload, dict): + event_id = payload.get("id") or payload.get("event_id") + html_link = payload.get("htmlLink") or payload.get("html_link") + return event_id, html_link, None + except Exception as e: + logger.error(f"Failed to create Calendar event: {e!s}") + return None, None, str(e) + + async def update_calendar_event( + self, + connected_account_id: str, + entity_id: str, + event_id: str, + summary: str | None = None, + start_time: str | None = None, + end_time: str | None = None, + timezone: str | None = None, + description: str | None = None, + location: str | None = None, + attendees: list[str] | None = None, + calendar_id: str = "primary", + ) -> tuple[str | None, str | None, str | None]: + """Patch an existing Google Calendar event via ``GOOGLECALENDAR_PATCH_EVENT``. + + Uses PATCH (not PUT) semantics so omitted fields are preserved. + + Returns: + Tuple of (event_id, html_link, error). + """ + try: + params: dict[str, Any] = { + "event_id": event_id, + "calendar_id": calendar_id, + } + if summary is not None: + params["summary"] = summary + if start_time is not None: + params["start_time"] = start_time + if end_time is not None: + params["end_time"] = end_time + if timezone: + params["timezone"] = timezone + if description is not None: + params["description"] = description + if location is not None: + params["location"] = location + if attendees is not None: + params["attendees"] = [a for a in attendees if a] + + result = await self.execute_tool( + connected_account_id=connected_account_id, + tool_name="GOOGLECALENDAR_PATCH_EVENT", + params=params, + entity_id=entity_id, + ) + if not result.get("success"): + return None, None, result.get("error", "Unknown error") + + payload = self._unwrap_response_data(result.get("data", {})) + new_event_id = event_id + html_link = None + if isinstance(payload, dict): + new_event_id = payload.get("id") or payload.get("event_id") or event_id + html_link = payload.get("htmlLink") or payload.get("html_link") + return new_event_id, html_link, None + except Exception as e: + logger.error(f"Failed to patch Calendar event: {e!s}") + return None, None, str(e) + + async def delete_calendar_event( + self, + connected_account_id: str, + entity_id: str, + event_id: str, + calendar_id: str = "primary", + ) -> str | None: + """Delete a Google Calendar event via ``GOOGLECALENDAR_DELETE_EVENT``. + + Returns the error message on failure, ``None`` on success (idempotent + on already-deleted events). + """ + try: + result = await self.execute_tool( + connected_account_id=connected_account_id, + tool_name="GOOGLECALENDAR_DELETE_EVENT", + params={ + "event_id": event_id, + "calendar_id": calendar_id, + }, + entity_id=entity_id, + ) + if not result.get("success"): + return result.get("error", "Unknown error") + return None + except Exception as e: + logger.error(f"Failed to delete Calendar event: {e!s}") + return str(e) + + # ===== Google Drive write methods ===== + + @staticmethod + def _drive_web_view_link(file_id: str, mime_type: str | None) -> str: + """Synthesize a Google Drive ``webViewLink`` from id + mimeType. + + Composio's ``GOOGLEDRIVE_CREATE_FILE_FROM_TEXT`` returns flat + metadata (id, name, mimeType) but does not always include a + ``webViewLink``. We rebuild the canonical UI URL based on the + Workspace MIME type so callers can keep using a single field. + """ + if not file_id: + return "" + mt = (mime_type or "").lower() + if mt == "application/vnd.google-apps.document": + return f"https://docs.google.com/document/d/{file_id}/edit" + if mt == "application/vnd.google-apps.spreadsheet": + return f"https://docs.google.com/spreadsheets/d/{file_id}/edit" + if mt == "application/vnd.google-apps.presentation": + return f"https://docs.google.com/presentation/d/{file_id}/edit" + if mt == "application/vnd.google-apps.folder": + return f"https://drive.google.com/drive/folders/{file_id}" + return f"https://drive.google.com/file/d/{file_id}/view" + + async def create_drive_file_from_text( + self, + connected_account_id: str, + entity_id: str, + name: str, + mime_type: str, + content: str | None = None, + parent_id: str | None = None, + ) -> tuple[dict[str, Any] | None, str | None]: + """Create a Google Drive file from text via ``GOOGLEDRIVE_CREATE_FILE_FROM_TEXT``. + + Composio's tool requires ``text_content`` even for "empty" files; + an empty string is accepted. Native Workspace types (Docs, Sheets) + are produced by setting ``mime_type`` to the Google Apps MIME, and + Drive auto-converts the text payload (e.g. CSV → Sheet). + + Returns: + Tuple of (file_meta, error). ``file_meta`` keys: + ``id``, ``name``, ``mimeType``, ``webViewLink``. + """ + try: + params: dict[str, Any] = { + "file_name": name, + "mime_type": mime_type, + "text_content": content if content is not None else "", + } + if parent_id: + params["parent_id"] = parent_id + + result = await self.execute_tool( + connected_account_id=connected_account_id, + tool_name="GOOGLEDRIVE_CREATE_FILE_FROM_TEXT", + params=params, + entity_id=entity_id, + ) + if not result.get("success"): + return None, result.get("error", "Unknown error") + + payload = self._unwrap_response_data(result.get("data", {})) + file_id: str | None = None + file_name: str | None = name + mime: str | None = mime_type + web_view_link: str | None = None + + if isinstance(payload, dict): + file_id = ( + payload.get("id") or payload.get("file_id") or payload.get("fileId") + ) + file_name = payload.get("name") or payload.get("file_name") or name + mime = payload.get("mimeType") or payload.get("mime_type") or mime_type + web_view_link = payload.get("webViewLink") or payload.get( + "web_view_link" + ) + + if not file_id: + return None, "Composio response did not include a file id" + + if not web_view_link: + web_view_link = self._drive_web_view_link(file_id, mime) + + return ( + { + "id": file_id, + "name": file_name, + "mimeType": mime, + "webViewLink": web_view_link, + }, + None, + ) + except Exception as e: + logger.error(f"Failed to create Drive file: {e!s}") + return None, str(e) + + async def trash_drive_file( + self, + connected_account_id: str, + entity_id: str, + file_id: str, + ) -> str | None: + """Move a Google Drive file to trash via ``GOOGLEDRIVE_TRASH_FILE``. + + Returns the error message on failure, ``None`` on success. + """ + try: + result = await self.execute_tool( + connected_account_id=connected_account_id, + tool_name="GOOGLEDRIVE_TRASH_FILE", + params={"file_id": file_id}, + entity_id=entity_id, + ) + if not result.get("success"): + return result.get("error", "Unknown error") + return None + except Exception as e: + logger.error(f"Failed to trash Drive file: {e!s}") + return str(e) + # ===== User Info Methods ===== async def get_connected_account_email( diff --git a/surfsense_backend/app/services/new_streaming_service.py b/surfsense_backend/app/services/new_streaming_service.py index 55129668c..ba0cb8753 100644 --- a/surfsense_backend/app/services/new_streaming_service.py +++ b/surfsense_backend/app/services/new_streaming_service.py @@ -456,6 +456,8 @@ class VercelStreamingService: title: str, status: str = "in_progress", items: list[str] | None = None, + *, + metadata: dict[str, Any] | None = None, ) -> str: """ Format a thinking step for chain-of-thought display (SurfSense specific). @@ -469,15 +471,15 @@ class VercelStreamingService: Returns: str: SSE formatted thinking step data part """ - return self.format_data( - "thinking-step", - { - "id": step_id, - "title": title, - "status": status, - "items": items or [], - }, - ) + payload: dict[str, Any] = { + "id": step_id, + "title": title, + "status": status, + "items": items or [], + } + if metadata: + payload["metadata"] = metadata + return self.format_data("thinking-step", payload) def format_thread_title_update(self, thread_id: int, title: str) -> str: """ @@ -601,6 +603,7 @@ class VercelStreamingService: tool_name: str, *, langchain_tool_call_id: str | None = None, + metadata: dict[str, Any] | None = None, ) -> str: """ Format the start of tool input streaming. @@ -608,15 +611,14 @@ class VercelStreamingService: Args: tool_call_id: The unique tool call identifier. May be EITHER the synthetic ``call_`` id derived from LangGraph - ``run_id`` (legacy / ``SURFSENSE_ENABLE_STREAM_PARITY_V2`` - OFF, or the unmatched-fallback path under parity_v2) OR - the authoritative LangChain ``tool_call.id`` (parity_v2 - path: when the provider streams ``tool_call_chunks`` we - register the ``index`` and reuse the lc-id as the card - id so live ``tool-input-delta`` events can be routed - without a downstream join). Either way, the same id is - preserved across ``tool-input-start`` / ``-delta`` / - ``-available`` / ``tool-output-available`` for one call. + ``run_id`` (unmatched chunk fallback when no ``index`` was + registered) OR the authoritative LangChain ``tool_call.id`` + (when the provider streams ``tool_call_chunks`` we register + the ``index`` and reuse the lc-id as the card id so live + ``tool-input-delta`` events route without a downstream join). + Either way, the same id is preserved across + ``tool-input-start`` / ``-delta`` / ``-available`` / + ``tool-output-available`` for one call. tool_name: The name of the tool being called. langchain_tool_call_id: Optional authoritative LangChain ``tool_call.id``. When set, surfaces as @@ -636,6 +638,8 @@ class VercelStreamingService: } if langchain_tool_call_id: payload["langchainToolCallId"] = langchain_tool_call_id + if metadata: + payload["metadata"] = metadata return self._format_sse(payload) def format_tool_input_delta(self, tool_call_id: str, input_text_delta: str) -> str: @@ -667,6 +671,7 @@ class VercelStreamingService: input_data: dict[str, Any], *, langchain_tool_call_id: str | None = None, + metadata: dict[str, Any] | None = None, ) -> str: """ Format the completion of tool input. @@ -692,6 +697,8 @@ class VercelStreamingService: } if langchain_tool_call_id: payload["langchainToolCallId"] = langchain_tool_call_id + if metadata: + payload["metadata"] = metadata return self._format_sse(payload) def format_tool_output_available( @@ -700,6 +707,7 @@ class VercelStreamingService: output: Any, *, langchain_tool_call_id: str | None = None, + metadata: dict[str, Any] | None = None, ) -> str: """ Format tool execution output. @@ -726,6 +734,8 @@ class VercelStreamingService: } if langchain_tool_call_id: payload["langchainToolCallId"] = langchain_tool_call_id + if metadata: + payload["metadata"] = metadata return self._format_sse(payload) # ========================================================================= diff --git a/surfsense_backend/app/services/streaming/__init__.py b/surfsense_backend/app/services/streaming/__init__.py new file mode 100644 index 000000000..3ec9b9cf1 --- /dev/null +++ b/surfsense_backend/app/services/streaming/__init__.py @@ -0,0 +1,20 @@ +"""Single-responsibility split of the streaming SSE protocol. + +Layout: +* ``envelope/`` - SSE wire framing + ID generators +* ``emitter/`` - identity of the agent that emitted an event + runtime registry +* ``events/`` - one module per SSE event family +* ``service.py`` - composition root used when emitting chat SSE +* ``interrupt_correlation.py`` - id-aware lookup over LangGraph state + +Naming on the wire: +* AI SDK protocol fields keep their existing camelCase + (``toolCallId``, ``messageId``, ``inputTextDelta``, ``langchainToolCallId``). +* Every SurfSense-added field uses ``snake_case``, including the + top-level ``emitted_by`` envelope and all inner ``data`` payloads. + +Production chat uses ``app.services.new_streaming_service`` from +``app.tasks.chat.stream_new_chat`` and related routes. +""" + +from __future__ import annotations diff --git a/surfsense_backend/app/services/streaming/emitter/__init__.py b/surfsense_backend/app/services/streaming/emitter/__init__.py new file mode 100644 index 000000000..7814894f3 --- /dev/null +++ b/surfsense_backend/app/services/streaming/emitter/__init__.py @@ -0,0 +1,29 @@ +"""Identity of the agent that emitted a streamed event. + +The wire field is ``emitted_by``; the Python identity is :class:`Emitter`. +``EmitterRegistry`` resolves which emitter owns a LangGraph event, with +LangGraph's own namespace metadata as the primary key and a parent_ids +walk as a fallback for cases where context vars don't propagate. +""" + +from __future__ import annotations + +from .emitter import ( + MAIN_EMITTER, + Emitter, + EmitterLevel, + attach_emitted_by, + main_emitter, + subagent_emitter, +) +from .registry import EmitterRegistry + +__all__ = [ + "MAIN_EMITTER", + "Emitter", + "EmitterLevel", + "EmitterRegistry", + "attach_emitted_by", + "main_emitter", + "subagent_emitter", +] diff --git a/surfsense_backend/app/services/streaming/emitter/emitter.py b/surfsense_backend/app/services/streaming/emitter/emitter.py new file mode 100644 index 000000000..08f625a69 --- /dev/null +++ b/surfsense_backend/app/services/streaming/emitter/emitter.py @@ -0,0 +1,61 @@ +"""Identity payload describing which agent produced a stream event.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Literal + +EmitterLevel = Literal["main", "subagent"] + + +@dataclass(frozen=True) +class Emitter: + level: EmitterLevel + subagent_type: str | None = None + subagent_run_id: str | None = None + parent_tool_call_id: str | None = None + extra: dict[str, Any] = field(default_factory=dict) + + def to_payload(self) -> dict[str, Any]: + payload: dict[str, Any] = {"level": self.level} + if self.subagent_type is not None: + payload["subagent_type"] = self.subagent_type + if self.subagent_run_id is not None: + payload["subagent_run_id"] = self.subagent_run_id + if self.parent_tool_call_id is not None: + payload["parent_tool_call_id"] = self.parent_tool_call_id + if self.extra: + payload.update(self.extra) + return payload + + +MAIN_EMITTER = Emitter(level="main") + + +def main_emitter() -> Emitter: + return MAIN_EMITTER + + +def subagent_emitter( + *, + subagent_type: str, + subagent_run_id: str, + parent_tool_call_id: str | None = None, + extra: dict[str, Any] | None = None, +) -> Emitter: + return Emitter( + level="subagent", + subagent_type=subagent_type, + subagent_run_id=subagent_run_id, + parent_tool_call_id=parent_tool_call_id, + extra=dict(extra or {}), + ) + + +def attach_emitted_by( + payload: dict[str, Any], emitter: Emitter | None +) -> dict[str, Any]: + if emitter is None: + return payload + payload["emitted_by"] = emitter.to_payload() + return payload diff --git a/surfsense_backend/app/services/streaming/emitter/registry.py b/surfsense_backend/app/services/streaming/emitter/registry.py new file mode 100644 index 000000000..066689691 --- /dev/null +++ b/surfsense_backend/app/services/streaming/emitter/registry.py @@ -0,0 +1,49 @@ +"""Resolve which agent owns a streamed event from its LangGraph run lineage.""" + +from __future__ import annotations + +from collections.abc import Iterable + +from .emitter import Emitter, main_emitter + + +class EmitterRegistry: + def __init__(self) -> None: + self._by_run_id: dict[str, Emitter] = {} + + def register(self, run_id: str, emitter: Emitter) -> None: + if not run_id: + return + self._by_run_id[run_id] = emitter + + def unregister(self, run_id: str) -> Emitter | None: + if not run_id: + return None + return self._by_run_id.pop(run_id, None) + + def get(self, run_id: str | None) -> Emitter | None: + if not run_id: + return None + return self._by_run_id.get(run_id) + + def resolve( + self, + *, + run_id: str | None, + parent_ids: Iterable[str] | None, + ) -> Emitter: + own = self.get(run_id) + if own is not None: + return own + if parent_ids: + for ancestor in reversed(list(parent_ids)): + emitter = self.get(ancestor) + if emitter is not None: + return emitter + return main_emitter() + + def has_active_subagents(self) -> bool: + return any(emitter.level == "subagent" for emitter in self._by_run_id.values()) + + def clear(self) -> None: + self._by_run_id.clear() diff --git a/surfsense_backend/app/services/streaming/envelope/__init__.py b/surfsense_backend/app/services/streaming/envelope/__init__.py new file mode 100644 index 000000000..862e84c8d --- /dev/null +++ b/surfsense_backend/app/services/streaming/envelope/__init__.py @@ -0,0 +1,23 @@ +"""Wire framing layer.""" + +from __future__ import annotations + +from .identifiers import ( + generate_message_id, + generate_reasoning_id, + generate_subagent_run_id, + generate_text_id, + generate_tool_call_id, +) +from .sse import format_done, format_sse, get_response_headers + +__all__ = [ + "format_done", + "format_sse", + "generate_message_id", + "generate_reasoning_id", + "generate_subagent_run_id", + "generate_text_id", + "generate_tool_call_id", + "get_response_headers", +] diff --git a/surfsense_backend/app/services/streaming/envelope/identifiers.py b/surfsense_backend/app/services/streaming/envelope/identifiers.py new file mode 100644 index 000000000..2fdd6ff09 --- /dev/null +++ b/surfsense_backend/app/services/streaming/envelope/identifiers.py @@ -0,0 +1,25 @@ +"""Prefixed UUID generators for stream parts.""" + +from __future__ import annotations + +import uuid + + +def generate_message_id() -> str: + return f"msg_{uuid.uuid4().hex}" + + +def generate_text_id() -> str: + return f"text_{uuid.uuid4().hex}" + + +def generate_reasoning_id() -> str: + return f"reasoning_{uuid.uuid4().hex}" + + +def generate_tool_call_id() -> str: + return f"call_{uuid.uuid4().hex}" + + +def generate_subagent_run_id() -> str: + return f"subagent_{uuid.uuid4().hex}" diff --git a/surfsense_backend/app/services/streaming/envelope/sse.py b/surfsense_backend/app/services/streaming/envelope/sse.py new file mode 100644 index 000000000..508fc1b1c --- /dev/null +++ b/surfsense_backend/app/services/streaming/envelope/sse.py @@ -0,0 +1,25 @@ +"""Server-Sent-Events wire framing.""" + +from __future__ import annotations + +import json +from typing import Any + + +def format_sse(data: Any) -> str: + if isinstance(data, str): + return f"data: {data}\n\n" + return f"data: {json.dumps(data)}\n\n" + + +def format_done() -> str: + return "data: [DONE]\n\n" + + +def get_response_headers() -> dict[str, str]: + return { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "x-vercel-ai-ui-message-stream": "v1", + } diff --git a/surfsense_backend/app/services/streaming/events/__init__.py b/surfsense_backend/app/services/streaming/events/__init__.py new file mode 100644 index 000000000..91a8ff854 --- /dev/null +++ b/surfsense_backend/app/services/streaming/events/__init__.py @@ -0,0 +1,29 @@ +"""SSE event payload formatters, one module per event family.""" + +from __future__ import annotations + +from . import ( + action_log, + data, + error, + interrupt, + lifecycle, + reasoning, + source, + subagent_lifecycle, + text, + tool, +) + +__all__ = [ + "action_log", + "data", + "error", + "interrupt", + "lifecycle", + "reasoning", + "source", + "subagent_lifecycle", + "text", + "tool", +] diff --git a/surfsense_backend/app/services/streaming/events/action_log.py b/surfsense_backend/app/services/streaming/events/action_log.py new file mode 100644 index 000000000..0a8e46f0a --- /dev/null +++ b/surfsense_backend/app/services/streaming/events/action_log.py @@ -0,0 +1,24 @@ +"""Action-log events relayed from ``ActionLogMiddleware`` custom dispatches.""" + +from __future__ import annotations + +from typing import Any + +from ..emitter import Emitter +from .data import format_data + + +def format_action_log( + payload: dict[str, Any], + *, + emitter: Emitter | None = None, +) -> str: + return format_data("action-log", payload, emitter=emitter) + + +def format_action_log_updated( + payload: dict[str, Any], + *, + emitter: Emitter | None = None, +) -> str: + return format_data("action-log-updated", payload, emitter=emitter) diff --git a/surfsense_backend/app/services/streaming/events/data.py b/surfsense_backend/app/services/streaming/events/data.py new file mode 100644 index 000000000..f6e190578 --- /dev/null +++ b/surfsense_backend/app/services/streaming/events/data.py @@ -0,0 +1,118 @@ +"""Generic ``data-*`` envelopes and SurfSense-specific data parts. + +Inner ``data`` dict fields use snake_case. Legacy ``threadId`` / +``messageId`` keys are preserved where they cross the AI SDK boundary. +""" + +from __future__ import annotations + +from typing import Any + +from ..emitter import Emitter, attach_emitted_by +from ..envelope import format_sse + + +def format_data( + data_type: str, + data: Any, + *, + emitter: Emitter | None = None, +) -> str: + payload: dict[str, Any] = {"type": f"data-{data_type}", "data": data} + return format_sse(attach_emitted_by(payload, emitter)) + + +def format_terminal_info( + text: str, + *, + message_type: str = "info", + emitter: Emitter | None = None, +) -> str: + return format_data( + "terminal-info", + {"text": text, "type": message_type}, + emitter=emitter, + ) + + +def format_further_questions( + questions: list[str], + *, + emitter: Emitter | None = None, +) -> str: + return format_data("further-questions", {"questions": questions}, emitter=emitter) + + +def format_thinking_step( + *, + step_id: str, + title: str, + status: str = "in_progress", + items: list[str] | None = None, + emitter: Emitter | None = None, +) -> str: + return format_data( + "thinking-step", + { + "id": step_id, + "title": title, + "status": status, + "items": items or [], + }, + emitter=emitter, + ) + + +def format_thread_title_update( + *, + thread_id: int, + title: str, + emitter: Emitter | None = None, +) -> str: + return format_data( + "thread-title-update", + {"threadId": thread_id, "title": title}, + emitter=emitter, + ) + + +def format_turn_info( + *, + chat_turn_id: str, + emitter: Emitter | None = None, +) -> str: + return format_data("turn-info", {"chat_turn_id": chat_turn_id}, emitter=emitter) + + +def format_turn_status( + *, + status: str, + emitter: Emitter | None = None, +) -> str: + return format_data("turn-status", {"status": status}, emitter=emitter) + + +def format_user_message_id( + *, + message_id: str, + turn_id: str, + emitter: Emitter | None = None, +) -> str: + return format_data( + "user-message-id", + {"message_id": message_id, "turn_id": turn_id}, + emitter=emitter, + ) + + +def format_assistant_message_id( + *, + message_id: str, + turn_id: str, + emitter: Emitter | None = None, +) -> str: + return format_data( + "assistant-message-id", + {"message_id": message_id, "turn_id": turn_id}, + emitter=emitter, + ) diff --git a/surfsense_backend/app/services/streaming/events/error.py b/surfsense_backend/app/services/streaming/events/error.py new file mode 100644 index 000000000..a1e8e01ca --- /dev/null +++ b/surfsense_backend/app/services/streaming/events/error.py @@ -0,0 +1,23 @@ +"""Single terminal error path chat streaming must route through.""" + +from __future__ import annotations + +from typing import Any + +from ..emitter import Emitter, attach_emitted_by +from ..envelope import format_sse + + +def format_error( + error_text: str, + *, + error_code: str | None = None, + extra: dict[str, Any] | None = None, + emitter: Emitter | None = None, +) -> str: + payload: dict[str, Any] = {"type": "error", "errorText": error_text} + if error_code: + payload["errorCode"] = error_code + if extra: + payload.update(extra) + return format_sse(attach_emitted_by(payload, emitter)) diff --git a/surfsense_backend/app/services/streaming/events/interrupt.py b/surfsense_backend/app/services/streaming/events/interrupt.py new file mode 100644 index 000000000..0334b10b3 --- /dev/null +++ b/surfsense_backend/app/services/streaming/events/interrupt.py @@ -0,0 +1,56 @@ +"""Interrupt-request events with a single canonical payload shape.""" + +from __future__ import annotations + +from typing import Any + +from ..emitter import Emitter +from .data import format_data + + +def normalize_interrupt_payload(interrupt_value: dict[str, Any]) -> dict[str, Any]: + if "action_requests" in interrupt_value and "review_configs" in interrupt_value: + return interrupt_value + + interrupt_type = interrupt_value.get("type", "unknown") + message = interrupt_value.get("message") + action = interrupt_value.get("action", {}) or {} + context = interrupt_value.get("context", {}) or {} + + normalized: dict[str, Any] = { + "action_requests": [ + { + "name": action.get("tool", "unknown_tool"), + "args": action.get("params", {}), + } + ], + "review_configs": [ + { + "action_name": action.get("tool", "unknown_tool"), + "allowed_decisions": ["approve", "edit", "reject"], + } + ], + "interrupt_type": interrupt_type, + "context": context, + } + if message: + normalized["message"] = message + return normalized + + +def format_interrupt_request( + interrupt_value: dict[str, Any], + *, + interrupt_id: str | None = None, + pending_interrupt_count: int | None = None, + chat_turn_id: str | None = None, + emitter: Emitter | None = None, +) -> str: + payload = normalize_interrupt_payload(interrupt_value) + if interrupt_id is not None: + payload["interrupt_id"] = interrupt_id + if pending_interrupt_count is not None: + payload["pending_interrupt_count"] = pending_interrupt_count + if chat_turn_id is not None: + payload["chat_turn_id"] = chat_turn_id + return format_data("interrupt-request", payload, emitter=emitter) diff --git a/surfsense_backend/app/services/streaming/events/lifecycle.py b/surfsense_backend/app/services/streaming/events/lifecycle.py new file mode 100644 index 000000000..019718b67 --- /dev/null +++ b/surfsense_backend/app/services/streaming/events/lifecycle.py @@ -0,0 +1,29 @@ +"""High-level message and step lifecycle events. + +Wire verbs are fixed by the AI SDK protocol (``start`` / ``finish`` for +the whole message, ``start-step`` / ``finish-step`` for each step). +Python helpers always read ``format__`` so pairs are +visible at the call site. +""" + +from __future__ import annotations + +from ..emitter import Emitter, attach_emitted_by +from ..envelope import format_sse + + +def format_message_start(message_id: str, *, emitter: Emitter | None = None) -> str: + payload = {"type": "start", "messageId": message_id} + return format_sse(attach_emitted_by(payload, emitter)) + + +def format_message_finish(*, emitter: Emitter | None = None) -> str: + return format_sse(attach_emitted_by({"type": "finish"}, emitter)) + + +def format_step_start(*, emitter: Emitter | None = None) -> str: + return format_sse(attach_emitted_by({"type": "start-step"}, emitter)) + + +def format_step_finish(*, emitter: Emitter | None = None) -> str: + return format_sse(attach_emitted_by({"type": "finish-step"}, emitter)) diff --git a/surfsense_backend/app/services/streaming/events/reasoning.py b/surfsense_backend/app/services/streaming/events/reasoning.py new file mode 100644 index 000000000..12843e2c3 --- /dev/null +++ b/surfsense_backend/app/services/streaming/events/reasoning.py @@ -0,0 +1,32 @@ +"""Reasoning-block streaming events.""" + +from __future__ import annotations + +from ..emitter import Emitter, attach_emitted_by +from ..envelope import format_sse + + +def format_reasoning_start(reasoning_id: str, *, emitter: Emitter | None = None) -> str: + return format_sse( + attach_emitted_by({"type": "reasoning-start", "id": reasoning_id}, emitter) + ) + + +def format_reasoning_delta( + reasoning_id: str, + delta: str, + *, + emitter: Emitter | None = None, +) -> str: + return format_sse( + attach_emitted_by( + {"type": "reasoning-delta", "id": reasoning_id, "delta": delta}, + emitter, + ) + ) + + +def format_reasoning_end(reasoning_id: str, *, emitter: Emitter | None = None) -> str: + return format_sse( + attach_emitted_by({"type": "reasoning-end", "id": reasoning_id}, emitter) + ) diff --git a/surfsense_backend/app/services/streaming/events/source.py b/surfsense_backend/app/services/streaming/events/source.py new file mode 100644 index 000000000..54541e8d2 --- /dev/null +++ b/surfsense_backend/app/services/streaming/events/source.py @@ -0,0 +1,59 @@ +"""Source and file reference events.""" + +from __future__ import annotations + +from typing import Any + +from ..emitter import Emitter, attach_emitted_by +from ..envelope import format_sse + + +def format_source_url( + url: str, + *, + source_id: str | None = None, + title: str | None = None, + emitter: Emitter | None = None, +) -> str: + payload: dict[str, Any] = { + "type": "source-url", + "sourceId": source_id or url, + "url": url, + } + if title: + payload["title"] = title + return format_sse(attach_emitted_by(payload, emitter)) + + +def format_source_document( + source_id: str, + *, + media_type: str = "file", + title: str | None = None, + description: str | None = None, + emitter: Emitter | None = None, +) -> str: + payload: dict[str, Any] = { + "type": "source-document", + "sourceId": source_id, + "mediaType": media_type, + } + if title: + payload["title"] = title + if description: + payload["description"] = description + return format_sse(attach_emitted_by(payload, emitter)) + + +def format_file( + url: str, + media_type: str, + *, + emitter: Emitter | None = None, +) -> str: + payload: dict[str, Any] = { + "type": "file", + "url": url, + "mediaType": media_type, + } + return format_sse(attach_emitted_by(payload, emitter)) diff --git a/surfsense_backend/app/services/streaming/events/subagent_lifecycle.py b/surfsense_backend/app/services/streaming/events/subagent_lifecycle.py new file mode 100644 index 000000000..6dd2d4eab --- /dev/null +++ b/surfsense_backend/app/services/streaming/events/subagent_lifecycle.py @@ -0,0 +1,86 @@ +"""Sub-agent lifecycle events the FE pairs into one timeline lane. + +A sub-agent run is a high-level boundary (a whole agent invocation), +so we use the ``start`` / ``finish`` verb pair, matching how the AI SDK +spells message- and step-level lifecycles. +""" + +from __future__ import annotations + +from typing import Any + +from ..emitter import Emitter +from .data import format_data + + +def format_subagent_start( + *, + subagent_run_id: str, + subagent_type: str, + parent_tool_call_id: str, + chat_turn_id: str | None = None, + description: str | None = None, + started_at: str | None = None, + emitter: Emitter | None = None, +) -> str: + payload: dict[str, Any] = { + "subagent_run_id": subagent_run_id, + "subagent_type": subagent_type, + "parent_tool_call_id": parent_tool_call_id, + } + if chat_turn_id is not None: + payload["chat_turn_id"] = chat_turn_id + if description is not None: + payload["description"] = description + if started_at is not None: + payload["started_at"] = started_at + return format_data("subagent-start", payload, emitter=emitter) + + +def format_subagent_finish( + *, + subagent_run_id: str, + subagent_type: str, + parent_tool_call_id: str, + status: str = "completed", + ended_at: str | None = None, + duration_ms: int | None = None, + emitter: Emitter | None = None, +) -> str: + payload: dict[str, Any] = { + "subagent_run_id": subagent_run_id, + "subagent_type": subagent_type, + "parent_tool_call_id": parent_tool_call_id, + "status": status, + } + if ended_at is not None: + payload["ended_at"] = ended_at + if duration_ms is not None: + payload["duration_ms"] = duration_ms + return format_data("subagent-finish", payload, emitter=emitter) + + +def format_subagent_error( + *, + subagent_run_id: str, + subagent_type: str, + parent_tool_call_id: str, + error_text: str, + error_type: str | None = None, + ended_at: str | None = None, + duration_ms: int | None = None, + emitter: Emitter | None = None, +) -> str: + payload: dict[str, Any] = { + "subagent_run_id": subagent_run_id, + "subagent_type": subagent_type, + "parent_tool_call_id": parent_tool_call_id, + "error_text": error_text, + } + if error_type is not None: + payload["error_type"] = error_type + if ended_at is not None: + payload["ended_at"] = ended_at + if duration_ms is not None: + payload["duration_ms"] = duration_ms + return format_data("subagent-error", payload, emitter=emitter) diff --git a/surfsense_backend/app/services/streaming/events/text.py b/surfsense_backend/app/services/streaming/events/text.py new file mode 100644 index 000000000..30489ce49 --- /dev/null +++ b/surfsense_backend/app/services/streaming/events/text.py @@ -0,0 +1,27 @@ +"""Text-block streaming events.""" + +from __future__ import annotations + +from ..emitter import Emitter, attach_emitted_by +from ..envelope import format_sse + + +def format_text_start(text_id: str, *, emitter: Emitter | None = None) -> str: + return format_sse(attach_emitted_by({"type": "text-start", "id": text_id}, emitter)) + + +def format_text_delta( + text_id: str, + delta: str, + *, + emitter: Emitter | None = None, +) -> str: + return format_sse( + attach_emitted_by( + {"type": "text-delta", "id": text_id, "delta": delta}, emitter + ) + ) + + +def format_text_end(text_id: str, *, emitter: Emitter | None = None) -> str: + return format_sse(attach_emitted_by({"type": "text-end", "id": text_id}, emitter)) diff --git a/surfsense_backend/app/services/streaming/events/tool.py b/surfsense_backend/app/services/streaming/events/tool.py new file mode 100644 index 000000000..c85dc061b --- /dev/null +++ b/surfsense_backend/app/services/streaming/events/tool.py @@ -0,0 +1,80 @@ +"""Tool-call streaming events. + +``toolCallId`` and ``langchainToolCallId`` are AI SDK protocol fields +and stay camelCase. Sub-agent provenance rides on the snake_case +top-level ``emitted_by`` envelope added by :func:`attach_emitted_by`. +""" + +from __future__ import annotations + +from typing import Any + +from ..emitter import Emitter, attach_emitted_by +from ..envelope import format_sse + + +def format_tool_input_start( + tool_call_id: str, + tool_name: str, + *, + langchain_tool_call_id: str | None = None, + emitter: Emitter | None = None, +) -> str: + payload: dict[str, Any] = { + "type": "tool-input-start", + "toolCallId": tool_call_id, + "toolName": tool_name, + } + if langchain_tool_call_id: + payload["langchainToolCallId"] = langchain_tool_call_id + return format_sse(attach_emitted_by(payload, emitter)) + + +def format_tool_input_delta( + tool_call_id: str, + input_text_delta: str, + *, + emitter: Emitter | None = None, +) -> str: + payload: dict[str, Any] = { + "type": "tool-input-delta", + "toolCallId": tool_call_id, + "inputTextDelta": input_text_delta, + } + return format_sse(attach_emitted_by(payload, emitter)) + + +def format_tool_input_available( + tool_call_id: str, + tool_name: str, + input_data: dict[str, Any], + *, + langchain_tool_call_id: str | None = None, + emitter: Emitter | None = None, +) -> str: + payload: dict[str, Any] = { + "type": "tool-input-available", + "toolCallId": tool_call_id, + "toolName": tool_name, + "input": input_data, + } + if langchain_tool_call_id: + payload["langchainToolCallId"] = langchain_tool_call_id + return format_sse(attach_emitted_by(payload, emitter)) + + +def format_tool_output_available( + tool_call_id: str, + output: Any, + *, + langchain_tool_call_id: str | None = None, + emitter: Emitter | None = None, +) -> str: + payload: dict[str, Any] = { + "type": "tool-output-available", + "toolCallId": tool_call_id, + "output": output, + } + if langchain_tool_call_id: + payload["langchainToolCallId"] = langchain_tool_call_id + return format_sse(attach_emitted_by(payload, emitter)) diff --git a/surfsense_backend/app/services/streaming/interrupt_correlation.py b/surfsense_backend/app/services/streaming/interrupt_correlation.py new file mode 100644 index 000000000..3045dfb6a --- /dev/null +++ b/surfsense_backend/app/services/streaming/interrupt_correlation.py @@ -0,0 +1,84 @@ +"""Id-aware lookup of pending LangGraph interrupts (replaces first-wins).""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True) +class PendingInterrupt: + interrupt_id: str | None + value: dict[str, Any] + source_task_id: str | None = None + + +def list_pending_interrupts(state: Any) -> list[PendingInterrupt]: + out: list[PendingInterrupt] = [] + + for task in getattr(state, "tasks", None) or (): + task_id = _safe_str(getattr(task, "id", None)) + for it in getattr(task, "interrupts", None) or (): + value = _coerce_interrupt_value(it) + if value is None: + continue + interrupt_id = _safe_str(getattr(it, "id", None)) + out.append( + PendingInterrupt( + interrupt_id=interrupt_id, + value=value, + source_task_id=task_id, + ) + ) + + for it in getattr(state, "interrupts", None) or (): + value = _coerce_interrupt_value(it) + if value is None: + continue + interrupt_id = _safe_str(getattr(it, "id", None)) + out.append(PendingInterrupt(interrupt_id=interrupt_id, value=value)) + + return out + + +def get_pending_interrupt_by_id( + state: Any, interrupt_id: str +) -> PendingInterrupt | None: + for pending in list_pending_interrupts(state): + if pending.interrupt_id == interrupt_id: + return pending + return None + + +def get_pending_interrupt_for_tool_call( + state: Any, tool_call_id: str +) -> PendingInterrupt | None: + for pending in list_pending_interrupts(state): + actions = pending.value.get("action_requests") + if not isinstance(actions, list): + continue + for action in actions: + if not isinstance(action, dict): + continue + if action.get("tool_call_id") == tool_call_id: + return pending + return None + + +def first_pending_interrupt(state: Any) -> PendingInterrupt | None: + """Explicit opt-in to legacy first-wins; prefer the id-aware helpers above.""" + pending = list_pending_interrupts(state) + return pending[0] if pending else None + + +def _coerce_interrupt_value(item: Any) -> dict[str, Any] | None: + if isinstance(item, dict): + return item if item else None + value = getattr(item, "value", None) + if isinstance(value, dict): + return value if value else None + return None + + +def _safe_str(value: Any) -> str | None: + return value if isinstance(value, str) and value else None diff --git a/surfsense_backend/app/services/streaming/service.py b/surfsense_backend/app/services/streaming/service.py new file mode 100644 index 000000000..a3f1da785 --- /dev/null +++ b/surfsense_backend/app/services/streaming/service.py @@ -0,0 +1,410 @@ +"""Composition root: bundles every formatter + a per-invocation emitter registry.""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any + +from . import envelope +from .emitter import Emitter, EmitterRegistry +from .events import ( + action_log, + data, + error, + interrupt, + lifecycle, + reasoning, + source, + subagent_lifecycle, + text, + tool, +) + + +class StreamingService: + def __init__(self) -> None: + self._message_id: str | None = None + self.emitter_registry = EmitterRegistry() + + @property + def message_id(self) -> str | None: + return self._message_id + + def begin_message(self, message_id: str | None = None) -> str: + self._message_id = message_id or envelope.generate_message_id() + return self._message_id + + @staticmethod + def generate_text_id() -> str: + return envelope.generate_text_id() + + @staticmethod + def generate_reasoning_id() -> str: + return envelope.generate_reasoning_id() + + @staticmethod + def generate_tool_call_id() -> str: + return envelope.generate_tool_call_id() + + @staticmethod + def generate_subagent_run_id() -> str: + return envelope.generate_subagent_run_id() + + @staticmethod + def get_response_headers() -> dict[str, str]: + return envelope.get_response_headers() + + @staticmethod + def format_done() -> str: + return envelope.format_done() + + def resolve_emitter( + self, + *, + run_id: str | None, + parent_ids: Iterable[str] | None, + ) -> Emitter: + return self.emitter_registry.resolve(run_id=run_id, parent_ids=parent_ids) + + def format_message_start( + self, + message_id: str | None = None, + *, + emitter: Emitter | None = None, + ) -> str: + chosen = self.begin_message(message_id) + return lifecycle.format_message_start(chosen, emitter=emitter) + + def format_message_finish(self, *, emitter: Emitter | None = None) -> str: + return lifecycle.format_message_finish(emitter=emitter) + + def format_step_start(self, *, emitter: Emitter | None = None) -> str: + return lifecycle.format_step_start(emitter=emitter) + + def format_step_finish(self, *, emitter: Emitter | None = None) -> str: + return lifecycle.format_step_finish(emitter=emitter) + + def format_text_start(self, text_id: str, *, emitter: Emitter | None = None) -> str: + return text.format_text_start(text_id, emitter=emitter) + + def format_text_delta( + self, text_id: str, delta: str, *, emitter: Emitter | None = None + ) -> str: + return text.format_text_delta(text_id, delta, emitter=emitter) + + def format_text_end(self, text_id: str, *, emitter: Emitter | None = None) -> str: + return text.format_text_end(text_id, emitter=emitter) + + def format_reasoning_start( + self, reasoning_id: str, *, emitter: Emitter | None = None + ) -> str: + return reasoning.format_reasoning_start(reasoning_id, emitter=emitter) + + def format_reasoning_delta( + self, + reasoning_id: str, + delta: str, + *, + emitter: Emitter | None = None, + ) -> str: + return reasoning.format_reasoning_delta(reasoning_id, delta, emitter=emitter) + + def format_reasoning_end( + self, reasoning_id: str, *, emitter: Emitter | None = None + ) -> str: + return reasoning.format_reasoning_end(reasoning_id, emitter=emitter) + + def format_tool_input_start( + self, + tool_call_id: str, + tool_name: str, + *, + langchain_tool_call_id: str | None = None, + emitter: Emitter | None = None, + ) -> str: + return tool.format_tool_input_start( + tool_call_id, + tool_name, + langchain_tool_call_id=langchain_tool_call_id, + emitter=emitter, + ) + + def format_tool_input_delta( + self, + tool_call_id: str, + input_text_delta: str, + *, + emitter: Emitter | None = None, + ) -> str: + return tool.format_tool_input_delta( + tool_call_id, input_text_delta, emitter=emitter + ) + + def format_tool_input_available( + self, + tool_call_id: str, + tool_name: str, + input_data: dict[str, Any], + *, + langchain_tool_call_id: str | None = None, + emitter: Emitter | None = None, + ) -> str: + return tool.format_tool_input_available( + tool_call_id, + tool_name, + input_data, + langchain_tool_call_id=langchain_tool_call_id, + emitter=emitter, + ) + + def format_tool_output_available( + self, + tool_call_id: str, + output: Any, + *, + langchain_tool_call_id: str | None = None, + emitter: Emitter | None = None, + ) -> str: + return tool.format_tool_output_available( + tool_call_id, + output, + langchain_tool_call_id=langchain_tool_call_id, + emitter=emitter, + ) + + def format_source_url( + self, + url: str, + *, + source_id: str | None = None, + title: str | None = None, + emitter: Emitter | None = None, + ) -> str: + return source.format_source_url( + url, source_id=source_id, title=title, emitter=emitter + ) + + def format_source_document( + self, + source_id: str, + *, + media_type: str = "file", + title: str | None = None, + description: str | None = None, + emitter: Emitter | None = None, + ) -> str: + return source.format_source_document( + source_id, + media_type=media_type, + title=title, + description=description, + emitter=emitter, + ) + + def format_file( + self, url: str, media_type: str, *, emitter: Emitter | None = None + ) -> str: + return source.format_file(url, media_type, emitter=emitter) + + def format_data( + self, data_type: str, payload: Any, *, emitter: Emitter | None = None + ) -> str: + return data.format_data(data_type, payload, emitter=emitter) + + def format_terminal_info( + self, + text_value: str, + *, + message_type: str = "info", + emitter: Emitter | None = None, + ) -> str: + return data.format_terminal_info( + text_value, message_type=message_type, emitter=emitter + ) + + def format_further_questions( + self, + questions: list[str], + *, + emitter: Emitter | None = None, + ) -> str: + return data.format_further_questions(questions, emitter=emitter) + + def format_thinking_step( + self, + *, + step_id: str, + title: str, + status: str = "in_progress", + items: list[str] | None = None, + emitter: Emitter | None = None, + ) -> str: + return data.format_thinking_step( + step_id=step_id, + title=title, + status=status, + items=items, + emitter=emitter, + ) + + def format_thread_title_update( + self, + *, + thread_id: int, + title: str, + emitter: Emitter | None = None, + ) -> str: + return data.format_thread_title_update( + thread_id=thread_id, title=title, emitter=emitter + ) + + def format_turn_info( + self, + *, + chat_turn_id: str, + emitter: Emitter | None = None, + ) -> str: + return data.format_turn_info(chat_turn_id=chat_turn_id, emitter=emitter) + + def format_turn_status( + self, + *, + status: str, + emitter: Emitter | None = None, + ) -> str: + return data.format_turn_status(status=status, emitter=emitter) + + def format_user_message_id( + self, + *, + message_id: str, + turn_id: str, + emitter: Emitter | None = None, + ) -> str: + return data.format_user_message_id( + message_id=message_id, turn_id=turn_id, emitter=emitter + ) + + def format_assistant_message_id( + self, + *, + message_id: str, + turn_id: str, + emitter: Emitter | None = None, + ) -> str: + return data.format_assistant_message_id( + message_id=message_id, turn_id=turn_id, emitter=emitter + ) + + def format_error( + self, + error_text: str, + *, + error_code: str | None = None, + extra: dict[str, Any] | None = None, + emitter: Emitter | None = None, + ) -> str: + return error.format_error( + error_text, + error_code=error_code, + extra=extra, + emitter=emitter, + ) + + def format_interrupt_request( + self, + interrupt_value: dict[str, Any], + *, + interrupt_id: str | None = None, + pending_interrupt_count: int | None = None, + chat_turn_id: str | None = None, + emitter: Emitter | None = None, + ) -> str: + return interrupt.format_interrupt_request( + interrupt_value, + interrupt_id=interrupt_id, + pending_interrupt_count=pending_interrupt_count, + chat_turn_id=chat_turn_id, + emitter=emitter, + ) + + def format_subagent_start( + self, + *, + subagent_run_id: str, + subagent_type: str, + parent_tool_call_id: str, + chat_turn_id: str | None = None, + description: str | None = None, + started_at: str | None = None, + emitter: Emitter | None = None, + ) -> str: + return subagent_lifecycle.format_subagent_start( + subagent_run_id=subagent_run_id, + subagent_type=subagent_type, + parent_tool_call_id=parent_tool_call_id, + chat_turn_id=chat_turn_id, + description=description, + started_at=started_at, + emitter=emitter, + ) + + def format_subagent_finish( + self, + *, + subagent_run_id: str, + subagent_type: str, + parent_tool_call_id: str, + status: str = "completed", + ended_at: str | None = None, + duration_ms: int | None = None, + emitter: Emitter | None = None, + ) -> str: + return subagent_lifecycle.format_subagent_finish( + subagent_run_id=subagent_run_id, + subagent_type=subagent_type, + parent_tool_call_id=parent_tool_call_id, + status=status, + ended_at=ended_at, + duration_ms=duration_ms, + emitter=emitter, + ) + + def format_subagent_error( + self, + *, + subagent_run_id: str, + subagent_type: str, + parent_tool_call_id: str, + error_text: str, + error_type: str | None = None, + ended_at: str | None = None, + duration_ms: int | None = None, + emitter: Emitter | None = None, + ) -> str: + return subagent_lifecycle.format_subagent_error( + subagent_run_id=subagent_run_id, + subagent_type=subagent_type, + parent_tool_call_id=parent_tool_call_id, + error_text=error_text, + error_type=error_type, + ended_at=ended_at, + duration_ms=duration_ms, + emitter=emitter, + ) + + def format_action_log( + self, + payload: dict[str, Any], + *, + emitter: Emitter | None = None, + ) -> str: + return action_log.format_action_log(payload, emitter=emitter) + + def format_action_log_updated( + self, + payload: dict[str, Any], + *, + emitter: Emitter | None = None, + ) -> str: + return action_log.format_action_log_updated(payload, emitter=emitter) diff --git a/surfsense_backend/app/tasks/chat/content_builder.py b/surfsense_backend/app/tasks/chat/content_builder.py index 041cab286..c8f79c045 100644 --- a/surfsense_backend/app/tasks/chat/content_builder.py +++ b/surfsense_backend/app/tasks/chat/content_builder.py @@ -51,6 +51,23 @@ logger = logging.getLogger(__name__) _MEANINGFUL_PART_TYPES: frozenset[str] = frozenset({"text", "reasoning", "tool-call"}) +def _merge_tool_part_metadata( + part: dict[str, Any], metadata: dict[str, Any] | None +) -> None: + """Shallow-merge ``metadata`` into ``part["metadata"]``; first key wins. + + Used for tool-call linkage (``spanId``, ``thinkingStepId``, …): a later + event must not overwrite an existing key so chunk order vs ``on_tool_start`` + stays stable. + """ + if not metadata: + return + md = part.setdefault("metadata", {}) + for k, v in metadata.items(): + if k not in md: + md[k] = v + + class AssistantContentBuilder: """Server-side projection of ``surfsense_web/lib/chat/streaming-state.ts``. @@ -61,6 +78,7 @@ class AssistantContentBuilder: | { type: "reasoning"; text: string } | { type: "tool-call"; toolCallId: str; toolName: str; args: dict; result?: any; argsText?: str; langchainToolCallId?: str; + metadata?: { spanId?: str; thinkingStepId?: str; ... }; state?: "aborted" } | { type: "data-thinking-steps"; data: { steps: ThinkingStepData[] } } | { type: "data-step-separator"; data: { stepIndex: int } } @@ -85,8 +103,8 @@ class AssistantContentBuilder: self._current_text_idx: int = -1 self._current_reasoning_idx: int = -1 # ``ui_id``-keyed indexes for tool-call parts. ``ui_id`` is the - # synthetic ``call_`` (legacy) or the LangChain - # ``tool_call.id`` (parity_v2) — same key the streaming layer + # synthetic ``call_`` (chunk fallback) or the LangChain + # ``tool_call.id`` (indexed chunk path) — same key the streaming layer # threads through every ``tool-input-*`` / ``tool-output-*`` event. self._tool_call_idx_by_ui_id: dict[str, int] = {} # Live argsText accumulator (concatenated ``tool-input-delta`` chunks) @@ -177,21 +195,27 @@ class AssistantContentBuilder: ui_id: str, tool_name: str, langchain_tool_call_id: str | None, + *, + metadata: dict[str, Any] | None = None, ) -> None: - """Register a tool-call card. Args are filled in by later events.""" + """Register a tool-call card. Args are filled in by later events. + + Optional ``metadata`` (``spanId``, ``thinkingStepId``, …) is stored on the + part; duplicate ``tool-input-start`` calls merge with first-key-wins. + """ if not ui_id: return - # Skip duplicate registration: parity_v2 may emit + # Skip duplicate registration: the stream may emit # ``tool-input-start`` from both ``on_chat_model_stream`` # (when tool_call_chunks register a name) and ``on_tool_start`` # (the canonical path). The FE de-dupes via ``toolCallIndices``; # we mirror that here. if ui_id in self._tool_call_idx_by_ui_id: - if langchain_tool_call_id: - idx = self._tool_call_idx_by_ui_id[ui_id] - part = self.parts[idx] - if not part.get("langchainToolCallId"): - part["langchainToolCallId"] = langchain_tool_call_id + idx = self._tool_call_idx_by_ui_id[ui_id] + part = self.parts[idx] + if langchain_tool_call_id and not part.get("langchainToolCallId"): + part["langchainToolCallId"] = langchain_tool_call_id + _merge_tool_part_metadata(part, metadata) return part: dict[str, Any] = { @@ -202,6 +226,8 @@ class AssistantContentBuilder: } if langchain_tool_call_id: part["langchainToolCallId"] = langchain_tool_call_id + if metadata: + part["metadata"] = dict(metadata) self.parts.append(part) self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1 @@ -235,6 +261,8 @@ class AssistantContentBuilder: tool_name: str, args: dict[str, Any], langchain_tool_call_id: str | None, + *, + metadata: dict[str, Any] | None = None, ) -> None: """Finalize the tool-call card's input. @@ -243,7 +271,7 @@ class AssistantContentBuilder: pretty-printed JSON, sets the full ``args`` dict, and backfills ``langchainToolCallId`` if it wasn't known at ``tool-input-start`` time. Also creates the card if no prior ``tool-input-start`` registered it - (legacy parity_v2-OFF / late-registration paths). + (late-registration when no prior ``tool-input-start``). """ if not ui_id: return @@ -264,6 +292,7 @@ class AssistantContentBuilder: part["argsText"] = final_args_text if langchain_tool_call_id and not part.get("langchainToolCallId"): part["langchainToolCallId"] = langchain_tool_call_id + _merge_tool_part_metadata(part, metadata) return # No prior tool-input-start: register the card now. @@ -276,6 +305,7 @@ class AssistantContentBuilder: } if langchain_tool_call_id: new_part["langchainToolCallId"] = langchain_tool_call_id + _merge_tool_part_metadata(new_part, metadata) self.parts.append(new_part) self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1 @@ -287,6 +317,8 @@ class AssistantContentBuilder: ui_id: str, output: Any, langchain_tool_call_id: str | None, + *, + metadata: dict[str, Any] | None = None, ) -> None: """Attach the tool's output (``result``) to the matching card. @@ -305,6 +337,7 @@ class AssistantContentBuilder: part["result"] = output if langchain_tool_call_id and not part.get("langchainToolCallId"): part["langchainToolCallId"] = langchain_tool_call_id + _merge_tool_part_metadata(part, metadata) # ------------------------------------------------------------------ # Thinking steps & step separators @@ -316,6 +349,8 @@ class AssistantContentBuilder: title: str, status: str, items: list[str] | None, + *, + metadata: dict[str, Any] | None = None, ) -> None: """Update / insert the singleton ``data-thinking-steps`` part. @@ -328,12 +363,14 @@ class AssistantContentBuilder: if not step_id: return - new_step = { + new_step: dict[str, Any] = { "id": step_id, "title": title or "", "status": status or "in_progress", "items": list(items) if items else [], } + if metadata: + new_step["metadata"] = dict(metadata) # Find existing data-thinking-steps part. existing_idx = -1 @@ -347,6 +384,8 @@ class AssistantContentBuilder: replaced = False for i, step in enumerate(current_steps): if step.get("id") == step_id: + if not metadata and step.get("metadata"): + new_step["metadata"] = dict(step["metadata"]) current_steps[i] = new_step replaced = True break diff --git a/surfsense_backend/app/tasks/chat/persistence.py b/surfsense_backend/app/tasks/chat/persistence.py index b2b8b6a88..37be50705 100644 --- a/surfsense_backend/app/tasks/chat/persistence.py +++ b/surfsense_backend/app/tasks/chat/persistence.py @@ -109,17 +109,18 @@ def _build_user_content( [{"type": "text", "text": "..."}, {"type": "image", "image": "data:..."}, {"type": "mentioned-documents", "documents": [{"id": int, - "title": str, "document_type": str}, ...]}] + "title": str, "document_type": str, "kind": "doc" | "folder"}, + ...]}] The companion reader is ``app.utils.user_message_multimodal.split_persisted_user_content_parts`` which expects exactly this shape — keep them in sync. - ``mentioned_documents``: optional list of ``{id, title, document_type}`` - dicts. When non-empty (and a ``mentioned-documents`` part is not already - in some other input shape), a single ``{"type": "mentioned-documents", - "documents": [...]}`` part is appended. Mirrors the FE injection at - ``page.tsx:281-286`` (``persistUserTurn``). + ``mentioned_documents``: optional list of mention chip dicts. Each + dict may include a ``kind`` discriminator (``"doc"`` or ``"folder"``) + so the persisted ContentPart round-trips folder chips on reload. + When ``kind`` is missing we default to ``"doc"`` so legacy clients + that haven't migrated to the union schema still persist correctly. """ parts: list[dict[str, Any]] = [{"type": "text", "text": user_query or ""}] for url in user_image_data_urls or (): @@ -135,11 +136,14 @@ def _build_user_content( document_type = doc.get("document_type") if doc_id is None or title is None or document_type is None: continue + kind_raw = doc.get("kind", "doc") + kind = kind_raw if kind_raw in ("doc", "folder") else "doc" normalized.append( { "id": doc_id, "title": str(title), "document_type": str(document_type), + "kind": kind, } ) if normalized: diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 3ba3912eb..818282996 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -9,13 +9,11 @@ Supports loading LLM configurations from: - NewLLMConfig database table (positive IDs for user-created configs with prompt settings) """ -import ast import asyncio import contextlib import gc import json import logging -import re import time from collections.abc import AsyncGenerator from dataclasses import dataclass, field @@ -28,14 +26,11 @@ from langchain_core.messages import HumanMessage from sqlalchemy.future import select from sqlalchemy.orm import selectinload -from app.agents.multi_agent_chat import ( - create_surfsense_deep_agent as create_registry_deep_agent, -) +from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent from app.agents.new_chat.checkpointer import get_checkpointer from app.agents.new_chat.context import SurfSenseContextSchema from app.agents.new_chat.errors import BusyError -from app.agents.new_chat.feature_flags import get_flags from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection from app.agents.new_chat.llm_config import ( AgentConfig, @@ -48,6 +43,7 @@ from app.agents.new_chat.memory_extraction import ( extract_and_save_memory, extract_and_save_team_memory, ) +from app.agents.new_chat.mention_resolver import resolve_mentions, substitute_in_text from app.agents.new_chat.middleware.busy_mutex import ( end_turn, get_cancel_state, @@ -79,6 +75,7 @@ from app.services.chat_session_state_service import ( ) from app.services.connector_service import ConnectorService from app.services.new_streaming_service import VercelStreamingService +from app.tasks.chat.streaming.graph_stream.event_stream import stream_output from app.utils.content_utils import bootstrap_history_from_db from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap from app.utils.user_message_multimodal import build_human_message_content @@ -577,6 +574,43 @@ async def _preflight_llm(llm: Any) -> None: ) +async def _build_main_agent_for_thread( + agent_factory: Any, + *, + llm: Any, + search_space_id: int, + db_session: Any, + connector_service: ConnectorService, + checkpointer: Any, + user_id: str | None, + thread_id: int | None, + agent_config: AgentConfig | None, + firecrawl_api_key: str | None, + thread_visibility: ChatVisibility | None, + filesystem_selection: FilesystemSelection | None, + disabled_tools: list[str] | None = None, + mentioned_document_ids: list[int] | None = None, +) -> Any: + """Single (re)build path so the agent factory cannot drift across + initial build, preflight repin, and mid-stream 429 recovery for one + ``thread_id``: a graph swap mid-turn would corrupt checkpointer state.""" + return await agent_factory( + llm=llm, + search_space_id=search_space_id, + db_session=db_session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=thread_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=thread_visibility, + filesystem_selection=filesystem_selection, + disabled_tools=disabled_tools, + mentioned_document_ids=mentioned_document_ids, + ) + + async def _settle_speculative_agent_build(task: asyncio.Task[Any]) -> None: """Wait for a discarded speculative agent build to release shared state. @@ -694,9 +728,9 @@ def _legacy_match_lc_id( ) -> str | None: """Best-effort match a buffered ``tool_call_chunk`` to a tool name. - Pure extract of the legacy in-line match used at ``on_tool_start`` for - parity_v2-OFF and unmatched (chunk path didn't register an index for - this call) tools. Pops the next id-bearing chunk whose ``name`` + Pure extract of the in-line match used at ``on_tool_start`` when the + chunk path didn't register an index for this call. Pops the next + id-bearing chunk whose ``name`` matches ``tool_name`` (or any id-bearing chunk as a fallback) and returns its id. Mutates ``pending_tool_call_chunks`` and ``lc_tool_call_id_by_run`` in place. @@ -768,1505 +802,22 @@ async def _stream_agent_events( Yields: SSE-formatted strings for each event. """ - accumulated_text = "" - current_text_id: str | None = None - thinking_step_counter = 1 if initial_step_id else 0 - tool_step_ids: dict[str, str] = {} - completed_step_ids: set[str] = set() - last_active_step_id: str | None = initial_step_id - last_active_step_title: str = initial_step_title - last_active_step_items: list[str] = initial_step_items or [] - just_finished_tool: bool = False - active_tool_depth: int = 0 # Track nesting: >0 means we're inside a tool - called_update_memory: bool = False + async for sse in stream_output( + agent=agent, + config=config, + input_data=input_data, + streaming_service=streaming_service, + result=result, + step_prefix=step_prefix, + initial_step_id=initial_step_id, + initial_step_title=initial_step_title, + initial_step_items=initial_step_items, + content_builder=content_builder, + runtime_context=runtime_context, + ): + yield sse - # Reasoning-block streaming. We open a reasoning block on the - # first reasoning delta of a step, append deltas as they arrive, and - # close it when text starts (the model has switched to writing its - # answer) or ``on_chat_model_end`` fires for the model node. Reuses - # the same Vercel format-helpers as text-start/delta/end. - current_reasoning_id: str | None = None - - # Streaming-parity v2 feature flag. When OFF we keep the legacy - # shape: str-only content, no reasoning blocks, no - # ``langchainToolCallId`` propagation. The schema migrations - # (135 / 136) ship unconditionally because they're forward-compatible. - parity_v2 = bool(get_flags().enable_stream_parity_v2) - - # Best-effort attach of LangChain ``tool_call_id`` to the synthetic - # ``call_`` card id we already emit. We accumulate - # ``tool_call_chunks`` from ``on_chat_model_stream``, key them by - # name, and pop the next unconsumed entry at ``on_tool_start``. The - # authoritative id is later filled in at ``on_tool_end`` from - # ``ToolMessage.tool_call_id``. Under parity_v2 we ALSO short-circuit - # this list for chunks that already registered into ``index_to_meta`` - # below — so this list is reserved for the parity_v2-OFF / unmatched - # fallback path only and never re-pops a chunk we already streamed. - pending_tool_call_chunks: list[dict[str, Any]] = [] - lc_tool_call_id_by_run: dict[str, str] = {} - file_path_by_run: dict[str, str] = {} - - # parity_v2 only: live tool-call argument streaming. ``index_to_meta`` - # is keyed by the chunk's ``index`` field — LangChain - # ``ToolCallChunk``s for the same call share an index but only the - # first chunk carries id+name (subsequent ones are id=None, - # name=None, args=""). We register an index when both id and - # name are observed on a chunk (per ToolCallChunk semantics they - # arrive together on the first chunk), then route every later chunk - # at that index to the same ``ui_id`` as a ``tool-input-delta``. - # ``ui_tool_call_id_by_run`` maps LangGraph ``run_id`` to the - # ``ui_id`` used for that call's ``tool-input-start`` so the matching - # ``tool-output-available`` (emitted from ``on_tool_end``) lands on - # the same card. - index_to_meta: dict[int, dict[str, str]] = {} - ui_tool_call_id_by_run: dict[str, str] = {} - - # Per-tool-end mutable cache for the LangChain tool_call_id resolved - # at ``on_tool_end``. ``_emit_tool_output`` reads this so every - # ``format_tool_output_available`` call automatically carries the - # authoritative id without duplicating the kwarg at every call site. - current_lc_tool_call_id: dict[str, str | None] = {"value": None} - - def _emit_tool_output(call_id: str, output: Any) -> str: - # Drive the builder before formatting the SSE so the in-memory - # ContentPart[] mirror sees the result attached to the same - # card the FE will render. Builder method is a no-op when - # ``content_builder`` is None (anonymous / legacy paths). - if content_builder is not None: - content_builder.on_tool_output_available( - call_id, output, current_lc_tool_call_id["value"] - ) - return streaming_service.format_tool_output_available( - call_id, - output, - langchain_tool_call_id=current_lc_tool_call_id["value"], - ) - - def _emit_thinking_step( - *, - step_id: str, - title: str, - status: str = "in_progress", - items: list[str] | None = None, - ) -> str: - """Format a thinking-step SSE event and notify the builder. - - Single helper used at every ``format_thinking_step`` yield site - in this generator. Drives ``AssistantContentBuilder.on_thinking_step`` - first so the FE-mirror state lands the update before the SSE - carrying the same data leaves the wire — order matches the FE - pipeline (``processSharedStreamEvent`` updates state, then - flushes). Builder call is a no-op when ``content_builder`` is - None (anonymous / legacy paths). - """ - if content_builder is not None: - content_builder.on_thinking_step(step_id, title, status, items) - return streaming_service.format_thinking_step( - step_id=step_id, - title=title, - status=status, - items=items, - ) - - def next_thinking_step_id() -> str: - nonlocal thinking_step_counter - thinking_step_counter += 1 - return f"{step_prefix}-{thinking_step_counter}" - - def complete_current_step() -> str | None: - nonlocal last_active_step_id - if last_active_step_id and last_active_step_id not in completed_step_ids: - completed_step_ids.add(last_active_step_id) - event = _emit_thinking_step( - step_id=last_active_step_id, - title=last_active_step_title, - status="completed", - items=last_active_step_items if last_active_step_items else None, - ) - last_active_step_id = None - return event - return None - - # Per-invocation runtime context (Phase 1.5). When supplied, - # ``KnowledgePriorityMiddleware`` reads ``mentioned_document_ids`` - # from ``runtime.context`` instead of its constructor closure — the - # prerequisite that lets the compiled-agent cache (Phase 1) reuse a - # single graph across turns. Astream_events_kwargs stays empty when - # callers leave ``runtime_context`` as ``None`` to preserve the - # legacy code path bit-for-bit. - astream_kwargs: dict[str, Any] = {"config": config, "version": "v2"} - if runtime_context is not None: - astream_kwargs["context"] = runtime_context - - async for event in agent.astream_events(input_data, **astream_kwargs): - event_type = event.get("event", "") - - if event_type == "on_chat_model_stream": - if active_tool_depth > 0: - continue # Suppress inner-tool LLM tokens from leaking into chat - if "surfsense:internal" in event.get("tags", []): - continue # Suppress middleware-internal LLM tokens (e.g. KB search classification) - chunk = event.get("data", {}).get("chunk") - if not chunk: - continue - parts = _extract_chunk_parts(chunk) - - reasoning_delta = parts["reasoning"] - text_delta = parts["text"] - - # Reasoning streaming. Open a reasoning block on first - # delta; append every subsequent delta until text begins. - # When text starts we close the reasoning block first so the - # frontend sees the natural hand-off. Gated behind the - # parity-v2 flag so legacy deployments keep today's shape. - if parity_v2 and reasoning_delta: - if current_text_id is not None: - yield streaming_service.format_text_end(current_text_id) - if content_builder is not None: - content_builder.on_text_end(current_text_id) - current_text_id = None - if current_reasoning_id is None: - completion_event = complete_current_step() - if completion_event: - yield completion_event - if just_finished_tool: - last_active_step_id = None - last_active_step_title = "" - last_active_step_items = [] - just_finished_tool = False - current_reasoning_id = streaming_service.generate_reasoning_id() - yield streaming_service.format_reasoning_start(current_reasoning_id) - if content_builder is not None: - content_builder.on_reasoning_start(current_reasoning_id) - yield streaming_service.format_reasoning_delta( - current_reasoning_id, reasoning_delta - ) - if content_builder is not None: - content_builder.on_reasoning_delta( - current_reasoning_id, reasoning_delta - ) - - if text_delta: - if current_reasoning_id is not None: - yield streaming_service.format_reasoning_end(current_reasoning_id) - if content_builder is not None: - content_builder.on_reasoning_end(current_reasoning_id) - current_reasoning_id = None - if current_text_id is None: - completion_event = complete_current_step() - if completion_event: - yield completion_event - if just_finished_tool: - last_active_step_id = None - last_active_step_title = "" - last_active_step_items = [] - just_finished_tool = False - current_text_id = streaming_service.generate_text_id() - yield streaming_service.format_text_start(current_text_id) - if content_builder is not None: - content_builder.on_text_start(current_text_id) - yield streaming_service.format_text_delta(current_text_id, text_delta) - accumulated_text += text_delta - if content_builder is not None: - content_builder.on_text_delta(current_text_id, text_delta) - - # Live tool-call argument streaming. Runs AFTER text/reasoning - # processing so chunks containing both stay in their natural - # wire order (text → text-end → tool-input-start). Active - # text/reasoning are closed inside the registration branch - # before ``tool-input-start`` so the frontend sees a clean - # part boundary even when providers interleave. - if parity_v2 and parts["tool_call_chunks"]: - for tcc in parts["tool_call_chunks"]: - idx = tcc.get("index") - - # Register this index when we first see id+name - # TOGETHER. Per LangChain ToolCallChunk semantics the - # first chunk for a tool call carries both fields - # together; later chunks have id=None, name=None and - # only ``args``. Requiring BOTH keeps wire - # ``tool-input-start`` always carrying a real - # toolName (assistant-ui's typed tool-part dispatch - # keys off it). - if idx is not None and idx not in index_to_meta: - lc_id = tcc.get("id") - name = tcc.get("name") - if lc_id and name: - ui_id = lc_id - - # Close active text/reasoning so wire - # ordering stays clean even on providers - # that interleave text and tool-call chunks - # within the same stream window. - if current_text_id is not None: - yield streaming_service.format_text_end(current_text_id) - if content_builder is not None: - content_builder.on_text_end(current_text_id) - current_text_id = None - if current_reasoning_id is not None: - yield streaming_service.format_reasoning_end( - current_reasoning_id - ) - if content_builder is not None: - content_builder.on_reasoning_end( - current_reasoning_id - ) - current_reasoning_id = None - - index_to_meta[idx] = { - "ui_id": ui_id, - "lc_id": lc_id, - "name": name, - } - yield streaming_service.format_tool_input_start( - ui_id, - name, - langchain_tool_call_id=lc_id, - ) - if content_builder is not None: - content_builder.on_tool_input_start(ui_id, name, lc_id) - - # Emit args delta for any chunk at a registered - # index (including idless continuations). Once an - # index is owned by ``index_to_meta`` we DO NOT - # append to ``pending_tool_call_chunks`` — that list - # is reserved for the parity_v2-OFF / unmatched - # fallback path so it never re-pops chunks already - # consumed here (skip-append). - meta = index_to_meta.get(idx) if idx is not None else None - if meta: - args_chunk = tcc.get("args") or "" - if args_chunk: - yield streaming_service.format_tool_input_delta( - meta["ui_id"], args_chunk - ) - if content_builder is not None: - content_builder.on_tool_input_delta( - meta["ui_id"], args_chunk - ) - else: - pending_tool_call_chunks.append(tcc) - - elif event_type == "on_tool_start": - active_tool_depth += 1 - tool_name = event.get("name", "unknown_tool") - run_id = event.get("run_id", "") - tool_input = event.get("data", {}).get("input", {}) - if tool_name in ("write_file", "edit_file"): - result.write_attempted = True - if isinstance(tool_input, dict): - file_path = tool_input.get("file_path") - if isinstance(file_path, str) and file_path.strip() and run_id: - file_path_by_run[run_id] = file_path.strip() - - if current_text_id is not None: - yield streaming_service.format_text_end(current_text_id) - if content_builder is not None: - content_builder.on_text_end(current_text_id) - current_text_id = None - - if last_active_step_title != "Synthesizing response": - completion_event = complete_current_step() - if completion_event: - yield completion_event - - just_finished_tool = False - tool_step_id = next_thinking_step_id() - tool_step_ids[run_id] = tool_step_id - last_active_step_id = tool_step_id - - if tool_name == "ls": - ls_path = ( - tool_input.get("path", "/") - if isinstance(tool_input, dict) - else str(tool_input) - ) - last_active_step_title = "Listing files" - last_active_step_items = [ls_path] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Listing files", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "read_file": - fp = ( - tool_input.get("file_path", "") - if isinstance(tool_input, dict) - else str(tool_input) - ) - display_fp = fp if len(fp) <= 80 else "…" + fp[-77:] - last_active_step_title = "Reading file" - last_active_step_items = [display_fp] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Reading file", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "write_file": - fp = ( - tool_input.get("file_path", "") - if isinstance(tool_input, dict) - else str(tool_input) - ) - display_fp = fp if len(fp) <= 80 else "…" + fp[-77:] - last_active_step_title = "Writing file" - last_active_step_items = [display_fp] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Writing file", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "edit_file": - fp = ( - tool_input.get("file_path", "") - if isinstance(tool_input, dict) - else str(tool_input) - ) - display_fp = fp if len(fp) <= 80 else "…" + fp[-77:] - last_active_step_title = "Editing file" - last_active_step_items = [display_fp] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Editing file", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "glob": - pat = ( - tool_input.get("pattern", "") - if isinstance(tool_input, dict) - else str(tool_input) - ) - base_path = ( - tool_input.get("path", "/") if isinstance(tool_input, dict) else "/" - ) - last_active_step_title = "Searching files" - last_active_step_items = [f"{pat} in {base_path}"] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Searching files", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "grep": - pat = ( - tool_input.get("pattern", "") - if isinstance(tool_input, dict) - else str(tool_input) - ) - grep_path = ( - tool_input.get("path", "") if isinstance(tool_input, dict) else "" - ) - display_pat = pat[:60] + ("…" if len(pat) > 60 else "") - last_active_step_title = "Searching content" - last_active_step_items = [ - f'"{display_pat}"' + (f" in {grep_path}" if grep_path else "") - ] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Searching content", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "rm": - rm_path = ( - tool_input.get("path", "") - if isinstance(tool_input, dict) - else str(tool_input) - ) - display_path = rm_path if len(rm_path) <= 80 else "…" + rm_path[-77:] - last_active_step_title = "Deleting file" - last_active_step_items = [display_path] if display_path else [] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Deleting file", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "rmdir": - rmdir_path = ( - tool_input.get("path", "") - if isinstance(tool_input, dict) - else str(tool_input) - ) - display_path = ( - rmdir_path if len(rmdir_path) <= 80 else "…" + rmdir_path[-77:] - ) - last_active_step_title = "Deleting folder" - last_active_step_items = [display_path] if display_path else [] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Deleting folder", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "mkdir": - mkdir_path = ( - tool_input.get("path", "") - if isinstance(tool_input, dict) - else str(tool_input) - ) - display_path = ( - mkdir_path if len(mkdir_path) <= 80 else "…" + mkdir_path[-77:] - ) - last_active_step_title = "Creating folder" - last_active_step_items = [display_path] if display_path else [] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Creating folder", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "move_file": - src = ( - tool_input.get("source_path", "") - if isinstance(tool_input, dict) - else "" - ) - dst = ( - tool_input.get("destination_path", "") - if isinstance(tool_input, dict) - else "" - ) - display_src = src if len(src) <= 60 else "…" + src[-57:] - display_dst = dst if len(dst) <= 60 else "…" + dst[-57:] - last_active_step_title = "Moving file" - last_active_step_items = ( - [f"{display_src} → {display_dst}"] if src or dst else [] - ) - yield _emit_thinking_step( - step_id=tool_step_id, - title="Moving file", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "write_todos": - todos = ( - tool_input.get("todos", []) if isinstance(tool_input, dict) else [] - ) - todo_count = len(todos) if isinstance(todos, list) else 0 - last_active_step_title = "Planning tasks" - last_active_step_items = ( - [f"{todo_count} task{'s' if todo_count != 1 else ''}"] - if todo_count - else [] - ) - yield _emit_thinking_step( - step_id=tool_step_id, - title="Planning tasks", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "save_document": - doc_title = ( - tool_input.get("title", "") - if isinstance(tool_input, dict) - else str(tool_input) - ) - display_title = doc_title[:60] + ("…" if len(doc_title) > 60 else "") - last_active_step_title = "Saving document" - last_active_step_items = [display_title] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Saving document", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "generate_image": - prompt = ( - tool_input.get("prompt", "") - if isinstance(tool_input, dict) - else str(tool_input) - ) - last_active_step_title = "Generating image" - last_active_step_items = [ - f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}" - ] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Generating image", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "scrape_webpage": - url = ( - tool_input.get("url", "") - if isinstance(tool_input, dict) - else str(tool_input) - ) - last_active_step_title = "Scraping webpage" - last_active_step_items = [ - f"URL: {url[:80]}{'...' if len(url) > 80 else ''}" - ] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Scraping webpage", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "generate_podcast": - podcast_title = ( - tool_input.get("podcast_title", "SurfSense Podcast") - if isinstance(tool_input, dict) - else "SurfSense Podcast" - ) - content_len = len( - tool_input.get("source_content", "") - if isinstance(tool_input, dict) - else "" - ) - last_active_step_title = "Generating podcast" - last_active_step_items = [ - f"Title: {podcast_title}", - f"Content: {content_len:,} characters", - "Preparing audio generation...", - ] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Generating podcast", - status="in_progress", - items=last_active_step_items, - ) - elif tool_name == "generate_report": - report_topic = ( - tool_input.get("topic", "Report") - if isinstance(tool_input, dict) - else "Report" - ) - is_revision = bool( - isinstance(tool_input, dict) and tool_input.get("parent_report_id") - ) - step_title = "Revising report" if is_revision else "Generating report" - last_active_step_title = step_title - last_active_step_items = [ - f"Topic: {report_topic}", - "Analyzing source content...", - ] - yield _emit_thinking_step( - step_id=tool_step_id, - title=step_title, - status="in_progress", - items=last_active_step_items, - ) - elif tool_name in ("execute", "execute_code"): - cmd = ( - tool_input.get("command", "") - if isinstance(tool_input, dict) - else str(tool_input) - ) - display_cmd = cmd[:80] + ("…" if len(cmd) > 80 else "") - last_active_step_title = "Running command" - last_active_step_items = [f"$ {display_cmd}"] - yield _emit_thinking_step( - step_id=tool_step_id, - title="Running command", - status="in_progress", - items=last_active_step_items, - ) - else: - # Fallback for tools without a curated thinking-step title - # (typically connector tools, MCP-registered tools, or - # newly added tools that haven't been wired up here yet). - # Render the snake_cased name as a sentence-cased phrase - # so non-technical users see e.g. "Send gmail email" - # rather than the raw identifier "send_gmail_email". - last_active_step_title = ( - tool_name.replace("_", " ").strip().capitalize() or tool_name - ) - last_active_step_items = [] - yield _emit_thinking_step( - step_id=tool_step_id, - title=last_active_step_title, - status="in_progress", - ) - - # Resolve the card identity. If the chunk-emission loop - # already registered an ``index`` for this tool call (parity_v2 - # path), reuse the same ui_id so the card sees: - # tool-input-start → deltas… → tool-input-available → - # tool-output-available all keyed by lc_id. Otherwise fall - # back to the synthetic ``call_`` id and the legacy - # best-effort match against ``pending_tool_call_chunks``. - matched_meta: dict[str, str] | None = None - if parity_v2: - # FIFO over indices 0,1,2…; first unassigned same-name - # match wins. Handles parallel same-name calls (e.g. two - # write_file calls) deterministically as long as the - # model interleaves on_tool_start in the same order it - # streamed the args. - taken_ui_ids = set(ui_tool_call_id_by_run.values()) - for meta in index_to_meta.values(): - if meta["name"] == tool_name and meta["ui_id"] not in taken_ui_ids: - matched_meta = meta - break - - tool_call_id: str - langchain_tool_call_id: str | None = None - if matched_meta is not None: - tool_call_id = matched_meta["ui_id"] - langchain_tool_call_id = matched_meta["lc_id"] - # ``tool-input-start`` already fired during chunk - # emission — skip the duplicate. No pruning is needed - # because the chunk-emission loop intentionally never - # appends registered-index chunks to - # ``pending_tool_call_chunks`` (skip-append). - if run_id: - lc_tool_call_id_by_run[run_id] = matched_meta["lc_id"] - else: - tool_call_id = ( - f"call_{run_id[:32]}" - if run_id - else streaming_service.generate_tool_call_id() - ) - # Legacy fallback: parity_v2 OFF, or parity_v2 ON but the - # provider didn't stream tool_call_chunks for this call - # (no index registered). Run the existing best-effort - # match BEFORE emitting start so we still attach an - # authoritative ``langchainToolCallId`` when possible. - if parity_v2: - langchain_tool_call_id = _legacy_match_lc_id( - pending_tool_call_chunks, - tool_name, - run_id, - lc_tool_call_id_by_run, - ) - yield streaming_service.format_tool_input_start( - tool_call_id, - tool_name, - langchain_tool_call_id=langchain_tool_call_id, - ) - if content_builder is not None: - content_builder.on_tool_input_start( - tool_call_id, tool_name, langchain_tool_call_id - ) - - if run_id: - ui_tool_call_id_by_run[run_id] = tool_call_id - - # Sanitize tool_input: strip runtime-injected non-serializable - # values (e.g. LangChain ToolRuntime) before sending over SSE. - if isinstance(tool_input, dict): - _safe_input: dict[str, Any] = {} - for _k, _v in tool_input.items(): - try: - json.dumps(_v) - _safe_input[_k] = _v - except (TypeError, ValueError, OverflowError): - pass - else: - _safe_input = {"input": tool_input} - yield streaming_service.format_tool_input_available( - tool_call_id, - tool_name, - _safe_input, - langchain_tool_call_id=langchain_tool_call_id, - ) - if content_builder is not None: - content_builder.on_tool_input_available( - tool_call_id, - tool_name, - _safe_input, - langchain_tool_call_id, - ) - - elif event_type == "on_tool_end": - active_tool_depth = max(0, active_tool_depth - 1) - run_id = event.get("run_id", "") - tool_name = event.get("name", "unknown_tool") - raw_output = event.get("data", {}).get("output", "") - staged_file_path = file_path_by_run.pop(run_id, None) if run_id else None - - if tool_name == "update_memory": - called_update_memory = True - - if hasattr(raw_output, "content"): - content = raw_output.content - if isinstance(content, str): - try: - tool_output = json.loads(content) - except (json.JSONDecodeError, TypeError): - tool_output = {"result": content} - elif isinstance(content, dict): - tool_output = content - else: - tool_output = {"result": str(content)} - elif isinstance(raw_output, dict): - tool_output = raw_output - else: - tool_output = {"result": str(raw_output) if raw_output else "completed"} - - if tool_name in ("write_file", "edit_file"): - if _tool_output_has_error(tool_output): - # Keep successful evidence if a previous write/edit in this turn succeeded. - pass - else: - result.write_succeeded = True - result.verification_succeeded = True - - # Look up the SAME card id used at on_tool_start (either the - # parity_v2 lc-id-derived ui_id or the legacy synthetic - # ``call_``) so the output event always lands on the - # same card as start/delta/available. Fallback preserves the - # legacy synthetic shape for parity_v2-OFF / unknown-run paths. - tool_call_id = ui_tool_call_id_by_run.get( - run_id, - f"call_{run_id[:32]}" if run_id else "call_unknown", - ) - original_step_id = tool_step_ids.get( - run_id, f"{step_prefix}-unknown-{run_id[:8]}" - ) - completed_step_ids.add(original_step_id) - - # Authoritative LangChain tool_call_id from the returned - # ``ToolMessage``. Falls back to whatever we matched - # at ``on_tool_start`` time (kept in ``lc_tool_call_id_by_run``) - # if the output isn't a ToolMessage. The value is stored in - # ``current_lc_tool_call_id`` so ``_emit_tool_output`` - # picks it up for every output emit below. - # - # Emitted in BOTH parity_v2 and legacy modes: the chat tool - # card needs the LangChain id to match against the - # ``data-action-log`` SSE event (keyed by ``lc_tool_call_id``) - # so the inline Revert button can light up. Reading - # ``raw_output.tool_call_id`` is a cheap, non-mutating attribute - # access that is safe regardless of feature-flag state. - current_lc_tool_call_id["value"] = None - authoritative = getattr(raw_output, "tool_call_id", None) - if isinstance(authoritative, str) and authoritative: - current_lc_tool_call_id["value"] = authoritative - if run_id: - lc_tool_call_id_by_run[run_id] = authoritative - elif run_id and run_id in lc_tool_call_id_by_run: - current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id] - - if tool_name == "read_file": - yield _emit_thinking_step( - step_id=original_step_id, - title="Reading file", - status="completed", - items=last_active_step_items, - ) - elif tool_name == "write_file": - yield _emit_thinking_step( - step_id=original_step_id, - title="Writing file", - status="completed", - items=last_active_step_items, - ) - elif tool_name == "edit_file": - yield _emit_thinking_step( - step_id=original_step_id, - title="Editing file", - status="completed", - items=last_active_step_items, - ) - elif tool_name == "glob": - yield _emit_thinking_step( - step_id=original_step_id, - title="Searching files", - status="completed", - items=last_active_step_items, - ) - elif tool_name == "grep": - yield _emit_thinking_step( - step_id=original_step_id, - title="Searching content", - status="completed", - items=last_active_step_items, - ) - elif tool_name == "rm": - yield _emit_thinking_step( - step_id=original_step_id, - title="Deleting file", - status="completed", - items=last_active_step_items, - ) - elif tool_name == "rmdir": - yield _emit_thinking_step( - step_id=original_step_id, - title="Deleting folder", - status="completed", - items=last_active_step_items, - ) - elif tool_name == "mkdir": - yield _emit_thinking_step( - step_id=original_step_id, - title="Creating folder", - status="completed", - items=last_active_step_items, - ) - elif tool_name == "move_file": - yield _emit_thinking_step( - step_id=original_step_id, - title="Moving file", - status="completed", - items=last_active_step_items, - ) - elif tool_name == "write_todos": - yield _emit_thinking_step( - step_id=original_step_id, - title="Planning tasks", - status="completed", - items=last_active_step_items, - ) - elif tool_name == "save_document": - result_str = ( - tool_output.get("result", "") - if isinstance(tool_output, dict) - else str(tool_output) - ) - is_error = "Error" in result_str - completed_items = [ - *last_active_step_items, - result_str[:80] if is_error else "Saved to knowledge base", - ] - yield _emit_thinking_step( - step_id=original_step_id, - title="Saving document", - status="completed", - items=completed_items, - ) - elif tool_name == "generate_image": - if isinstance(tool_output, dict) and not tool_output.get("error"): - completed_items = [ - *last_active_step_items, - "Image generated successfully", - ] - else: - error_msg = ( - tool_output.get("error", "Generation failed") - if isinstance(tool_output, dict) - else "Generation failed" - ) - completed_items = [*last_active_step_items, f"Error: {error_msg}"] - yield _emit_thinking_step( - step_id=original_step_id, - title="Generating image", - status="completed", - items=completed_items, - ) - elif tool_name == "scrape_webpage": - if isinstance(tool_output, dict): - title = tool_output.get("title", "Webpage") - word_count = tool_output.get("word_count", 0) - has_error = "error" in tool_output - if has_error: - completed_items = [ - *last_active_step_items, - f"Error: {tool_output.get('error', 'Failed to scrape')[:50]}", - ] - else: - completed_items = [ - *last_active_step_items, - f"Title: {title[:50]}{'...' if len(title) > 50 else ''}", - f"Extracted: {word_count:,} words", - ] - else: - completed_items = [*last_active_step_items, "Content extracted"] - yield _emit_thinking_step( - step_id=original_step_id, - title="Scraping webpage", - status="completed", - items=completed_items, - ) - elif tool_name == "generate_podcast": - podcast_status = ( - tool_output.get("status", "unknown") - if isinstance(tool_output, dict) - else "unknown" - ) - podcast_title = ( - tool_output.get("title", "Podcast") - if isinstance(tool_output, dict) - else "Podcast" - ) - if podcast_status in ("pending", "generating", "processing"): - completed_items = [ - f"Title: {podcast_title}", - "Podcast generation started", - "Processing in background...", - ] - elif podcast_status == "already_generating": - completed_items = [ - f"Title: {podcast_title}", - "Podcast already in progress", - "Please wait for it to complete", - ] - elif podcast_status in ("failed", "error"): - error_msg = ( - tool_output.get("error", "Unknown error") - if isinstance(tool_output, dict) - else "Unknown error" - ) - completed_items = [ - f"Title: {podcast_title}", - f"Error: {error_msg[:50]}", - ] - elif podcast_status in ("ready", "success"): - completed_items = [ - f"Title: {podcast_title}", - "Podcast ready", - ] - else: - completed_items = last_active_step_items - yield _emit_thinking_step( - step_id=original_step_id, - title="Generating podcast", - status="completed", - items=completed_items, - ) - elif tool_name == "generate_video_presentation": - vp_status = ( - tool_output.get("status", "unknown") - if isinstance(tool_output, dict) - else "unknown" - ) - vp_title = ( - tool_output.get("title", "Presentation") - if isinstance(tool_output, dict) - else "Presentation" - ) - if vp_status in ("pending", "generating"): - completed_items = [ - f"Title: {vp_title}", - "Presentation generation started", - "Processing in background...", - ] - elif vp_status == "failed": - error_msg = ( - tool_output.get("error", "Unknown error") - if isinstance(tool_output, dict) - else "Unknown error" - ) - completed_items = [ - f"Title: {vp_title}", - f"Error: {error_msg[:50]}", - ] - else: - completed_items = last_active_step_items - yield _emit_thinking_step( - step_id=original_step_id, - title="Generating video presentation", - status="completed", - items=completed_items, - ) - elif tool_name == "generate_report": - report_status = ( - tool_output.get("status", "unknown") - if isinstance(tool_output, dict) - else "unknown" - ) - report_title = ( - tool_output.get("title", "Report") - if isinstance(tool_output, dict) - else "Report" - ) - word_count = ( - tool_output.get("word_count", 0) - if isinstance(tool_output, dict) - else 0 - ) - is_revision = ( - tool_output.get("is_revision", False) - if isinstance(tool_output, dict) - else False - ) - step_title = "Revising report" if is_revision else "Generating report" - - if report_status == "ready": - completed_items = [ - f"Topic: {report_title}", - f"{word_count:,} words", - "Report ready", - ] - elif report_status == "failed": - error_msg = ( - tool_output.get("error", "Unknown error") - if isinstance(tool_output, dict) - else "Unknown error" - ) - completed_items = [ - f"Topic: {report_title}", - f"Error: {error_msg[:50]}", - ] - else: - completed_items = last_active_step_items - - yield _emit_thinking_step( - step_id=original_step_id, - title=step_title, - status="completed", - items=completed_items, - ) - elif tool_name in ("execute", "execute_code"): - raw_text = ( - tool_output.get("result", "") - if isinstance(tool_output, dict) - else str(tool_output) - ) - m = re.match(r"^Exit code:\s*(\d+)", raw_text) - exit_code_val = int(m.group(1)) if m else None - if exit_code_val is not None and exit_code_val == 0: - completed_items = [ - *last_active_step_items, - "Completed successfully", - ] - elif exit_code_val is not None: - completed_items = [ - *last_active_step_items, - f"Exit code: {exit_code_val}", - ] - else: - completed_items = [*last_active_step_items, "Finished"] - yield _emit_thinking_step( - step_id=original_step_id, - title="Running command", - status="completed", - items=completed_items, - ) - elif tool_name == "ls": - if isinstance(tool_output, dict): - ls_output = tool_output.get("result", "") - elif isinstance(tool_output, str): - ls_output = tool_output - else: - ls_output = str(tool_output) if tool_output else "" - file_names: list[str] = [] - if ls_output: - paths: list[str] = [] - try: - parsed = ast.literal_eval(ls_output) - if isinstance(parsed, list): - paths = [str(p) for p in parsed] - except (ValueError, SyntaxError): - paths = [ - line.strip() - for line in ls_output.strip().split("\n") - if line.strip() - ] - for p in paths: - name = p.rstrip("/").split("/")[-1] - if name and len(name) <= 40: - file_names.append(name) - elif name: - file_names.append(name[:37] + "...") - if file_names: - if len(file_names) <= 5: - completed_items = [f"[{name}]" for name in file_names] - else: - completed_items = [f"[{name}]" for name in file_names[:4]] - completed_items.append(f"(+{len(file_names) - 4} more)") - else: - completed_items = ["No files found"] - yield _emit_thinking_step( - step_id=original_step_id, - title="Listing files", - status="completed", - items=completed_items, - ) - else: - # Fallback completion title — see the matching in-progress - # branch above for the wording rationale. - fallback_title = ( - tool_name.replace("_", " ").strip().capitalize() or tool_name - ) - yield _emit_thinking_step( - step_id=original_step_id, - title=fallback_title, - status="completed", - items=last_active_step_items, - ) - - just_finished_tool = True - last_active_step_id = None - last_active_step_title = "" - last_active_step_items = [] - - if tool_name == "generate_podcast": - yield _emit_tool_output( - tool_call_id, - tool_output - if isinstance(tool_output, dict) - else {"result": tool_output}, - ) - if isinstance(tool_output, dict) and tool_output.get("status") in ( - "pending", - "generating", - "processing", - ): - yield streaming_service.format_terminal_info( - f"Podcast queued: {tool_output.get('title', 'Podcast')}", - "success", - ) - elif isinstance(tool_output, dict) and tool_output.get("status") in ( - "ready", - "success", - ): - yield streaming_service.format_terminal_info( - f"Podcast generated successfully: {tool_output.get('title', 'Podcast')}", - "success", - ) - elif isinstance(tool_output, dict) and tool_output.get("status") in ( - "failed", - "error", - ): - error_msg = tool_output.get("error", "Unknown error") - yield streaming_service.format_terminal_info( - f"Podcast generation failed: {error_msg}", - "error", - ) - elif tool_name == "generate_video_presentation": - yield _emit_tool_output( - tool_call_id, - tool_output - if isinstance(tool_output, dict) - else {"result": tool_output}, - ) - if ( - isinstance(tool_output, dict) - and tool_output.get("status") == "pending" - ): - yield streaming_service.format_terminal_info( - f"Video presentation queued: {tool_output.get('title', 'Presentation')}", - "success", - ) - elif ( - isinstance(tool_output, dict) - and tool_output.get("status") == "failed" - ): - error_msg = ( - tool_output.get("error", "Unknown error") - if isinstance(tool_output, dict) - else "Unknown error" - ) - yield streaming_service.format_terminal_info( - f"Presentation generation failed: {error_msg}", - "error", - ) - elif tool_name == "generate_image": - yield _emit_tool_output( - tool_call_id, - tool_output - if isinstance(tool_output, dict) - else {"result": tool_output}, - ) - if isinstance(tool_output, dict): - if tool_output.get("error"): - yield streaming_service.format_terminal_info( - f"Image generation failed: {tool_output['error'][:60]}", - "error", - ) - else: - yield streaming_service.format_terminal_info( - "Image generated successfully", - "success", - ) - elif tool_name == "scrape_webpage": - if isinstance(tool_output, dict): - display_output = { - k: v for k, v in tool_output.items() if k != "content" - } - if "content" in tool_output: - content = tool_output.get("content", "") - display_output["content_preview"] = ( - content[:500] + "..." if len(content) > 500 else content - ) - yield _emit_tool_output( - tool_call_id, - display_output, - ) - else: - yield _emit_tool_output( - tool_call_id, - {"result": tool_output}, - ) - if isinstance(tool_output, dict) and "error" not in tool_output: - title = tool_output.get("title", "Webpage") - word_count = tool_output.get("word_count", 0) - yield streaming_service.format_terminal_info( - f"Scraped: {title[:40]}{'...' if len(title) > 40 else ''} ({word_count:,} words)", - "success", - ) - else: - error_msg = ( - tool_output.get("error", "Failed to scrape") - if isinstance(tool_output, dict) - else "Failed to scrape" - ) - yield streaming_service.format_terminal_info( - f"Scrape failed: {error_msg}", - "error", - ) - elif tool_name in ("write_file", "edit_file"): - resolved_path = _extract_resolved_file_path( - tool_name=tool_name, - tool_output=tool_output, - tool_input={"file_path": staged_file_path} - if staged_file_path - else None, - ) - result_text = _tool_output_to_text(tool_output) - if _tool_output_has_error(tool_output): - yield _emit_tool_output( - tool_call_id, - { - "status": "error", - "error": result_text, - "path": resolved_path, - }, - ) - else: - yield _emit_tool_output( - tool_call_id, - { - "status": "completed", - "path": resolved_path, - "result": result_text, - }, - ) - elif tool_name == "generate_report": - # Stream the full report result so frontend can render the ReportCard - yield _emit_tool_output( - tool_call_id, - tool_output - if isinstance(tool_output, dict) - else {"result": tool_output}, - ) - # Send appropriate terminal message based on status - if ( - isinstance(tool_output, dict) - and tool_output.get("status") == "ready" - ): - word_count = tool_output.get("word_count", 0) - yield streaming_service.format_terminal_info( - f"Report generated: {tool_output.get('title', 'Report')} ({word_count:,} words)", - "success", - ) - else: - error_msg = ( - tool_output.get("error", "Unknown error") - if isinstance(tool_output, dict) - else "Unknown error" - ) - yield streaming_service.format_terminal_info( - f"Report generation failed: {error_msg}", - "error", - ) - elif tool_name == "generate_resume": - yield _emit_tool_output( - tool_call_id, - tool_output - if isinstance(tool_output, dict) - else {"result": tool_output}, - ) - if ( - isinstance(tool_output, dict) - and tool_output.get("status") == "ready" - ): - yield streaming_service.format_terminal_info( - f"Resume generated: {tool_output.get('title', 'Resume')}", - "success", - ) - else: - error_msg = ( - tool_output.get("error", "Unknown error") - if isinstance(tool_output, dict) - else "Unknown error" - ) - yield streaming_service.format_terminal_info( - f"Resume generation failed: {error_msg}", - "error", - ) - elif tool_name in ( - "create_notion_page", - "update_notion_page", - "delete_notion_page", - "create_linear_issue", - "update_linear_issue", - "delete_linear_issue", - "create_google_drive_file", - "delete_google_drive_file", - "create_onedrive_file", - "delete_onedrive_file", - "create_dropbox_file", - "delete_dropbox_file", - "create_gmail_draft", - "update_gmail_draft", - "send_gmail_email", - "trash_gmail_email", - "create_calendar_event", - "update_calendar_event", - "delete_calendar_event", - "create_jira_issue", - "update_jira_issue", - "delete_jira_issue", - "create_confluence_page", - "update_confluence_page", - "delete_confluence_page", - ): - yield _emit_tool_output( - tool_call_id, - tool_output - if isinstance(tool_output, dict) - else {"result": tool_output}, - ) - elif tool_name in ("execute", "execute_code"): - raw_text = ( - tool_output.get("result", "") - if isinstance(tool_output, dict) - else str(tool_output) - ) - exit_code: int | None = None - output_text = raw_text - m = re.match(r"^Exit code:\s*(\d+)", raw_text) - if m: - exit_code = int(m.group(1)) - om = re.search(r"\nOutput:\n([\s\S]*)", raw_text) - output_text = om.group(1) if om else "" - thread_id_str = config.get("configurable", {}).get("thread_id", "") - - for sf_match in re.finditer( - r"^SANDBOX_FILE:\s*(.+)$", output_text, re.MULTILINE - ): - fpath = sf_match.group(1).strip() - if fpath and fpath not in result.sandbox_files: - result.sandbox_files.append(fpath) - - yield _emit_tool_output( - tool_call_id, - { - "exit_code": exit_code, - "output": output_text, - "thread_id": thread_id_str, - }, - ) - elif tool_name == "web_search": - xml = ( - tool_output.get("result", str(tool_output)) - if isinstance(tool_output, dict) - else str(tool_output) - ) - citations: dict[str, dict[str, str]] = {} - for m in re.finditer( - r"<!\[CDATA\[(.*?)\]\]>\s*", - xml, - ): - title, url = m.group(1).strip(), m.group(2).strip() - if url.startswith("http") and url not in citations: - citations[url] = {"title": title} - for m in re.finditer( - r"", - xml, - ): - chunk_url, content = m.group(1).strip(), m.group(2).strip() - if ( - chunk_url.startswith("http") - and chunk_url in citations - and content - ): - citations[chunk_url]["snippet"] = ( - content[:200] + "…" if len(content) > 200 else content - ) - yield _emit_tool_output( - tool_call_id, - {"status": "completed", "citations": citations}, - ) - else: - yield _emit_tool_output( - tool_call_id, - {"status": "completed", "result_length": len(str(tool_output))}, - ) - yield streaming_service.format_terminal_info( - f"Tool {tool_name} completed", "success" - ) - - elif event_type == "on_custom_event" and event.get("name") == "report_progress": - # Live progress updates from inside the generate_report tool - data = event.get("data", {}) - message = data.get("message", "") - if message and last_active_step_id: - phase = data.get("phase", "") - # Always keep the "Topic: ..." line - topic_items = [ - item for item in last_active_step_items if item.startswith("Topic:") - ] - - if phase in ("revising_section", "adding_section"): - # During section-level ops: keep plan summary + show current op - plan_items = [ - item - for item in last_active_step_items - if item.startswith("Topic:") - or item.startswith("Modifying ") - or item.startswith("Adding ") - or item.startswith("Removing ") - ] - # Only keep plan_items that don't end with "..." (not progress lines) - plan_items = [ - item for item in plan_items if not item.endswith("...") - ] - last_active_step_items = [*plan_items, message] - else: - # Phase transitions: replace everything after topic - last_active_step_items = [*topic_items, message] - - yield _emit_thinking_step( - step_id=last_active_step_id, - title=last_active_step_title, - status="in_progress", - items=last_active_step_items, - ) - - elif ( - event_type == "on_custom_event" and event.get("name") == "document_created" - ): - data = event.get("data", {}) - if data.get("id"): - yield streaming_service.format_data( - "documents-updated", - { - "action": "created", - "document": data, - }, - ) - - elif event_type == "on_custom_event" and event.get("name") == "action_log": - # Surface a freshly committed AgentActionLog row so the chat - # tool card can render its Revert button immediately. - data = event.get("data", {}) - if data.get("id") is not None: - yield streaming_service.format_data("action-log", data) - - elif ( - event_type == "on_custom_event" - and event.get("name") == "action_log_updated" - ): - # Reversibility flipped in kb_persistence after the SAVEPOINT - # for a destructive op (rm/rmdir/move/edit/write) committed. - # Frontend uses this to flip the card's Revert - # button on without re-fetching the actions list. - data = event.get("data", {}) - if data.get("id") is not None: - yield streaming_service.format_data("action-log-updated", data) - - elif event_type in ("on_chain_end", "on_agent_end"): - if current_text_id is not None: - yield streaming_service.format_text_end(current_text_id) - if content_builder is not None: - content_builder.on_text_end(current_text_id) - current_text_id = None - - if current_text_id is not None: - yield streaming_service.format_text_end(current_text_id) - if content_builder is not None: - content_builder.on_text_end(current_text_id) - - completion_event = complete_current_step() - if completion_event: - yield completion_event + accumulated_text = result.accumulated_text state = await agent.aget_state(config) state_values = getattr(state, "values", {}) or {} @@ -2362,7 +913,6 @@ async def _stream_agent_events( result.commit_gate_reason = "" result.accumulated_text = accumulated_text - result.agent_called_update_memory = called_update_memory _log_file_contract("turn_outcome", result) interrupt_value = _first_interrupt_value(state) @@ -2380,6 +930,7 @@ async def stream_new_chat( llm_config_id: int = -1, mentioned_document_ids: list[int] | None = None, mentioned_surfsense_doc_ids: list[int] | None = None, + mentioned_folder_ids: list[int] | None = None, mentioned_documents: list[dict[str, Any]] | None = None, checkpoint_id: str | None = None, needs_history_bootstrap: bool = False, @@ -2409,6 +960,7 @@ async def stream_new_chat( needs_history_bootstrap: If True, load message history from DB (for cloned chats) mentioned_document_ids: Optional list of document IDs mentioned with @ in the chat mentioned_surfsense_doc_ids: Optional list of SurfSense doc IDs mentioned with @ in the chat + mentioned_folder_ids: Optional list of knowledge-base folder IDs mentioned with @ (cloud mode) checkpoint_id: Optional checkpoint ID to rewind/fork from (for edit/reload operations) Yields: @@ -2767,7 +1319,7 @@ async def stream_new_chat( _t0 = time.perf_counter() agent_factory = ( - create_registry_deep_agent + create_multi_agent_chat_deep_agent if use_multi_agent else create_surfsense_deep_agent ) @@ -2776,7 +1328,8 @@ async def stream_new_chat( # if preflight reports 429 we will discard this future and rebuild # against the freshly pinned config below. agent_build_task = asyncio.create_task( - agent_factory( + _build_main_agent_for_thread( + agent_factory, llm=llm, search_space_id=search_space_id, db_session=session, @@ -2787,9 +1340,9 @@ async def stream_new_chat( agent_config=agent_config, firecrawl_api_key=firecrawl_api_key, thread_visibility=visibility, + filesystem_selection=filesystem_selection, disabled_tools=disabled_tools, mentioned_document_ids=mentioned_document_ids, - filesystem_selection=filesystem_selection, ), name="agent_build:stream_new_chat", ) @@ -2952,6 +1505,53 @@ async def stream_new_chat( ) recent_reports = list(recent_reports_result.scalars().all()) + # Resolve @-mention chips to canonical virtual paths and rewrite + # the user-typed text so the LLM sees ``\`/documents/...\``` instead + # of bare ``@title``. The persisted user-message text keeps + # ``@title`` so chip rendering on reload is unchanged — see + # ``persistence._build_user_content``. + # + # Cloud mode only: local-folder mode keeps the legacy + # ``@title`` text path; mention support there is a follow-up + # task because the path scheme (mount-rooted) and the picker + # UI both need separate work. + accepted_folder_ids: list[int] = [] + if fs_mode == FilesystemMode.CLOUD.value and ( + mentioned_document_ids + or mentioned_surfsense_doc_ids + or mentioned_folder_ids + or mentioned_documents + ): + from app.schemas.new_chat import ( + MentionedDocumentInfo as _MentionedDocumentInfo, + ) + + chip_objs: list[_MentionedDocumentInfo] | None = None + if mentioned_documents: + chip_objs = [] + for raw in mentioned_documents: + if isinstance(raw, _MentionedDocumentInfo): + chip_objs.append(raw) + continue + try: + chip_objs.append(_MentionedDocumentInfo.model_validate(raw)) + except Exception: + logger.debug( + "stream_new_chat: dropping malformed mention chip %r", + raw, + ) + + resolved = await resolve_mentions( + session, + search_space_id=search_space_id, + mentioned_documents=chip_objs, + mentioned_document_ids=mentioned_document_ids, + mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids, + mentioned_folder_ids=mentioned_folder_ids, + ) + user_query = substitute_in_text(user_query, resolved.token_to_path) + accepted_folder_ids = resolved.mentioned_folder_ids + # Format the user query with context (SurfSense docs + reports only) final_query = user_query context_parts = [] @@ -3351,6 +1951,9 @@ async def stream_new_chat( runtime_context = SurfSenseContextSchema( search_space_id=search_space_id, mentioned_document_ids=list(mentioned_document_ids or []), + mentioned_folder_ids=list( + accepted_folder_ids or mentioned_folder_ids or [] + ), request_id=request_id, turn_id=stream_result.turn_id, ) @@ -3466,7 +2069,8 @@ async def stream_new_chat( title_task = None _t0 = time.perf_counter() - agent = await create_surfsense_deep_agent( + agent = await _build_main_agent_for_thread( + agent_factory, llm=llm, search_space_id=search_space_id, db_session=session, @@ -3477,9 +2081,9 @@ async def stream_new_chat( agent_config=agent_config, firecrawl_api_key=firecrawl_api_key, thread_visibility=visibility, + filesystem_selection=filesystem_selection, disabled_tools=disabled_tools, mentioned_document_ids=mentioned_document_ids, - filesystem_selection=filesystem_selection, ) _perf_log.info( "[stream_new_chat] Runtime rate-limit recovery repinned " @@ -4130,12 +2734,13 @@ async def stream_resume_chat( _t0 = time.perf_counter() agent_factory = ( - create_registry_deep_agent + create_multi_agent_chat_deep_agent if _app_config.MULTI_AGENT_CHAT_ENABLED else create_surfsense_deep_agent ) agent_build_task = asyncio.create_task( - agent_factory( + _build_main_agent_for_thread( + agent_factory, llm=llm, search_space_id=search_space_id, db_session=session, @@ -4224,7 +2829,8 @@ async def stream_resume_chat( "fallback_config_id": llm_config_id, }, ) - agent = await agent_factory( + agent = await _build_main_agent_for_thread( + agent_factory, llm=llm, search_space_id=search_space_id, db_session=session, @@ -4409,7 +3015,8 @@ async def stream_resume_chat( raise stream_exc _t0 = time.perf_counter() - agent = await create_surfsense_deep_agent( + agent = await _build_main_agent_for_thread( + agent_factory, llm=llm, search_space_id=search_space_id, db_session=session, @@ -4421,6 +3028,7 @@ async def stream_resume_chat( firecrawl_api_key=firecrawl_api_key, thread_visibility=visibility, filesystem_selection=filesystem_selection, + disabled_tools=disabled_tools, ) _perf_log.info( "[stream_resume] Runtime rate-limit recovery repinned " diff --git a/surfsense_backend/app/tasks/chat/streaming/__init__.py b/surfsense_backend/app/tasks/chat/streaming/__init__.py new file mode 100644 index 000000000..70c99342a --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/__init__.py @@ -0,0 +1,3 @@ +"""Chat streaming helpers (e.g. LangGraph → SSE relay under ``graph_stream``).""" + +from __future__ import annotations diff --git a/surfsense_backend/app/tasks/chat/streaming/errors/__init__.py b/surfsense_backend/app/tasks/chat/streaming/errors/__init__.py new file mode 100644 index 000000000..02284d4b0 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/errors/__init__.py @@ -0,0 +1,3 @@ +"""Error classification, structured logging, and terminal-error SSE emission.""" + +from __future__ import annotations diff --git a/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py b/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py new file mode 100644 index 000000000..3af2b9f9f --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py @@ -0,0 +1,187 @@ +"""Classify stream exceptions for logging and client error payloads.""" + +from __future__ import annotations + +import json +import logging +import time +from typing import Any, Literal + +from app.agents.new_chat.errors import BusyError +from app.agents.new_chat.middleware.busy_mutex import ( + get_cancel_state, + is_cancel_requested, +) + +TURN_CANCELLING_INITIAL_DELAY_MS = 200 +TURN_CANCELLING_BACKOFF_FACTOR = 2 +TURN_CANCELLING_MAX_DELAY_MS = 1500 + + +def compute_turn_cancelling_retry_delay(attempt: int) -> int: + if attempt < 1: + attempt = 1 + delay = TURN_CANCELLING_INITIAL_DELAY_MS * ( + TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1) + ) + return min(delay, TURN_CANCELLING_MAX_DELAY_MS) + + +def log_chat_stream_error( + *, + flow: Literal["new", "resume", "regenerate"], + error_kind: str, + error_code: str | None, + severity: Literal["info", "warn", "error"], + is_expected: bool, + request_id: str | None, + thread_id: int | None, + search_space_id: int | None, + user_id: str | None, + message: str, + extra: dict[str, Any] | None = None, +) -> None: + payload: dict[str, Any] = { + "event": "chat_stream_error", + "flow": flow, + "error_kind": error_kind, + "error_code": error_code, + "severity": severity, + "is_expected": is_expected, + "request_id": request_id or "unknown", + "thread_id": thread_id, + "search_space_id": search_space_id, + "user_id": user_id, + "message": message, + } + if extra: + payload.update(extra) + + logger = logging.getLogger(__name__) + rendered = json.dumps(payload, ensure_ascii=False) + if severity == "error": + logger.error("[chat_stream_error] %s", rendered) + elif severity == "warn": + logger.warning("[chat_stream_error] %s", rendered) + else: + logger.info("[chat_stream_error] %s", rendered) + + +def _parse_error_payload(message: str) -> dict[str, Any] | None: + candidates = [message] + first_brace_idx = message.find("{") + if first_brace_idx >= 0: + candidates.append(message[first_brace_idx:]) + + for candidate in candidates: + try: + parsed = json.loads(candidate) + if isinstance(parsed, dict): + return parsed + except Exception: + continue + return None + + +def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None: + if not isinstance(parsed, dict): + return None + candidates: list[Any] = [parsed.get("code")] + nested = parsed.get("error") + if isinstance(nested, dict): + candidates.append(nested.get("code")) + for value in candidates: + try: + if value is None: + continue + return int(value) + except Exception: + continue + return None + + +def is_provider_rate_limited(exc: BaseException) -> bool: + """Return True if the exception looks like an upstream HTTP 429 / rate limit.""" + raw = str(exc) + lowered = raw.lower() + if "ratelimit" in type(exc).__name__.lower(): + return True + parsed = _parse_error_payload(raw) + provider_code = _extract_provider_error_code(parsed) + if provider_code == 429: + return True + + provider_error_type = "" + if parsed: + top_type = parsed.get("type") + if isinstance(top_type, str): + provider_error_type = top_type.lower() + nested = parsed.get("error") + if isinstance(nested, dict): + nested_type = nested.get("type") + if isinstance(nested_type, str): + provider_error_type = nested_type.lower() + if provider_error_type == "rate_limit_error": + return True + + return ( + "rate limited" in lowered + or "rate-limited" in lowered + or "temporarily rate-limited upstream" in lowered + ) + + +def classify_stream_exception( + exc: Exception, + *, + flow_label: str, +) -> tuple[ + str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None +]: + """Return kind, code, severity, expected flag, message, and optional extra dict.""" + raw = str(exc) + if isinstance(exc, BusyError) or "Thread is busy with another request" in raw: + busy_thread_id = str(exc.request_id) if isinstance(exc, BusyError) else None + if busy_thread_id and is_cancel_requested(busy_thread_id): + cancel_state = get_cancel_state(busy_thread_id) + attempt = cancel_state[0] if cancel_state else 1 + retry_after_ms = compute_turn_cancelling_retry_delay(attempt) + retry_after_at = int(time.time() * 1000) + retry_after_ms + return ( + "thread_busy", + "TURN_CANCELLING", + "info", + True, + "A previous response is still stopping. Please try again in a moment.", + { + "retry_after_ms": retry_after_ms, + "retry_after_at": retry_after_at, + }, + ) + return ( + "thread_busy", + "THREAD_BUSY", + "warn", + True, + "Another response is still finishing for this thread. Please try again in a moment.", + None, + ) + + if is_provider_rate_limited(exc): + return ( + "rate_limited", + "RATE_LIMITED", + "warn", + True, + "This model is temporarily rate-limited. Please try again in a few seconds or switch models.", + None, + ) + + return ( + "server_error", + "SERVER_ERROR", + "error", + False, + f"Error during {flow_label}: {raw}", + None, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/errors/emitter.py b/surfsense_backend/app/tasks/chat/streaming/errors/emitter.py new file mode 100644 index 000000000..95806ab87 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/errors/emitter.py @@ -0,0 +1,38 @@ +"""Emit one terminal error SSE frame and log via the stream error classifier.""" + +from __future__ import annotations + +from typing import Any, Literal + +from .classifier import log_chat_stream_error + + +def emit_stream_terminal_error( + *, + streaming_service: Any, + flow: Literal["new", "resume", "regenerate"], + request_id: str | None, + thread_id: int, + search_space_id: int, + user_id: str | None, + message: str, + error_kind: str = "server_error", + error_code: str = "SERVER_ERROR", + severity: Literal["info", "warn", "error"] = "error", + is_expected: bool = False, + extra: dict[str, Any] | None = None, +) -> str: + log_chat_stream_error( + flow=flow, + error_kind=error_kind, + error_code=error_code, + severity=severity, + is_expected=is_expected, + request_id=request_id, + thread_id=thread_id, + search_space_id=search_space_id, + user_id=user_id, + message=message, + extra=extra, + ) + return streaming_service.format_error(message, error_code=error_code, extra=extra) diff --git a/surfsense_backend/app/tasks/chat/streaming/graph_stream/__init__.py b/surfsense_backend/app/tasks/chat/streaming/graph_stream/__init__.py new file mode 100644 index 000000000..e3bf0426c --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/graph_stream/__init__.py @@ -0,0 +1,21 @@ +"""LangGraph ``astream_events`` → SSE (``stream_output`` + ``StreamingResult``). + +Imports are lazy to avoid a circular import with ``relay.event_relay``. +""" + +from __future__ import annotations + +__all__ = ["StreamingResult", "stream_output"] + + +def __getattr__(name: str): + if name == "stream_output": + from app.tasks.chat.streaming.graph_stream.event_stream import stream_output + + return stream_output + if name == "StreamingResult": + from app.tasks.chat.streaming.graph_stream.result import StreamingResult + + return StreamingResult + msg = f"module {__name__!r} has no attribute {name!r}" + raise AttributeError(msg) diff --git a/surfsense_backend/app/tasks/chat/streaming/graph_stream/event_stream.py b/surfsense_backend/app/tasks/chat/streaming/graph_stream/event_stream.py new file mode 100644 index 000000000..9a309f9d7 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/graph_stream/event_stream.py @@ -0,0 +1,51 @@ +"""Run LangGraph event streams through ``EventRelay``.""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import Any + +from app.tasks.chat.streaming.graph_stream.result import StreamingResult +from app.tasks.chat.streaming.relay.event_relay import EventRelay +from app.tasks.chat.streaming.relay.state import AgentEventRelayState + + +async def stream_output( + *, + agent: Any, + config: dict[str, Any], + input_data: Any, + streaming_service: Any, + result: StreamingResult, + step_prefix: str = "thinking", + initial_step_id: str | None = None, + initial_step_title: str = "", + initial_step_items: list[str] | None = None, + content_builder: Any | None = None, + runtime_context: Any = None, +) -> AsyncIterator[str]: + """Yield SSE frames from agent ``astream_events`` via ``EventRelay``.""" + state = AgentEventRelayState.for_invocation( + initial_step_id=initial_step_id, + initial_step_title=initial_step_title, + initial_step_items=initial_step_items, + ) + + astream_kwargs: dict[str, Any] = {"config": config, "version": "v2"} + if runtime_context is not None: + astream_kwargs["context"] = runtime_context + + events = agent.astream_events(input_data, **astream_kwargs) + relay = EventRelay(streaming_service=streaming_service) + async for frame in relay.relay( + events, + state=state, + result=result, + step_prefix=step_prefix, + content_builder=content_builder, + config=config, + ): + yield frame + + result.accumulated_text = state.accumulated_text + result.agent_called_update_memory = state.called_update_memory diff --git a/surfsense_backend/app/tasks/chat/streaming/graph_stream/result.py b/surfsense_backend/app/tasks/chat/streaming/graph_stream/result.py new file mode 100644 index 000000000..40404e9d0 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/graph_stream/result.py @@ -0,0 +1,28 @@ +"""Mutable facts collected while relaying one agent stream (``stream_output``).""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class StreamingResult: + accumulated_text: str = "" + is_interrupted: bool = False + interrupt_value: dict[str, Any] | None = None + sandbox_files: list[str] = field(default_factory=list) + agent_called_update_memory: bool = False + request_id: str | None = None + turn_id: str = "" + filesystem_mode: str = "cloud" + client_platform: str = "web" + intent_detected: str = "chat_only" + intent_confidence: float = 0.0 + write_attempted: bool = False + write_succeeded: bool = False + verification_succeeded: bool = False + commit_gate_passed: bool = True + commit_gate_reason: str = "" + assistant_message_id: int | None = None + content_builder: Any | None = field(default=None, repr=False) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/__init__.py new file mode 100644 index 000000000..3e2165932 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/__init__.py @@ -0,0 +1,3 @@ +"""LangGraph stream handlers by event kind.""" + +from __future__ import annotations diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/chain_end.py b/surfsense_backend/app/tasks/chat/streaming/handlers/chain_end.py new file mode 100644 index 000000000..c61058ac7 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/chain_end.py @@ -0,0 +1,23 @@ +"""Close open text when a LangGraph chain or agent node finishes.""" + +from __future__ import annotations + +from collections.abc import Iterator +from typing import Any + +from app.tasks.chat.streaming.relay.state import AgentEventRelayState + + +def iter_chain_end_frames( + _event: dict[str, Any], + *, + state: AgentEventRelayState, + streaming_service: Any, + content_builder: Any | None, +) -> Iterator[str]: + """Close the open text stream if one is open.""" + if state.current_text_id is not None: + yield streaming_service.format_text_end(state.current_text_id) + if content_builder is not None: + content_builder.on_text_end(state.current_text_id) + state.current_text_id = None diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/chat_model_stream.py b/surfsense_backend/app/tasks/chat/streaming/handlers/chat_model_stream.py new file mode 100644 index 000000000..c3f6d6d59 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/chat_model_stream.py @@ -0,0 +1,159 @@ +"""Chat model stream: text, reasoning, and tool-call chunk SSE.""" + +from __future__ import annotations + +from collections.abc import Iterator +from typing import Any + +from app.tasks.chat.streaming.helpers.chunk_parts import extract_chunk_parts +from app.tasks.chat.streaming.relay.state import AgentEventRelayState +from app.tasks.chat.streaming.relay.task_span import ensure_pending_task_span_for_lc +from app.tasks.chat.streaming.relay.thinking_step_completion import ( + complete_active_thinking_step, +) + + +def iter_chat_model_stream_frames( + event: dict[str, Any], + *, + state: AgentEventRelayState, + streaming_service: Any, + content_builder: Any | None, + step_prefix: str, +) -> Iterator[str]: + """SSE frames for one chat-model chunk.""" + if state.active_tool_depth > 0: + return + if "surfsense:internal" in event.get("tags", []): + return + chunk = event.get("data", {}).get("chunk") + if not chunk: + return + parts = extract_chunk_parts(chunk) + + reasoning_delta = parts["reasoning"] + text_delta = parts["text"] + + if reasoning_delta: + if state.current_text_id is not None: + yield streaming_service.format_text_end(state.current_text_id) + if content_builder is not None: + content_builder.on_text_end(state.current_text_id) + state.current_text_id = None + if state.current_reasoning_id is None: + comp, new_active = complete_active_thinking_step( + state=state, + streaming_service=streaming_service, + content_builder=content_builder, + last_active_step_id=state.last_active_step_id, + last_active_step_title=state.last_active_step_title, + last_active_step_items=state.last_active_step_items, + completed_step_ids=state.completed_step_ids, + ) + if comp: + yield comp + state.last_active_step_id = new_active + if state.just_finished_tool: + state.last_active_step_id = None + state.last_active_step_title = "" + state.last_active_step_items = [] + state.just_finished_tool = False + state.current_reasoning_id = streaming_service.generate_reasoning_id() + yield streaming_service.format_reasoning_start(state.current_reasoning_id) + if content_builder is not None: + content_builder.on_reasoning_start(state.current_reasoning_id) + yield streaming_service.format_reasoning_delta( + state.current_reasoning_id, reasoning_delta + ) + if content_builder is not None: + content_builder.on_reasoning_delta( + state.current_reasoning_id, reasoning_delta + ) + + if text_delta: + if state.current_reasoning_id is not None: + yield streaming_service.format_reasoning_end(state.current_reasoning_id) + if content_builder is not None: + content_builder.on_reasoning_end(state.current_reasoning_id) + state.current_reasoning_id = None + if state.current_text_id is None: + comp, new_active = complete_active_thinking_step( + state=state, + streaming_service=streaming_service, + content_builder=content_builder, + last_active_step_id=state.last_active_step_id, + last_active_step_title=state.last_active_step_title, + last_active_step_items=state.last_active_step_items, + completed_step_ids=state.completed_step_ids, + ) + if comp: + yield comp + state.last_active_step_id = new_active + if state.just_finished_tool: + state.last_active_step_id = None + state.last_active_step_title = "" + state.last_active_step_items = [] + state.just_finished_tool = False + state.current_text_id = streaming_service.generate_text_id() + yield streaming_service.format_text_start(state.current_text_id) + if content_builder is not None: + content_builder.on_text_start(state.current_text_id) + yield streaming_service.format_text_delta(state.current_text_id, text_delta) + state.accumulated_text += text_delta + if content_builder is not None: + content_builder.on_text_delta(state.current_text_id, text_delta) + + if parts["tool_call_chunks"]: + for tcc in parts["tool_call_chunks"]: + idx = tcc.get("index") + + if idx is not None and idx not in state.index_to_meta: + lc_id = tcc.get("id") + name = tcc.get("name") + if lc_id and name: + ui_id = lc_id + tool_input_metadata: dict[str, Any] | None = None + if name == "task": + sid = ensure_pending_task_span_for_lc(state, str(lc_id)) + tool_input_metadata = {"spanId": sid} + + if state.current_text_id is not None: + yield streaming_service.format_text_end(state.current_text_id) + if content_builder is not None: + content_builder.on_text_end(state.current_text_id) + state.current_text_id = None + if state.current_reasoning_id is not None: + yield streaming_service.format_reasoning_end( + state.current_reasoning_id + ) + if content_builder is not None: + content_builder.on_reasoning_end(state.current_reasoning_id) + state.current_reasoning_id = None + + state.index_to_meta[idx] = { + "ui_id": ui_id, + "lc_id": lc_id, + "name": name, + } + yield streaming_service.format_tool_input_start( + ui_id, + name, + langchain_tool_call_id=lc_id, + metadata=tool_input_metadata, + ) + if content_builder is not None: + content_builder.on_tool_input_start( + ui_id, name, lc_id, metadata=tool_input_metadata + ) + + meta = state.index_to_meta.get(idx) if idx is not None else None + if meta: + args_chunk = tcc.get("args") or "" + if args_chunk: + yield streaming_service.format_tool_input_delta( + meta["ui_id"], args_chunk + ) + if content_builder is not None: + content_builder.on_tool_input_delta(meta["ui_id"], args_chunk) + else: + state.pending_tool_call_chunks.append(tcc) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/custom_event_dispatch.py b/surfsense_backend/app/tasks/chat/streaming/handlers/custom_event_dispatch.py new file mode 100644 index 000000000..69f4b8a24 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/custom_event_dispatch.py @@ -0,0 +1,57 @@ +"""Custom graph events routed to SSE (documents, action logs, report progress).""" + +from __future__ import annotations + +from collections.abc import Iterator +from typing import Any + +from app.tasks.chat.streaming.handlers.custom_events import ( + handle_action_log, + handle_action_log_updated, + handle_document_created, + handle_report_progress, +) +from app.tasks.chat.streaming.relay.state import AgentEventRelayState + + +def iter_custom_event_frames( + event: dict[str, Any], + *, + state: AgentEventRelayState, + streaming_service: Any, + content_builder: Any | None, +) -> Iterator[str]: + """Yield any SSE produced by ad-hoc graph events (documents, action logs, report progress).""" + name = event.get("name") + data = event.get("data", {}) + + if name == "report_progress": + frame, state.last_active_step_items = handle_report_progress( + data, + last_active_step_id=state.last_active_step_id, + last_active_step_title=state.last_active_step_title, + last_active_step_items=state.last_active_step_items, + streaming_service=streaming_service, + content_builder=content_builder, + thinking_metadata=state.span_metadata_if_active(), + ) + if frame: + yield frame + return + + if name == "document_created": + frame = handle_document_created(data, streaming_service=streaming_service) + if frame: + yield frame + return + + if name == "action_log": + frame = handle_action_log(data, streaming_service=streaming_service) + if frame: + yield frame + return + + if name == "action_log_updated": + frame = handle_action_log_updated(data, streaming_service=streaming_service) + if frame: + yield frame diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/custom_events.py b/surfsense_backend/app/tasks/chat/streaming/handlers/custom_events.py new file mode 100644 index 000000000..81116e205 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/custom_events.py @@ -0,0 +1,79 @@ +"""Custom-event payloads turned into SSE (no model/tool stream handling).""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame + + +def handle_report_progress( + data: dict[str, Any], + *, + last_active_step_id: str | None, + last_active_step_title: str, + last_active_step_items: list[str], + streaming_service: Any, + content_builder: Any | None, + thinking_metadata: dict[str, Any] | None = None, +) -> tuple[str | None, list[str]]: + """Update report step items; may emit one thinking SSE frame. + + Returns (frame or None, items list after update). + """ + message = data.get("message", "") + if not message or not last_active_step_id: + return None, last_active_step_items + + phase = data.get("phase", "") + topic_items = [item for item in last_active_step_items if item.startswith("Topic:")] + + if phase in ("revising_section", "adding_section"): + plan_items = [ + item + for item in last_active_step_items + if item.startswith("Topic:") + or item.startswith("Modifying ") + or item.startswith("Adding ") + or item.startswith("Removing ") + ] + plan_items = [item for item in plan_items if not item.endswith("...")] + new_items = [*plan_items, message] + else: + new_items = [*topic_items, message] + + frame = emit_thinking_step_frame( + streaming_service=streaming_service, + content_builder=content_builder, + step_id=last_active_step_id, + title=last_active_step_title, + status="in_progress", + items=new_items, + metadata=thinking_metadata, + ) + return frame, new_items + + +def handle_document_created( + data: dict[str, Any], *, streaming_service: Any +) -> str | None: + if not data.get("id"): + return None + return streaming_service.format_data( + "documents-updated", + {"action": "created", "document": data}, + ) + + +def handle_action_log(data: dict[str, Any], *, streaming_service: Any) -> str | None: + if data.get("id") is None: + return None + return streaming_service.format_data("action-log", data) + + +def handle_action_log_updated( + data: dict[str, Any], *, streaming_service: Any +) -> str | None: + if data.get("id") is None: + return None + return streaming_service.format_data("action-log-updated", data) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tool_end.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tool_end.py new file mode 100644 index 000000000..57ab617c5 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tool_end.py @@ -0,0 +1,119 @@ +"""Tool end: thinking completion, tool output, and terminal SSE.""" + +from __future__ import annotations + +import json +from collections.abc import Iterator +from typing import Any + +from app.tasks.chat.streaming.handlers.tools import ( + ToolCompletionEmissionContext, + iter_tool_completion_emission_frames, + resolve_tool_completed_thinking_step, +) +from app.tasks.chat.streaming.helpers.tool_output import tool_output_has_error +from app.tasks.chat.streaming.relay.state import AgentEventRelayState +from app.tasks.chat.streaming.relay.task_span import ( + clear_task_span_if_delegating_task_ended, +) +from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame + + +def iter_tool_end_frames( + event: dict[str, Any], + *, + state: AgentEventRelayState, + streaming_service: Any, + content_builder: Any | None, + result: Any, + step_prefix: str, + config: dict[str, Any], +) -> Iterator[str]: + """SSE frames when one tool run finishes.""" + state.active_tool_depth = max(0, state.active_tool_depth - 1) + run_id = event.get("run_id", "") + tool_name = event.get("name", "unknown_tool") + raw_output = event.get("data", {}).get("output", "") + staged_file_path = state.file_path_by_run.pop(run_id, None) if run_id else None + + if tool_name == "update_memory": + state.called_update_memory = True + + if hasattr(raw_output, "content"): + content = raw_output.content + if isinstance(content, str): + try: + tool_output = json.loads(content) + except (json.JSONDecodeError, TypeError): + tool_output = {"result": content} + elif isinstance(content, dict): + tool_output = content + else: + tool_output = {"result": str(content)} + elif isinstance(raw_output, dict): + tool_output = raw_output + else: + tool_output = {"result": str(raw_output) if raw_output else "completed"} + + if tool_name in ("write_file", "edit_file"): + if tool_output_has_error(tool_output): + pass + else: + result.write_succeeded = True + result.verification_succeeded = True + + tool_call_id = state.ui_tool_call_id_by_run.get( + run_id, + f"call_{run_id[:32]}" if run_id else "call_unknown", + ) + original_step_id = state.tool_step_ids.get( + run_id, f"{step_prefix}-unknown-{run_id[:8]}" + ) + state.completed_step_ids.add(original_step_id) + + holder = state.current_lc_tool_call_id + holder["value"] = None + authoritative = getattr(raw_output, "tool_call_id", None) + if isinstance(authoritative, str) and authoritative: + holder["value"] = authoritative + if run_id: + state.lc_tool_call_id_by_run[run_id] = authoritative + elif run_id and run_id in state.lc_tool_call_id_by_run: + holder["value"] = state.lc_tool_call_id_by_run[run_id] + + items = state.last_active_step_items + title, completed_items = resolve_tool_completed_thinking_step( + tool_name, tool_output, items + ) + yield emit_thinking_step_frame( + streaming_service=streaming_service, + content_builder=content_builder, + step_id=original_step_id, + title=title, + status="completed", + items=completed_items, + metadata=state.span_metadata_if_active(), + ) + + state.just_finished_tool = True + state.last_active_step_id = None + state.last_active_step_title = "" + state.last_active_step_items = [] + + emission_ctx = ToolCompletionEmissionContext( + tool_name=tool_name, + tool_call_id=tool_call_id, + tool_output=tool_output, + streaming_service=streaming_service, + content_builder=content_builder, + langchain_tool_call_id_holder=holder, + stream_result=result, + langgraph_config=config, + staged_workspace_file_path=staged_file_path, + tool_metadata=state.tool_activity_metadata( + thinking_step_id=original_step_id, + ), + ) + yield from iter_tool_completion_emission_frames(emission_ctx) + + clear_task_span_if_delegating_task_ended(state, tool_name=tool_name, run_id=run_id) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tool_output_frame.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tool_output_frame.py new file mode 100644 index 000000000..4cd8e3274 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tool_output_frame.py @@ -0,0 +1,29 @@ +"""Emit tool-output SSE and optional assistant content updates.""" + +from __future__ import annotations + +from typing import Any + + +def emit_tool_output_available_frame( + *, + streaming_service: Any, + content_builder: Any | None, + langchain_id_holder: dict[str, str | None], + call_id: str, + output: Any, + tool_metadata: dict[str, Any] | None = None, +) -> str: + if content_builder is not None: + content_builder.on_tool_output_available( + call_id, + output, + langchain_id_holder["value"], + metadata=tool_metadata, + ) + return streaming_service.format_tool_output_available( + call_id, + output, + langchain_tool_call_id=langchain_id_holder["value"], + metadata=tool_metadata, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tool_start.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tool_start.py new file mode 100644 index 000000000..e0cac307c --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tool_start.py @@ -0,0 +1,161 @@ +"""Tool start: thinking-step and tool-input SSE.""" + +from __future__ import annotations + +import json +from collections.abc import Iterator +from typing import Any + +from app.tasks.chat.streaming.handlers.tools import resolve_tool_start_thinking +from app.tasks.chat.streaming.helpers.tool_call_matching import ( + match_buffered_langchain_tool_call_id, +) +from app.tasks.chat.streaming.relay.state import AgentEventRelayState +from app.tasks.chat.streaming.relay.task_span import open_task_span +from app.tasks.chat.streaming.relay.thinking_step_completion import ( + complete_active_thinking_step, +) +from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame + + +def iter_tool_start_frames( + event: dict[str, Any], + *, + state: AgentEventRelayState, + streaming_service: Any, + content_builder: Any | None, + result: Any, + step_prefix: str, +) -> Iterator[str]: + """SSE frames for the start of one tool run.""" + state.active_tool_depth += 1 + tool_name = event.get("name", "unknown_tool") + run_id = event.get("run_id", "") + tool_input = event.get("data", {}).get("input", {}) + if tool_name in ("write_file", "edit_file"): + result.write_attempted = True + if isinstance(tool_input, dict): + file_path = tool_input.get("file_path") + if isinstance(file_path, str) and file_path.strip() and run_id: + state.file_path_by_run[run_id] = file_path.strip() + + if state.current_text_id is not None: + yield streaming_service.format_text_end(state.current_text_id) + if content_builder is not None: + content_builder.on_text_end(state.current_text_id) + state.current_text_id = None + + if state.last_active_step_title != "Synthesizing response": + comp, new_active = complete_active_thinking_step( + state=state, + streaming_service=streaming_service, + content_builder=content_builder, + last_active_step_id=state.last_active_step_id, + last_active_step_title=state.last_active_step_title, + last_active_step_items=state.last_active_step_items, + completed_step_ids=state.completed_step_ids, + ) + if comp: + yield comp + state.last_active_step_id = new_active + + state.just_finished_tool = False + tool_step_id = state.next_thinking_step_id(step_prefix) + state.tool_step_ids[run_id] = tool_step_id + state.last_active_step_id = tool_step_id + + matched_meta: dict[str, str] | None = None + taken_ui_ids = set(state.ui_tool_call_id_by_run.values()) + for meta in state.index_to_meta.values(): + if meta["name"] == tool_name and meta["ui_id"] not in taken_ui_ids: + matched_meta = meta + break + + tool_call_id: str + langchain_tool_call_id: str | None = None + if matched_meta is not None: + tool_call_id = matched_meta["ui_id"] + langchain_tool_call_id = matched_meta["lc_id"] + if run_id: + state.lc_tool_call_id_by_run[run_id] = matched_meta["lc_id"] + else: + tool_call_id = ( + f"call_{run_id[:32]}" + if run_id + else streaming_service.generate_tool_call_id() + ) + langchain_tool_call_id = match_buffered_langchain_tool_call_id( + state.pending_tool_call_chunks, + tool_name, + run_id, + state.lc_tool_call_id_by_run, + ) + + if tool_name == "task": + open_task_span( + state, + run_id=run_id, + langchain_tool_call_id=langchain_tool_call_id, + ) + + span_md = state.span_metadata_if_active() + tool_md = state.tool_activity_metadata(thinking_step_id=tool_step_id) + + if matched_meta is None: + yield streaming_service.format_tool_input_start( + tool_call_id, + tool_name, + langchain_tool_call_id=langchain_tool_call_id, + metadata=tool_md, + ) + if content_builder is not None: + content_builder.on_tool_input_start( + tool_call_id, + tool_name, + langchain_tool_call_id, + metadata=tool_md, + ) + + thinking = resolve_tool_start_thinking(tool_name, tool_input) + state.last_active_step_title = thinking.title + state.last_active_step_items = thinking.items + frame_kw: dict[str, Any] = { + "streaming_service": streaming_service, + "content_builder": content_builder, + "step_id": tool_step_id, + "title": thinking.title, + "status": "in_progress", + "metadata": span_md, + } + if thinking.include_items_on_frame: + frame_kw["items"] = thinking.items + yield emit_thinking_step_frame(**frame_kw) + + if run_id: + state.ui_tool_call_id_by_run[run_id] = tool_call_id + + if isinstance(tool_input, dict): + _safe_input: dict[str, Any] = {} + for _k, _v in tool_input.items(): + try: + json.dumps(_v) + _safe_input[_k] = _v + except (TypeError, ValueError, OverflowError): + pass + else: + _safe_input = {"input": tool_input} + yield streaming_service.format_tool_input_available( + tool_call_id, + tool_name, + _safe_input, + langchain_tool_call_id=langchain_tool_call_id, + metadata=tool_md, + ) + if content_builder is not None: + content_builder.on_tool_input_available( + tool_call_id, + tool_name, + _safe_input, + langchain_tool_call_id, + metadata=tool_md, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/__init__.py new file mode 100644 index 000000000..4b191c100 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/__init__.py @@ -0,0 +1,23 @@ +"""Per-tool streaming: thinking-step and completion emission.""" + +from __future__ import annotations + +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) +from app.tasks.chat.streaming.handlers.tools.registry import ( + iter_tool_completion_emission_frames, + resolve_tool_completed_thinking_step, + resolve_tool_start_thinking, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + +__all__ = [ + "ToolCompletionEmissionContext", + "ToolStartThinking", + "iter_tool_completion_emission_frames", + "resolve_tool_completed_thinking_step", + "resolve_tool_start_thinking", +] diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/shared/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/shared/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/shared/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/shared/emission.py new file mode 100644 index 000000000..8e19dc224 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/shared/emission.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from collections.abc import Iterator + +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) + + +def iter_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + out = ctx.tool_output + payload = out if isinstance(out, dict) else {"result": out} + yield ctx.emit_tool_output_card(payload) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/shared/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/shared/thinking.py new file mode 100644 index 000000000..2eed8855a --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/shared/thinking.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.default import ( + thinking as default_thinking, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + return default_thinking.resolve_start_thinking(tool_name, tool_input) + + +def resolve_completed_thinking( + tool_name: str, + tool_output: Any, + last_items: list[str], +) -> tuple[str, list[str]]: + return default_thinking.resolve_completed_thinking( + tool_name, tool_output, last_items + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/shared/tool_names.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/shared/tool_names.py new file mode 100644 index 000000000..ab698b32d --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/connector/shared/tool_names.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +SHARED_CONNECTOR_TOOLS: frozenset[str] = frozenset( + { + "create_calendar_event", + "create_confluence_page", + "create_dropbox_file", + "create_gmail_draft", + "create_google_drive_file", + "create_jira_issue", + "create_linear_issue", + "create_notion_page", + "create_onedrive_file", + "delete_calendar_event", + "delete_confluence_page", + "delete_dropbox_file", + "delete_google_drive_file", + "delete_jira_issue", + "delete_linear_issue", + "delete_notion_page", + "delete_onedrive_file", + "send_gmail_email", + "trash_gmail_email", + "update_calendar_event", + "update_confluence_page", + "update_gmail_draft", + "update_jira_issue", + "update_linear_issue", + "update_notion_page", + } +) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/default/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/default/__init__.py new file mode 100644 index 000000000..5e84a37f4 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/default/__init__.py @@ -0,0 +1,3 @@ +"""Fallback tool package.""" + +from __future__ import annotations diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/default/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/default/emission.py new file mode 100644 index 000000000..e24c619a7 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/default/emission.py @@ -0,0 +1,24 @@ +"""Default tool-output card and a short completion terminal line.""" + +from __future__ import annotations + +from collections.abc import Iterator + +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) + + +def iter_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + yield ctx.emit_tool_output_card( + { + "status": "completed", + "result_length": len(str(ctx.tool_output)), + }, + ) + yield ctx.streaming_service.format_terminal_info( + f"Tool {ctx.tool_name} completed", + "success", + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/default/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/default/thinking.py new file mode 100644 index 000000000..46d15a4e7 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/default/thinking.py @@ -0,0 +1,23 @@ +"""Fallback thinking-step copy for unknown tools and connectors without custom UI.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_input + title = tool_name.replace("_", " ").strip().capitalize() or tool_name + return ToolStartThinking(title=title, items=[], include_items_on_frame=False) + + +def resolve_completed_thinking( + tool_name: str, tool_output: Any, last_items: list[str] +) -> tuple[str, list[str]]: + del tool_output + title = tool_name.replace("_", " ").strip().capitalize() or tool_name + return (title, last_items) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_image/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_image/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_image/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_image/emission.py new file mode 100644 index 000000000..762f75cca --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_image/emission.py @@ -0,0 +1,28 @@ +"""generate_image: tool card + terminal summary.""" + +from __future__ import annotations + +from collections.abc import Iterator + +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) + + +def iter_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + out = ctx.tool_output + payload = out if isinstance(out, dict) else {"result": out} + yield ctx.emit_tool_output_card(payload) + if isinstance(out, dict): + if out.get("error"): + yield ctx.streaming_service.format_terminal_info( + f"Image generation failed: {out['error'][:60]}", + "error", + ) + else: + yield ctx.streaming_service.format_terminal_info( + "Image generated successfully", + "success", + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_image/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_image/thinking.py new file mode 100644 index 000000000..c094aabf6 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_image/thinking.py @@ -0,0 +1,41 @@ +"""generate_image: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.deliverables.shared.tool_input import ( + as_tool_input_dict, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + prompt = d.get("prompt", "") if isinstance(tool_input, dict) else str(tool_input) + return ToolStartThinking( + title="Generating image", + items=[f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}"], + ) + + +def resolve_completed_thinking( + tool_name: str, + tool_output: Any, + last_items: list[str], +) -> tuple[str, list[str]]: + del tool_name + items = last_items + if isinstance(tool_output, dict) and not tool_output.get("error"): + completed = [*items, "Image generated successfully"] + else: + error_msg = ( + tool_output.get("error", "Generation failed") + if isinstance(tool_output, dict) + else "Generation failed" + ) + completed = [*items, f"Error: {error_msg}"] + return ("Generating image", completed) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/emission.py new file mode 100644 index 000000000..f1a1e9c37 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/emission.py @@ -0,0 +1,37 @@ +"""generate_podcast: tool card + queue / success / failure terminal lines.""" + +from __future__ import annotations + +from collections.abc import Iterator + +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) + + +def iter_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + out = ctx.tool_output + payload = out if isinstance(out, dict) else {"result": out} + yield ctx.emit_tool_output_card(payload) + if isinstance(out, dict) and out.get("status") in ( + "pending", + "generating", + "processing", + ): + yield ctx.streaming_service.format_terminal_info( + f"Podcast queued: {out.get('title', 'Podcast')}", + "success", + ) + elif isinstance(out, dict) and out.get("status") in ("ready", "success"): + yield ctx.streaming_service.format_terminal_info( + f"Podcast generated successfully: {out.get('title', 'Podcast')}", + "success", + ) + elif isinstance(out, dict) and out.get("status") in ("failed", "error"): + error_msg = out.get("error", "Unknown error") + yield ctx.streaming_service.format_terminal_info( + f"Podcast generation failed: {error_msg}", + "error", + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/thinking.py new file mode 100644 index 000000000..5cf78ea72 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/thinking.py @@ -0,0 +1,82 @@ +"""generate_podcast: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.deliverables.shared.tool_input import ( + as_tool_input_dict, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + podcast_title = ( + d.get("podcast_title", "SurfSense Podcast") + if isinstance(tool_input, dict) + else "SurfSense Podcast" + ) + content_len = len( + d.get("source_content", "") if isinstance(tool_input, dict) else "" + ) + return ToolStartThinking( + title="Generating podcast", + items=[ + f"Title: {podcast_title}", + f"Content: {content_len:,} characters", + "Preparing audio generation...", + ], + ) + + +def resolve_completed_thinking( + tool_name: str, + tool_output: Any, + last_items: list[str], +) -> tuple[str, list[str]]: + del tool_name + items = last_items + podcast_status = ( + tool_output.get("status", "unknown") + if isinstance(tool_output, dict) + else "unknown" + ) + podcast_title = ( + tool_output.get("title", "Podcast") + if isinstance(tool_output, dict) + else "Podcast" + ) + if podcast_status in ("pending", "generating", "processing"): + completed = [ + f"Title: {podcast_title}", + "Podcast generation started", + "Processing in background...", + ] + elif podcast_status == "already_generating": + completed = [ + f"Title: {podcast_title}", + "Podcast already in progress", + "Please wait for it to complete", + ] + elif podcast_status in ("failed", "error"): + error_msg = ( + tool_output.get("error", "Unknown error") + if isinstance(tool_output, dict) + else "Unknown error" + ) + completed = [ + f"Title: {podcast_title}", + f"Error: {error_msg[:50]}", + ] + elif podcast_status in ("ready", "success"): + completed = [ + f"Title: {podcast_title}", + "Podcast ready", + ] + else: + completed = items + return ("Generating podcast", completed) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_report/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_report/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_report/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_report/emission.py new file mode 100644 index 000000000..1c5c71b8b --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_report/emission.py @@ -0,0 +1,33 @@ +"""generate_report: full payload + terminal line.""" + +from __future__ import annotations + +from collections.abc import Iterator + +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) + + +def iter_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + out = ctx.tool_output + payload = out if isinstance(out, dict) else {"result": out} + yield ctx.emit_tool_output_card(payload) + if isinstance(out, dict) and out.get("status") == "ready": + word_count = out.get("word_count", 0) + yield ctx.streaming_service.format_terminal_info( + f"Report generated: {out.get('title', 'Report')} ({word_count:,} words)", + "success", + ) + else: + error_msg = ( + out.get("error", "Unknown error") + if isinstance(out, dict) + else "Unknown error" + ) + yield ctx.streaming_service.format_terminal_info( + f"Report generation failed: {error_msg}", + "error", + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_report/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_report/thinking.py new file mode 100644 index 000000000..3893cd550 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_report/thinking.py @@ -0,0 +1,77 @@ +"""generate_report: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.deliverables.shared.tool_input import ( + as_tool_input_dict, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + report_topic = ( + d.get("topic", "Report") if isinstance(tool_input, dict) else "Report" + ) + is_revision = bool( + isinstance(tool_input, dict) and tool_input.get("parent_report_id") + ) + step_title = "Revising report" if is_revision else "Generating report" + return ToolStartThinking( + title=step_title, + items=[f"Topic: {report_topic}", "Analyzing source content..."], + ) + + +def resolve_completed_thinking( + tool_name: str, + tool_output: Any, + last_items: list[str], +) -> tuple[str, list[str]]: + del tool_name + items = last_items + report_status = ( + tool_output.get("status", "unknown") + if isinstance(tool_output, dict) + else "unknown" + ) + report_title = ( + tool_output.get("title", "Report") + if isinstance(tool_output, dict) + else "Report" + ) + word_count = ( + tool_output.get("word_count", 0) if isinstance(tool_output, dict) else 0 + ) + is_revision = ( + tool_output.get("is_revision", False) + if isinstance(tool_output, dict) + else False + ) + step_title = "Revising report" if is_revision else "Generating report" + + if report_status == "ready": + completed = [ + f"Topic: {report_title}", + f"{word_count:,} words", + "Report ready", + ] + elif report_status == "failed": + error_msg = ( + tool_output.get("error", "Unknown error") + if isinstance(tool_output, dict) + else "Unknown error" + ) + completed = [ + f"Topic: {report_title}", + f"Error: {error_msg[:50]}", + ] + else: + completed = items + + return (step_title, completed) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_resume/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_resume/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_resume/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_resume/emission.py new file mode 100644 index 000000000..dc8d3c7fc --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_resume/emission.py @@ -0,0 +1,32 @@ +"""generate_resume: full payload + terminal line.""" + +from __future__ import annotations + +from collections.abc import Iterator + +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) + + +def iter_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + out = ctx.tool_output + payload = out if isinstance(out, dict) else {"result": out} + yield ctx.emit_tool_output_card(payload) + if isinstance(out, dict) and out.get("status") == "ready": + yield ctx.streaming_service.format_terminal_info( + f"Resume generated: {out.get('title', 'Resume')}", + "success", + ) + else: + error_msg = ( + out.get("error", "Unknown error") + if isinstance(out, dict) + else "Unknown error" + ) + yield ctx.streaming_service.format_terminal_info( + f"Resume generation failed: {error_msg}", + "error", + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_resume/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_resume/thinking.py new file mode 100644 index 000000000..5a54da84d --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_resume/thinking.py @@ -0,0 +1,26 @@ +"""generate_resume: generic thinking titles and items.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.default import ( + thinking as default_thinking, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + return default_thinking.resolve_start_thinking(tool_name, tool_input) + + +def resolve_completed_thinking( + tool_name: str, + tool_output: Any, + last_items: list[str], +) -> tuple[str, list[str]]: + return default_thinking.resolve_completed_thinking( + tool_name, tool_output, last_items + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_video_presentation/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_video_presentation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_video_presentation/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_video_presentation/emission.py new file mode 100644 index 000000000..21e27d4c3 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_video_presentation/emission.py @@ -0,0 +1,28 @@ +"""generate_video_presentation: tool card + terminal line.""" + +from __future__ import annotations + +from collections.abc import Iterator + +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) + + +def iter_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + out = ctx.tool_output + payload = out if isinstance(out, dict) else {"result": out} + yield ctx.emit_tool_output_card(payload) + if isinstance(out, dict) and out.get("status") == "pending": + yield ctx.streaming_service.format_terminal_info( + f"Video presentation queued: {out.get('title', 'Presentation')}", + "success", + ) + elif isinstance(out, dict) and out.get("status") == "failed": + error_msg = out.get("error", "Unknown error") + yield ctx.streaming_service.format_terminal_info( + f"Presentation generation failed: {error_msg}", + "error", + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_video_presentation/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_video_presentation/thinking.py new file mode 100644 index 000000000..fcf425950 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_video_presentation/thinking.py @@ -0,0 +1,54 @@ +"""generate_video_presentation: generic in-progress thinking; completion is status-driven.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.default import ( + thinking as default_thinking, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + return default_thinking.resolve_start_thinking(tool_name, tool_input) + + +def resolve_completed_thinking( + tool_name: str, + tool_output: Any, + last_items: list[str], +) -> tuple[str, list[str]]: + del tool_name + items = last_items + vp_status = ( + tool_output.get("status", "unknown") + if isinstance(tool_output, dict) + else "unknown" + ) + vp_title = ( + tool_output.get("title", "Presentation") + if isinstance(tool_output, dict) + else "Presentation" + ) + if vp_status in ("pending", "generating"): + completed = [ + f"Title: {vp_title}", + "Presentation generation started", + "Processing in background...", + ] + elif vp_status == "failed": + error_msg = ( + tool_output.get("error", "Unknown error") + if isinstance(tool_output, dict) + else "Unknown error" + ) + completed = [ + f"Title: {vp_title}", + f"Error: {error_msg[:50]}", + ] + else: + completed = items + return ("Generating video presentation", completed) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/save_document/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/save_document/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/save_document/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/save_document/emission.py new file mode 100644 index 000000000..68c93dede --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/save_document/emission.py @@ -0,0 +1,16 @@ +"""save_document: default completion card and terminal line.""" + +from __future__ import annotations + +from collections.abc import Iterator + +from app.tasks.chat.streaming.handlers.tools.default import emission as _default +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) + + +def iter_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + yield from _default.iter_completion_emission_frames(ctx) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/save_document/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/save_document/thinking.py new file mode 100644 index 000000000..96e2b6743 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/save_document/thinking.py @@ -0,0 +1,40 @@ +"""save_document: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.deliverables.shared.tool_input import ( + as_tool_input_dict, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + doc_title = d.get("title", "") if isinstance(tool_input, dict) else str(tool_input) + display_title = doc_title[:60] + ("…" if len(doc_title) > 60 else "") + return ToolStartThinking(title="Saving document", items=[display_title]) + + +def resolve_completed_thinking( + tool_name: str, + tool_output: Any, + last_items: list[str], +) -> tuple[str, list[str]]: + del tool_name + items = last_items + result_str = ( + tool_output.get("result", "") + if isinstance(tool_output, dict) + else str(tool_output) + ) + is_error = "Error" in result_str + completed = [ + *items, + result_str[:80] if is_error else "Saved to knowledge base", + ] + return ("Saving document", completed) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/shared/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/shared/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/shared/tool_input.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/shared/tool_input.py new file mode 100644 index 000000000..1303cf09f --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/shared/tool_input.py @@ -0,0 +1,9 @@ +"""Tool-call args for deliverable thinking modules.""" + +from __future__ import annotations + +from typing import Any + + +def as_tool_input_dict(tool_input: Any) -> dict[str, Any]: + return tool_input if isinstance(tool_input, dict) else {} diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/tool_names.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/tool_names.py new file mode 100644 index 000000000..5924af196 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/tool_names.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +DELIVERABLE_TOOLS: frozenset[str] = frozenset( + { + "generate_image", + "generate_podcast", + "generate_report", + "generate_resume", + "generate_video_presentation", + "save_document", + } +) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/emission_context.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/emission_context.py new file mode 100644 index 000000000..baa1d7336 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/emission_context.py @@ -0,0 +1,36 @@ +"""Context for one tool-completion emission pass.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from app.tasks.chat.streaming.handlers.tool_output_frame import ( + emit_tool_output_available_frame, +) + + +@dataclass +class ToolCompletionEmissionContext: + """Streaming service, tool output, and ids for completion frames.""" + + tool_name: str + tool_call_id: str + tool_output: Any + streaming_service: Any + content_builder: Any | None + langchain_tool_call_id_holder: dict[str, str | None] + stream_result: Any + langgraph_config: dict[str, Any] + staged_workspace_file_path: str | None + tool_metadata: dict[str, Any] | None = None + + def emit_tool_output_card(self, payload: Any) -> str: + return emit_tool_output_available_frame( + streaming_service=self.streaming_service, + content_builder=self.content_builder, + langchain_id_holder=self.langchain_tool_call_id_holder, + call_id=self.tool_call_id, + output=payload, + tool_metadata=self.tool_metadata, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/edit_file/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/edit_file/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/edit_file/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/edit_file/thinking.py new file mode 100644 index 000000000..7937c26e9 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/edit_file/thinking.py @@ -0,0 +1,29 @@ +"""edit_file: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import ( + as_tool_input_dict, + truncate_path, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + fp = d.get("file_path", "") if isinstance(tool_input, dict) else str(tool_input) + return ToolStartThinking(title="Editing file", items=[truncate_path(fp)]) + + +def resolve_completed_thinking( + tool_name: str, + tool_output: Any, + last_items: list[str], +) -> tuple[str, list[str]]: + del tool_output, tool_name + return ("Editing file", last_items) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/execute/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/execute/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/execute/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/execute/emission.py new file mode 100644 index 000000000..780cd77d8 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/execute/emission.py @@ -0,0 +1,38 @@ +"""execute: exit code, stdout, sandbox file hints.""" + +from __future__ import annotations + +import re +from collections.abc import Iterator + +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) + + +def iter_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + out = ctx.tool_output + raw_text = out.get("result", "") if isinstance(out, dict) else str(out) + exit_code: int | None = None + output_text = raw_text + m = re.match(r"^Exit code:\s*(\d+)", raw_text) + if m: + exit_code = int(m.group(1)) + om = re.search(r"\nOutput:\n([\s\S]*)", raw_text) + output_text = om.group(1) if om else "" + thread_id_str = ctx.langgraph_config.get("configurable", {}).get("thread_id", "") + + for sf_match in re.finditer(r"^SANDBOX_FILE:\s*(.+)$", output_text, re.MULTILINE): + fpath = sf_match.group(1).strip() + if fpath and fpath not in ctx.stream_result.sandbox_files: + ctx.stream_result.sandbox_files.append(fpath) + + yield ctx.emit_tool_output_card( + { + "exit_code": exit_code, + "output": output_text, + "thread_id": thread_id_str, + }, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/execute/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/execute/thinking.py new file mode 100644 index 000000000..968140363 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/execute/thinking.py @@ -0,0 +1,44 @@ +"""execute: sandbox command thinking + completion lines.""" + +from __future__ import annotations + +import re +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import ( + as_tool_input_dict, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + cmd = d.get("command", "") if isinstance(tool_input, dict) else str(tool_input) + display_cmd = cmd[:80] + ("…" if len(cmd) > 80 else "") + return ToolStartThinking(title="Running command", items=[f"$ {display_cmd}"]) + + +def resolve_completed_thinking( + tool_name: str, + tool_output: Any, + last_items: list[str], +) -> tuple[str, list[str]]: + del tool_name + items = last_items + raw_text = ( + tool_output.get("result", "") + if isinstance(tool_output, dict) + else str(tool_output) + ) + m = re.match(r"^Exit code:\s*(\d+)", raw_text) + exit_code_val = int(m.group(1)) if m else None + if exit_code_val is not None and exit_code_val == 0: + completed = [*items, "Completed successfully"] + elif exit_code_val is not None: + completed = [*items, f"Exit code: {exit_code_val}"] + else: + completed = [*items, "Finished"] + return ("Running command", completed) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/glob/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/glob/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/glob/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/glob/thinking.py new file mode 100644 index 000000000..bdcaecd45 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/glob/thinking.py @@ -0,0 +1,29 @@ +"""glob: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import ( + as_tool_input_dict, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + pat = d.get("pattern", "") if isinstance(tool_input, dict) else str(tool_input) + base = d.get("path", "/") if isinstance(tool_input, dict) else "/" + return ToolStartThinking(title="Searching files", items=[f"{pat} in {base}"]) + + +def resolve_completed_thinking( + tool_name: str, + tool_output: Any, + last_items: list[str], +) -> tuple[str, list[str]]: + del tool_output, tool_name + return ("Searching files", last_items) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/grep/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/grep/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/grep/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/grep/thinking.py new file mode 100644 index 000000000..36e077599 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/grep/thinking.py @@ -0,0 +1,33 @@ +"""grep: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import ( + as_tool_input_dict, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + pat = d.get("pattern", "") if isinstance(tool_input, dict) else str(tool_input) + grep_path = d.get("path", "") if isinstance(tool_input, dict) else "" + display_pat = pat[:60] + ("…" if len(pat) > 60 else "") + return ToolStartThinking( + title="Searching content", + items=[f'"{display_pat}"' + (f" in {grep_path}" if grep_path else "")], + ) + + +def resolve_completed_thinking( + tool_name: str, + tool_output: Any, + last_items: list[str], +) -> tuple[str, list[str]]: + del tool_output, tool_name + return ("Searching content", last_items) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/ls/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/ls/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/ls/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/ls/thinking.py new file mode 100644 index 000000000..1e89f3a12 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/ls/thinking.py @@ -0,0 +1,59 @@ +"""ls: thinking-step copy for directory listing.""" + +from __future__ import annotations + +import ast +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + if isinstance(tool_input, dict): + path = tool_input.get("path", "/") + else: + path = str(tool_input) + return ToolStartThinking(title="Listing files", items=[path]) + + +def resolve_completed_thinking( + tool_name: str, + tool_output: Any, + last_items: list[str], +) -> tuple[str, list[str]]: + del tool_name + if isinstance(tool_output, dict): + ls_output = tool_output.get("result", "") + elif isinstance(tool_output, str): + ls_output = tool_output + else: + ls_output = str(tool_output) if tool_output else "" + file_names: list[str] = [] + if ls_output: + paths: list[str] = [] + try: + parsed = ast.literal_eval(ls_output) + if isinstance(parsed, list): + paths = [str(p) for p in parsed] + except (ValueError, SyntaxError): + paths = [ + line.strip() for line in ls_output.strip().split("\n") if line.strip() + ] + for p in paths: + name = p.rstrip("/").split("/")[-1] + if name and len(name) <= 40: + file_names.append(name) + elif name: + file_names.append(name[:37] + "...") + if file_names: + if len(file_names) <= 5: + completed = [f"[{name}]" for name in file_names] + else: + completed = [f"[{name}]" for name in file_names[:4]] + completed.append(f"(+{len(file_names) - 4} more)") + else: + completed = ["No files found"] + return ("Listing files", completed) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/mkdir/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/mkdir/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/mkdir/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/mkdir/thinking.py new file mode 100644 index 000000000..1cfb0fc67 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/mkdir/thinking.py @@ -0,0 +1,31 @@ +"""mkdir: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import ( + as_tool_input_dict, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + p = d.get("path", "") if isinstance(tool_input, dict) else str(tool_input) + display = p if len(p) <= 80 else "…" + p[-77:] + return ToolStartThinking( + title="Creating folder", items=[display] if display else [] + ) + + +def resolve_completed_thinking( + tool_name: str, + tool_output: Any, + last_items: list[str], +) -> tuple[str, list[str]]: + del tool_output, tool_name + return ("Creating folder", last_items) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/move_file/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/move_file/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/move_file/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/move_file/thinking.py new file mode 100644 index 000000000..a83a33ef7 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/move_file/thinking.py @@ -0,0 +1,35 @@ +"""move_file: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import ( + as_tool_input_dict, + truncate_middle, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + src = d.get("source_path", "") if isinstance(tool_input, dict) else "" + dst = d.get("destination_path", "") if isinstance(tool_input, dict) else "" + display_src = truncate_middle(src, max_len=60) + display_dst = truncate_middle(dst, max_len=60) + return ToolStartThinking( + title="Moving file", + items=[f"{display_src} → {display_dst}"] if src or dst else [], + ) + + +def resolve_completed_thinking( + tool_name: str, + tool_output: Any, + last_items: list[str], +) -> tuple[str, list[str]]: + del tool_output, tool_name + return ("Moving file", last_items) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/read_file/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/read_file/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/read_file/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/read_file/thinking.py new file mode 100644 index 000000000..2957aed05 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/read_file/thinking.py @@ -0,0 +1,29 @@ +"""read_file: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import ( + as_tool_input_dict, + truncate_path, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + fp = d.get("file_path", "") if isinstance(tool_input, dict) else str(tool_input) + return ToolStartThinking(title="Reading file", items=[truncate_path(fp)]) + + +def resolve_completed_thinking( + tool_name: str, + tool_output: Any, + last_items: list[str], +) -> tuple[str, list[str]]: + del tool_output, tool_name + return ("Reading file", last_items) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/rm/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/rm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/rm/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/rm/thinking.py new file mode 100644 index 000000000..9d21c306d --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/rm/thinking.py @@ -0,0 +1,30 @@ +"""rm: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import ( + as_tool_input_dict, + truncate_path, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + rm_path = d.get("path", "") if isinstance(tool_input, dict) else str(tool_input) + display = truncate_path(rm_path) + return ToolStartThinking(title="Deleting file", items=[display] if display else []) + + +def resolve_completed_thinking( + tool_name: str, + tool_output: Any, + last_items: list[str], +) -> tuple[str, list[str]]: + del tool_output, tool_name + return ("Deleting file", last_items) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/rmdir/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/rmdir/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/rmdir/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/rmdir/thinking.py new file mode 100644 index 000000000..eab0e78ec --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/rmdir/thinking.py @@ -0,0 +1,31 @@ +"""rmdir: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import ( + as_tool_input_dict, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + p = d.get("path", "") if isinstance(tool_input, dict) else str(tool_input) + display = p if len(p) <= 80 else "…" + p[-77:] + return ToolStartThinking( + title="Deleting folder", items=[display] if display else [] + ) + + +def resolve_completed_thinking( + tool_name: str, + tool_output: Any, + last_items: list[str], +) -> tuple[str, list[str]]: + del tool_output, tool_name + return ("Deleting folder", last_items) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/shared/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/shared/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/shared/tool_input.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/shared/tool_input.py new file mode 100644 index 000000000..507782283 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/shared/tool_input.py @@ -0,0 +1,17 @@ +"""Tool-call args + display truncation for filesystem thinking modules.""" + +from __future__ import annotations + +from typing import Any + + +def as_tool_input_dict(tool_input: Any) -> dict[str, Any]: + return tool_input if isinstance(tool_input, dict) else {} + + +def truncate_path(fp: str, *, max_len: int = 80) -> str: + return fp if len(fp) <= max_len else "…" + fp[-(max_len - 3) :] + + +def truncate_middle(s: str, *, max_len: int = 60) -> str: + return s if len(s) <= max_len else "…" + s[-(max_len - 3) :] diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/tool_names.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/tool_names.py new file mode 100644 index 000000000..e2ad33736 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/tool_names.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +FILESYSTEM_TOOLS: frozenset[str] = frozenset( + { + "read_file", + "glob", + "grep", + "ls", + "mkdir", + "move_file", + "rm", + "rmdir", + "write_todos", + "write_file", + "edit_file", + "execute", + } +) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_file/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_file/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_file/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_file/emission.py new file mode 100644 index 000000000..820235379 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_file/emission.py @@ -0,0 +1,43 @@ +"""write_file: path + status envelope on the tool-output card.""" + +from __future__ import annotations + +from collections.abc import Iterator + +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) +from app.tasks.chat.streaming.helpers.tool_output import ( + extract_resolved_file_path, + tool_output_has_error, + tool_output_to_text, +) + + +def iter_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + resolved_path = extract_resolved_file_path( + tool_name=ctx.tool_name, + tool_output=ctx.tool_output, + tool_input={"file_path": ctx.staged_workspace_file_path} + if ctx.staged_workspace_file_path + else None, + ) + result_text = tool_output_to_text(ctx.tool_output) + if tool_output_has_error(ctx.tool_output): + yield ctx.emit_tool_output_card( + { + "status": "error", + "error": result_text, + "path": resolved_path, + }, + ) + else: + yield ctx.emit_tool_output_card( + { + "status": "completed", + "path": resolved_path, + "result": result_text, + }, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_file/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_file/thinking.py new file mode 100644 index 000000000..07a6cb9e1 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_file/thinking.py @@ -0,0 +1,29 @@ +"""write_file: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import ( + as_tool_input_dict, + truncate_path, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + fp = d.get("file_path", "") if isinstance(tool_input, dict) else str(tool_input) + return ToolStartThinking(title="Writing file", items=[truncate_path(fp)]) + + +def resolve_completed_thinking( + tool_name: str, + tool_output: Any, + last_items: list[str], +) -> tuple[str, list[str]]: + del tool_output, tool_name + return ("Writing file", last_items) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_todos/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_todos/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_todos/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_todos/thinking.py new file mode 100644 index 000000000..5cc631fb5 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/filesystem/write_todos/thinking.py @@ -0,0 +1,34 @@ +"""write_todos: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.filesystem.shared.tool_input import ( + as_tool_input_dict, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + todos = d.get("todos", []) if isinstance(tool_input, dict) else [] + todo_count = len(todos) if isinstance(todos, list) else 0 + return ToolStartThinking( + title="Planning tasks", + items=( + [f"{todo_count} task{'s' if todo_count != 1 else ''}"] if todo_count else [] + ), + ) + + +def resolve_completed_thinking( + tool_name: str, + tool_output: Any, + last_items: list[str], +) -> tuple[str, list[str]]: + del tool_output, tool_name + return ("Planning tasks", last_items) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/registry.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/registry.py new file mode 100644 index 000000000..06f06f90a --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/registry.py @@ -0,0 +1,92 @@ +"""Resolve thinking and emission modules by tool name.""" + +from __future__ import annotations + +import importlib +from collections.abc import Iterator +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.connector.shared.tool_names import ( + SHARED_CONNECTOR_TOOLS, +) +from app.tasks.chat.streaming.handlers.tools.deliverables.tool_names import ( + DELIVERABLE_TOOLS, +) +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) +from app.tasks.chat.streaming.handlers.tools.filesystem.tool_names import ( + FILESYSTEM_TOOLS, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + +_BASE = "app.tasks.chat.streaming.handlers.tools" +_CONNECTOR_SHARED = "connector.shared" + +_THINKING_ALIAS: dict[str, str] = { + "execute_code": "filesystem.execute", +} +_EMISSION_ALIAS: dict[str, str] = { + "edit_file": "filesystem.write_file", + "execute_code": "filesystem.execute", +} + + +def _thinking_module(tool_name: str) -> str: + if tool_name in SHARED_CONNECTOR_TOOLS: + return _CONNECTOR_SHARED + if tool_name in FILESYSTEM_TOOLS: + return f"filesystem.{tool_name}" + if tool_name in DELIVERABLE_TOOLS: + return f"deliverables.{tool_name}" + return _THINKING_ALIAS.get(tool_name, tool_name) + + +def _emission_module(tool_name: str) -> str: + if tool_name in _EMISSION_ALIAS: + return _EMISSION_ALIAS[tool_name] + if tool_name in SHARED_CONNECTOR_TOOLS: + return _CONNECTOR_SHARED + if tool_name in DELIVERABLE_TOOLS: + return f"deliverables.{tool_name}" + if tool_name in FILESYSTEM_TOOLS: + return f"filesystem.{tool_name}" + return tool_name + + +def _import_thinking(tool_name: str): + try: + return importlib.import_module( + f"{_BASE}.{_thinking_module(tool_name)}.thinking" + ) + except ModuleNotFoundError: + return importlib.import_module(f"{_BASE}.default.thinking") + + +def _import_emission(tool_name: str): + try: + return importlib.import_module( + f"{_BASE}.{_emission_module(tool_name)}.emission" + ) + except ModuleNotFoundError: + return importlib.import_module(f"{_BASE}.default.emission") + + +def resolve_tool_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + return _import_thinking(tool_name).resolve_start_thinking(tool_name, tool_input) + + +def resolve_tool_completed_thinking_step( + tool_name: str, tool_output: Any, last_items: list[str] +) -> tuple[str, list[str]]: + return _import_thinking(tool_name).resolve_completed_thinking( + tool_name, tool_output, last_items + ) + + +def iter_tool_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + yield from _import_emission(ctx.tool_name).iter_completion_emission_frames(ctx) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/scrape_webpage/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/scrape_webpage/emission.py new file mode 100644 index 000000000..293d2a1e9 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/scrape_webpage/emission.py @@ -0,0 +1,43 @@ +"""scrape_webpage: redacted payload + terminal summary.""" + +from __future__ import annotations + +from collections.abc import Iterator + +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) + + +def iter_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + out = ctx.tool_output + if isinstance(out, dict): + display_output = {k: v for k, v in out.items() if k != "content"} + if "content" in out: + content = out.get("content", "") + display_output["content_preview"] = ( + content[:500] + "..." if len(content) > 500 else content + ) + yield ctx.emit_tool_output_card(display_output) + else: + yield ctx.emit_tool_output_card({"result": out}) + + if isinstance(out, dict) and "error" not in out: + title = out.get("title", "Webpage") + word_count = out.get("word_count", 0) + yield ctx.streaming_service.format_terminal_info( + f"Scraped: {title[:40]}{'...' if len(title) > 40 else ''} ({word_count:,} words)", + "success", + ) + else: + error_msg = ( + out.get("error", "Failed to scrape") + if isinstance(out, dict) + else "Failed to scrape" + ) + yield ctx.streaming_service.format_terminal_info( + f"Scrape failed: {error_msg}", + "error", + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/scrape_webpage/shared/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/scrape_webpage/shared/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/scrape_webpage/shared/tool_input.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/scrape_webpage/shared/tool_input.py new file mode 100644 index 000000000..581f0e64a --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/scrape_webpage/shared/tool_input.py @@ -0,0 +1,9 @@ +"""Tool-call args for scrape_webpage thinking.""" + +from __future__ import annotations + +from typing import Any + + +def as_tool_input_dict(tool_input: Any) -> dict[str, Any]: + return tool_input if isinstance(tool_input, dict) else {} diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/scrape_webpage/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/scrape_webpage/thinking.py new file mode 100644 index 000000000..8a04acbe6 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/scrape_webpage/thinking.py @@ -0,0 +1,49 @@ +"""scrape_webpage: thinking-step copy.""" + +from __future__ import annotations + +from typing import Any + +from app.tasks.chat.streaming.handlers.tools.scrape_webpage.shared.tool_input import ( + as_tool_input_dict, +) +from app.tasks.chat.streaming.handlers.tools.shared.model import ( + ToolStartThinking, +) + + +def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking: + del tool_name + d = as_tool_input_dict(tool_input) + url = d.get("url", "") if isinstance(tool_input, dict) else str(tool_input) + return ToolStartThinking( + title="Scraping webpage", + items=[f"URL: {url[:80]}{'...' if len(url) > 80 else ''}"], + ) + + +def resolve_completed_thinking( + tool_name: str, + tool_output: Any, + last_items: list[str], +) -> tuple[str, list[str]]: + del tool_name + items = last_items + if isinstance(tool_output, dict): + title = tool_output.get("title", "Webpage") + word_count = tool_output.get("word_count", 0) + has_error = "error" in tool_output + if has_error: + completed = [ + *items, + f"Error: {tool_output.get('error', 'Failed to scrape')[:50]}", + ] + else: + completed = [ + *items, + f"Title: {title[:50]}{'...' if len(title) > 50 else ''}", + f"Extracted: {word_count:,} words", + ] + else: + completed = [*items, "Content extracted"] + return ("Scraping webpage", completed) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/shared/__init__.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/shared/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/shared/model.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/shared/model.py new file mode 100644 index 000000000..047a84374 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/shared/model.py @@ -0,0 +1,12 @@ +"""In-progress thinking-step title and bullet lines.""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True, slots=True) +class ToolStartThinking: + title: str + items: list[str] + include_items_on_frame: bool = True diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/web_search/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/web_search/emission.py new file mode 100644 index 000000000..3efe45d0c --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/web_search/emission.py @@ -0,0 +1,37 @@ +"""web_search: citations parsed from provider XML.""" + +from __future__ import annotations + +import re +from collections.abc import Iterator + +from app.tasks.chat.streaming.handlers.tools.emission_context import ( + ToolCompletionEmissionContext, +) + + +def iter_completion_emission_frames( + ctx: ToolCompletionEmissionContext, +) -> Iterator[str]: + out = ctx.tool_output + xml = out.get("result", str(out)) if isinstance(out, dict) else str(out) + citations: dict[str, dict[str, str]] = {} + for m in re.finditer( + r"<!\[CDATA\[(.*?)\]\]>\s*", + xml, + ): + title, url = m.group(1).strip(), m.group(2).strip() + if url.startswith("http") and url not in citations: + citations[url] = {"title": title} + for m in re.finditer( + r"", + xml, + ): + chunk_url, content = m.group(1).strip(), m.group(2).strip() + if chunk_url.startswith("http") and chunk_url in citations and content: + citations[chunk_url]["snippet"] = ( + content[:200] + "…" if len(content) > 200 else content + ) + yield ctx.emit_tool_output_card( + {"status": "completed", "citations": citations}, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/helpers/__init__.py b/surfsense_backend/app/tasks/chat/streaming/helpers/__init__.py new file mode 100644 index 000000000..151dfdaac --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/helpers/__init__.py @@ -0,0 +1,3 @@ +"""Pure helpers for chat streaming.""" + +from __future__ import annotations diff --git a/surfsense_backend/app/tasks/chat/streaming/helpers/chunk_parts.py b/surfsense_backend/app/tasks/chat/streaming/helpers/chunk_parts.py new file mode 100644 index 000000000..48b44fc1d --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/helpers/chunk_parts.py @@ -0,0 +1,60 @@ +"""Split a model chunk into text, reasoning, and tool-call fragment lists.""" + +from __future__ import annotations + +from typing import Any + + +def extract_chunk_parts(chunk: Any) -> dict[str, Any]: + """Return dict with keys text, reasoning, and tool_call_chunks (merged from chunk fields).""" + out: dict[str, Any] = {"text": "", "reasoning": "", "tool_call_chunks": []} + if chunk is None: + return out + + content = getattr(chunk, "content", None) + if isinstance(content, str): + if content: + out["text"] = content + elif isinstance(content, list): + text_parts: list[str] = [] + reasoning_parts: list[str] = [] + for block in content: + if not isinstance(block, dict): + continue + block_type = block.get("type") + if block_type == "text": + value = block.get("text") or block.get("content") or "" + if isinstance(value, str) and value: + text_parts.append(value) + elif block_type == "reasoning": + value = ( + block.get("reasoning") + or block.get("text") + or block.get("content") + or "" + ) + if isinstance(value, str) and value: + reasoning_parts.append(value) + elif block_type in ("tool_call_chunk", "tool_use"): + out["tool_call_chunks"].append(block) + if text_parts: + out["text"] = "".join(text_parts) + if reasoning_parts: + out["reasoning"] = "".join(reasoning_parts) + + additional = getattr(chunk, "additional_kwargs", None) or {} + if isinstance(additional, dict): + extra_reasoning = additional.get("reasoning_content") + if isinstance(extra_reasoning, str) and extra_reasoning: + existing = out["reasoning"] + out["reasoning"] = ( + (existing + extra_reasoning) if existing else extra_reasoning + ) + + extra_tool_chunks = getattr(chunk, "tool_call_chunks", None) + if isinstance(extra_tool_chunks, list): + for tcc in extra_tool_chunks: + if isinstance(tcc, dict): + out["tool_call_chunks"].append(tcc) + + return out diff --git a/surfsense_backend/app/tasks/chat/streaming/helpers/interrupt_inspector.py b/surfsense_backend/app/tasks/chat/streaming/helpers/interrupt_inspector.py new file mode 100644 index 000000000..dca099b3f --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/helpers/interrupt_inspector.py @@ -0,0 +1,47 @@ +"""Read the first interrupt payload from a LangGraph state snapshot.""" + +from __future__ import annotations + +from typing import Any + + +def first_interrupt_value(state: Any) -> dict[str, Any] | None: + """Return the first interrupt payload across all snapshot tasks.""" + + def _extract(candidate: Any) -> dict[str, Any] | None: + if isinstance(candidate, dict): + value = candidate.get("value", candidate) + return value if isinstance(value, dict) else None + value = getattr(candidate, "value", None) + if isinstance(value, dict): + return value + if isinstance(candidate, list | tuple): + for item in candidate: + extracted = _extract(item) + if extracted is not None: + return extracted + return None + + for task in getattr(state, "tasks", ()) or (): + try: + interrupts = getattr(task, "interrupts", ()) or () + except (AttributeError, IndexError, TypeError): + interrupts = () + if not interrupts: + extracted = _extract(task) + if extracted is not None: + return extracted + continue + for interrupt_item in interrupts: + extracted = _extract(interrupt_item) + if extracted is not None: + return extracted + + try: + state_interrupts = getattr(state, "interrupts", ()) or () + except (AttributeError, IndexError, TypeError): + state_interrupts = () + extracted = _extract(state_interrupts) + if extracted is not None: + return extracted + return None diff --git a/surfsense_backend/app/tasks/chat/streaming/helpers/tool_call_matching.py b/surfsense_backend/app/tasks/chat/streaming/helpers/tool_call_matching.py new file mode 100644 index 000000000..fbe4c94b7 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/helpers/tool_call_matching.py @@ -0,0 +1,32 @@ +"""Match buffered model tool-call chunks to a tool start when ids were missing.""" + +from __future__ import annotations + +from typing import Any + + +def match_buffered_langchain_tool_call_id( + pending_tool_call_chunks: list[dict[str, Any]], + tool_name: str, + run_id: str, + lc_tool_call_id_by_run: dict[str, str], +) -> str | None: + matched_idx: int | None = None + for idx, tcc in enumerate(pending_tool_call_chunks): + if tcc.get("name") == tool_name and tcc.get("id"): + matched_idx = idx + break + if matched_idx is None: + for idx, tcc in enumerate(pending_tool_call_chunks): + if tcc.get("id"): + matched_idx = idx + break + if matched_idx is None: + return None + matched = pending_tool_call_chunks.pop(matched_idx) + candidate = matched.get("id") + if isinstance(candidate, str) and candidate: + if run_id: + lc_tool_call_id_by_run[run_id] = candidate + return candidate + return None diff --git a/surfsense_backend/app/tasks/chat/streaming/helpers/tool_output.py b/surfsense_backend/app/tasks/chat/streaming/helpers/tool_output.py new file mode 100644 index 000000000..a7c401dee --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/helpers/tool_output.py @@ -0,0 +1,43 @@ +"""Normalize filesystem tool payloads for SSE cards and messages.""" + +from __future__ import annotations + +import json +from typing import Any + + +def tool_output_to_text(tool_output: Any) -> str: + if isinstance(tool_output, dict): + if isinstance(tool_output.get("result"), str): + return tool_output["result"] + if isinstance(tool_output.get("error"), str): + return tool_output["error"] + return json.dumps(tool_output, ensure_ascii=False) + return str(tool_output) + + +def tool_output_has_error(tool_output: Any) -> bool: + if isinstance(tool_output, dict): + if tool_output.get("error"): + return True + result = tool_output.get("result") + return bool( + isinstance(result, str) and result.strip().lower().startswith("error:") + ) + if isinstance(tool_output, str): + return tool_output.strip().lower().startswith("error:") + return False + + +def extract_resolved_file_path( + *, tool_name: str, tool_output: Any, tool_input: Any | None = None +) -> str | None: + if isinstance(tool_output, dict): + path_value = tool_output.get("path") + if isinstance(path_value, str) and path_value.strip(): + return path_value.strip() + if tool_name in ("write_file", "edit_file") and isinstance(tool_input, dict): + file_path = tool_input.get("file_path") + if isinstance(file_path, str) and file_path.strip(): + return file_path.strip() + return None diff --git a/surfsense_backend/app/tasks/chat/streaming/relay/__init__.py b/surfsense_backend/app/tasks/chat/streaming/relay/__init__.py new file mode 100644 index 000000000..18eda9a6d --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/relay/__init__.py @@ -0,0 +1,23 @@ +"""Relay: thinking steps, tool bookkeeping, and ``EventRelay``. + +Package imports are lazy so ``relay.thinking_step_sse`` (and siblings) can load +without pulling in ``event_relay`` (which imports handler modules that may +import those siblings). +""" + +from __future__ import annotations + +__all__ = ["EventRelay", "EventRelayConfig"] + + +def __getattr__(name: str): + if name == "EventRelay": + from app.tasks.chat.streaming.relay.event_relay import EventRelay + + return EventRelay + if name == "EventRelayConfig": + from app.tasks.chat.streaming.relay.event_relay import EventRelayConfig + + return EventRelayConfig + msg = f"module {__name__!r} has no attribute {name!r}" + raise AttributeError(msg) diff --git a/surfsense_backend/app/tasks/chat/streaming/relay/event_relay.py b/surfsense_backend/app/tasks/chat/streaming/relay/event_relay.py new file mode 100644 index 000000000..03d6a66e6 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/relay/event_relay.py @@ -0,0 +1,128 @@ +"""Turn LangGraph astream_events into SSE strings via the handler modules.""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from dataclasses import dataclass, field +from typing import Any + +from app.services.streaming.emitter import EmitterRegistry +from app.tasks.chat.streaming.graph_stream.result import StreamingResult +from app.tasks.chat.streaming.handlers.chain_end import iter_chain_end_frames +from app.tasks.chat.streaming.handlers.chat_model_stream import ( + iter_chat_model_stream_frames, +) +from app.tasks.chat.streaming.handlers.custom_event_dispatch import ( + iter_custom_event_frames, +) +from app.tasks.chat.streaming.handlers.tool_end import iter_tool_end_frames +from app.tasks.chat.streaming.handlers.tool_start import iter_tool_start_frames +from app.tasks.chat.streaming.relay.state import AgentEventRelayState +from app.tasks.chat.streaming.relay.thinking_step_completion import ( + complete_active_thinking_step, +) + + +@dataclass +class EventRelayConfig: + """Optional relay tuning (sub-agent tools, text suppression).""" + + subagent_entry_tool_names: frozenset[str] = field( + default_factory=lambda: frozenset({"task"}) + ) + suppress_main_text_inside_tools: bool = True + + +class EventRelay: + """Dispatches graph events to streaming handlers and optional emitters.""" + + def __init__( + self, + *, + streaming_service: Any, + config: EventRelayConfig | None = None, + ) -> None: + self.streaming_service = streaming_service + self.config = config or EventRelayConfig() + reg = getattr(streaming_service, "emitter_registry", None) + self.emitter_registry = reg if reg is not None else EmitterRegistry() + + async def relay( + self, + events: AsyncIterator[dict[str, Any]], + *, + state: AgentEventRelayState, + result: StreamingResult, + step_prefix: str = "thinking", + content_builder: Any | None = None, + config: dict[str, Any] | None = None, + ) -> AsyncIterator[str]: + """Yield SSE for each event from the async iterator, then finalize text/thinking.""" + graph_config = config or {} + async for event in events: + event_type = event.get("event", "") + if event_type == "on_chat_model_stream": + for frame in iter_chat_model_stream_frames( + event, + state=state, + streaming_service=self.streaming_service, + content_builder=content_builder, + step_prefix=step_prefix, + ): + yield frame + elif event_type == "on_tool_start": + for frame in iter_tool_start_frames( + event, + state=state, + streaming_service=self.streaming_service, + content_builder=content_builder, + result=result, + step_prefix=step_prefix, + ): + yield frame + elif event_type == "on_tool_end": + for frame in iter_tool_end_frames( + event, + state=state, + streaming_service=self.streaming_service, + content_builder=content_builder, + result=result, + step_prefix=step_prefix, + config=graph_config, + ): + yield frame + elif event_type == "on_custom_event": + for frame in iter_custom_event_frames( + event, + state=state, + streaming_service=self.streaming_service, + content_builder=content_builder, + ): + yield frame + elif event_type in ("on_chain_end", "on_agent_end"): + for frame in iter_chain_end_frames( + event, + state=state, + streaming_service=self.streaming_service, + content_builder=content_builder, + ): + yield frame + + if state.current_text_id is not None: + yield self.streaming_service.format_text_end(state.current_text_id) + if content_builder is not None: + content_builder.on_text_end(state.current_text_id) + state.current_text_id = None + + completion_event, new_active = complete_active_thinking_step( + state=state, + streaming_service=self.streaming_service, + content_builder=content_builder, + last_active_step_id=state.last_active_step_id, + last_active_step_title=state.last_active_step_title, + last_active_step_items=state.last_active_step_items, + completed_step_ids=state.completed_step_ids, + ) + if completion_event: + yield completion_event + state.last_active_step_id = new_active diff --git a/surfsense_backend/app/tasks/chat/streaming/relay/state.py b/surfsense_backend/app/tasks/chat/streaming/relay/state.py new file mode 100644 index 000000000..27898403d --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/relay/state.py @@ -0,0 +1,98 @@ +"""Mutable counters and maps for one agent stream.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class AgentEventRelayState: + """Tracks text, thinking steps, tool depth, and pending tool-call metadata. + + **Task span (`spanId`)** — ``active_span_id`` groups steps and tools for one + open delegating ``task`` episode. ``active_task_run_id`` is the LangGraph + ``run_id`` of that ``task`` so the span clears only when that run ends, not + when child tools end. Open/close uses ``relay.task_span`` helpers. + + **Tool ↔ thinking link (`thinkingStepId`)** — Each tool run gets a thinking-row + id (``tool_step_ids[run_id]``, emitted as ``data-thinking-step`` ``data.id``). + ``tool_activity_metadata`` supplies ``metadata`` for ``tool-input-start`` / + ``tool-input-available`` (``handlers.tool_start``) and + ``tool-output-available`` (``handlers.tool_end``). + """ + + accumulated_text: str = "" + current_text_id: str | None = None + thinking_step_counter: int = 0 + tool_step_ids: dict[str, str] = field(default_factory=dict) + completed_step_ids: set[str] = field(default_factory=set) + last_active_step_id: str | None = None + last_active_step_title: str = "" + last_active_step_items: list[str] = field(default_factory=list) + just_finished_tool: bool = False + active_tool_depth: int = 0 + called_update_memory: bool = False + current_reasoning_id: str | None = None + pending_tool_call_chunks: list[dict[str, Any]] = field(default_factory=list) + lc_tool_call_id_by_run: dict[str, str] = field(default_factory=dict) + file_path_by_run: dict[str, str] = field(default_factory=dict) + index_to_meta: dict[int, dict[str, str]] = field(default_factory=dict) + ui_tool_call_id_by_run: dict[str, str] = field(default_factory=dict) + current_lc_tool_call_id: dict[str, str | None] = field( + default_factory=lambda: {"value": None} + ) + # Open ``task`` delegation span (one id shared by nested activity); unset outside. + active_span_id: str | None = None + active_task_run_id: str | None = None + # Span id minted when a ``task`` tool_call_chunk registers (before ``on_tool_start``). + pending_task_span_by_lc: dict[str, str] = field(default_factory=dict) + + def span_metadata_if_active(self) -> dict[str, Any] | None: + """``{"spanId": ...}`` when a span is active; ``None`` otherwise.""" + if self.active_span_id: + return {"spanId": self.active_span_id} + return None + + def tool_activity_metadata( + self, *, thinking_step_id: str | None + ) -> dict[str, Any] | None: + """Build ``metadata`` for tool SSE and ``tool-call`` persistence. + + Contract (keys omitted when not applicable): + + - ``spanId`` (str): present while a task-delegation span is active + (same value as ``span_metadata_if_active()``). + - ``thinkingStepId`` (str): equals the thinking-step row ``id`` for this + tool (``data-thinking-step`` payload ``data.id`` on the wire). + + Returns ``None`` if neither applies. Whitespace-only + ``thinking_step_id`` is ignored. + """ + out: dict[str, Any] = {} + if self.active_span_id: + out["spanId"] = self.active_span_id + tid = (thinking_step_id or "").strip() + if tid: + out["thinkingStepId"] = tid + return out if out else None + + @classmethod + def for_invocation( + cls, + *, + initial_step_id: str | None = None, + initial_step_title: str = "", + initial_step_items: list[str] | None = None, + ) -> AgentEventRelayState: + counter = 1 if initial_step_id else 0 + return cls( + thinking_step_counter=counter, + last_active_step_id=initial_step_id, + last_active_step_title=initial_step_title, + last_active_step_items=list(initial_step_items or []), + ) + + def next_thinking_step_id(self, step_prefix: str) -> str: + self.thinking_step_counter += 1 + return f"{step_prefix}-{self.thinking_step_counter}" diff --git a/surfsense_backend/app/tasks/chat/streaming/relay/task_span.py b/surfsense_backend/app/tasks/chat/streaming/relay/task_span.py new file mode 100644 index 000000000..c4cdf24ba --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/relay/task_span.py @@ -0,0 +1,74 @@ +"""Open/close ``active_span_id`` around a delegating ``task`` tool run.""" + +from __future__ import annotations + +import uuid + +from app.tasks.chat.streaming.relay.state import AgentEventRelayState + + +def new_span_id() -> str: + """One delegation-episode id (shared by activity under an open ``task``).""" + return f"spn_{uuid.uuid4().hex}" + + +def _run_key(run_id: str) -> str: + return (run_id or "").strip() + + +def _lc_key(langchain_tool_call_id: str | None) -> str: + return (langchain_tool_call_id or "").strip() + + +def ensure_pending_task_span_for_lc(state: AgentEventRelayState, lc_id: str) -> str: + """Return span id for this LangChain tool call id, storing it in ``pending`` if new. + + Used from ``chat_model_stream`` when the first ``task`` chunk registers so + early ``tool-input-start`` can carry ``metadata.spanId`` before ``on_tool_start``. + """ + key = _lc_key(lc_id) + if not key: + return new_span_id() + existing = state.pending_task_span_by_lc.get(key) + if existing: + return existing + sid = new_span_id() + state.pending_task_span_by_lc[key] = sid + return sid + + +def open_task_span( + state: AgentEventRelayState, + *, + run_id: str, + langchain_tool_call_id: str | None = None, +) -> str: + """Set ``active_span_id`` from pending (same lc) or mint; remember ``active_task_run_id``. + + Call when the ``task`` tool **starts**. Nested ``task`` is not supported: + a second call replaces the previous span without restoring it. + """ + key = _lc_key(langchain_tool_call_id) + sid: str | None = state.pending_task_span_by_lc.pop(key, None) if key else None + if not sid: + sid = new_span_id() + state.active_span_id = sid + state.active_task_run_id = _run_key(run_id) or None + return sid + + +def clear_task_span_if_delegating_task_ended( + state: AgentEventRelayState, + *, + tool_name: str, + run_id: str, +) -> None: + """Clear span state only when this event is the end of the opening ``task`` run.""" + if tool_name != "task": + return + if state.active_task_run_id is None: + return + if state.active_task_run_id != _run_key(run_id): + return + state.active_span_id = None + state.active_task_run_id = None diff --git a/surfsense_backend/app/tasks/chat/streaming/relay/thinking_step_completion.py b/surfsense_backend/app/tasks/chat/streaming/relay/thinking_step_completion.py new file mode 100644 index 000000000..ad0930341 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/relay/thinking_step_completion.py @@ -0,0 +1,34 @@ +"""Close the in-progress thinking step with a completed status frame.""" + +from __future__ import annotations + +from typing import Any + +from .state import AgentEventRelayState +from .thinking_step_sse import emit_thinking_step_frame + + +def complete_active_thinking_step( + *, + state: AgentEventRelayState, + streaming_service: Any, + content_builder: Any | None, + last_active_step_id: str | None, + last_active_step_title: str, + last_active_step_items: list[str], + completed_step_ids: set[str], +) -> tuple[str | None, str | None]: + """Emit a completed thinking-step frame once; return (frame or None, next active step id).""" + if last_active_step_id and last_active_step_id not in completed_step_ids: + completed_step_ids.add(last_active_step_id) + event = emit_thinking_step_frame( + streaming_service=streaming_service, + content_builder=content_builder, + step_id=last_active_step_id, + title=last_active_step_title, + status="completed", + items=last_active_step_items if last_active_step_items else None, + metadata=state.span_metadata_if_active(), + ) + return event, None + return None, last_active_step_id diff --git a/surfsense_backend/app/tasks/chat/streaming/relay/thinking_step_sse.py b/surfsense_backend/app/tasks/chat/streaming/relay/thinking_step_sse.py new file mode 100644 index 000000000..6737f536b --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/relay/thinking_step_sse.py @@ -0,0 +1,28 @@ +"""Thinking-step SSE plus optional content-builder updates.""" + +from __future__ import annotations + +from typing import Any + + +def emit_thinking_step_frame( + *, + streaming_service: Any, + content_builder: Any | None, + step_id: str, + title: str, + status: str = "in_progress", + items: list[str] | None = None, + metadata: dict[str, Any] | None = None, +) -> str: + if content_builder is not None: + content_builder.on_thinking_step( + step_id, title, status, items, metadata=metadata + ) + return streaming_service.format_thinking_step( + step_id=step_id, + title=title, + status=status, + items=items, + metadata=metadata, + ) diff --git a/surfsense_backend/pyproject.toml b/surfsense_backend/pyproject.toml index 4235ac962..523a8a1ac 100644 --- a/surfsense_backend/pyproject.toml +++ b/surfsense_backend/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "surf-new-backend" -version = "0.0.22" +version = "0.0.23" description = "SurfSense Backend" requires-python = ">=3.12" dependencies = [ diff --git a/surfsense_backend/tests/integration/chat/test_persistence.py b/surfsense_backend/tests/integration/chat/test_persistence.py index 66a04772e..d6f816cc0 100644 --- a/surfsense_backend/tests/integration/chat/test_persistence.py +++ b/surfsense_backend/tests/integration/chat/test_persistence.py @@ -367,18 +367,26 @@ class TestPersistUserTurn: db_thread, patched_shielded_session, ): - """The full ``{id, title, document_type}`` triple forwarded by - the FE must round-trip into a single ``mentioned-documents`` - ContentPart on the persisted user message — the history loader - renders the chips on reload from this part directly. + """The full ``{id, title, document_type, kind}`` chip metadata + forwarded by the FE must round-trip into a single + ``mentioned-documents`` ContentPart on the persisted user + message — the history loader renders the chips on reload from + this part directly. Folder chips ride alongside doc chips so + the FE can render mixed mention bars without a second fetch. """ thread_id = db_thread.id user_id_str = str(db_user.id) turn_id = f"{thread_id}:8200" mentioned = [ - {"id": 11, "title": "Alpha", "document_type": "GENERAL"}, - {"id": 22, "title": "Beta", "document_type": "GENERAL"}, + {"id": 11, "title": "Alpha", "document_type": "GENERAL", "kind": "doc"}, + {"id": 22, "title": "Beta", "document_type": "GENERAL", "kind": "doc"}, + { + "id": 33, + "title": "Reports", + "document_type": "FOLDER", + "kind": "folder", + }, ] msg_id = await persist_user_turn( chat_id=thread_id, @@ -397,8 +405,61 @@ class TestPersistUserTurn: assert row.content[1] == { "type": "mentioned-documents", "documents": [ - {"id": 11, "title": "Alpha", "document_type": "GENERAL"}, - {"id": 22, "title": "Beta", "document_type": "GENERAL"}, + {"id": 11, "title": "Alpha", "document_type": "GENERAL", "kind": "doc"}, + {"id": 22, "title": "Beta", "document_type": "GENERAL", "kind": "doc"}, + { + "id": 33, + "title": "Reports", + "document_type": "FOLDER", + "kind": "folder", + }, + ], + } + + async def test_legacy_chip_without_kind_defaults_to_doc( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + ): + """Pre-folder clients send chips without ``kind``. The persistence + layer defaults them to ``"doc"`` so the round-trip stays + consistent on reload — the FE schema's optional default + produces the same value, but persisting it explicitly keeps + the DB row self-describing. + """ + thread_id = db_thread.id + user_id_str = str(db_user.id) + turn_id = f"{thread_id}:8201" + + mentioned = [ + {"id": 77, "title": "Legacy", "document_type": "GENERAL"}, + ] + msg_id = await persist_user_turn( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + user_query="hi", + mentioned_documents=mentioned, + ) + assert isinstance(msg_id, int) + + row = await db_session.get(NewChatMessage, msg_id) + assert row is not None + assert isinstance(row.content, list) + mentioned_part = next( + p for p in row.content if p.get("type") == "mentioned-documents" + ) + assert mentioned_part == { + "type": "mentioned-documents", + "documents": [ + { + "id": 77, + "title": "Legacy", + "document_type": "GENERAL", + "kind": "doc", + }, ], } diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/__init__.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/__init__.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/__init__.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_hitl_bridge.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_hitl_bridge.py new file mode 100644 index 000000000..dbc2c9c00 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_hitl_bridge.py @@ -0,0 +1,208 @@ +"""End-to-end resume-bridge tests against a real LangGraph subagent.""" + +from __future__ import annotations + +import ast + +import pytest +from langchain.tools import ToolRuntime +from langchain_core.messages import HumanMessage +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import END, START, StateGraph +from langgraph.types import Command, interrupt +from typing_extensions import TypedDict + +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( + build_task_tool_with_parent_config, +) + + +class _SubagentState(TypedDict, total=False): + messages: list + decision_text: str + + +def _build_single_interrupt_subagent(): + def approve_node(state): + from langchain_core.messages import AIMessage + + decision = interrupt( + { + "action_requests": [ + { + "name": "do_thing", + "args": {"x": 1}, + "description": "test action", + } + ], + "review_configs": [{}], + } + ) + return { + "messages": [AIMessage(content="done")], + "decision_text": repr(decision), + } + + graph = StateGraph(_SubagentState) + graph.add_node("approve", approve_node) + graph.add_edge(START, "approve") + graph.add_edge("approve", END) + return graph.compile(checkpointer=InMemorySaver()) + + +def _make_runtime(config: dict) -> ToolRuntime: + return ToolRuntime( + state={"messages": [HumanMessage(content="seed")]}, + context=None, + config=config, + stream_writer=None, + tool_call_id="parent-tcid-1", + store=None, + ) + + +@pytest.mark.asyncio +async def test_resume_bridge_dispatches_decision_into_pending_subagent(): + """Side-channel decision must reach the subagent's pending interrupt verbatim.""" + subagent = _build_single_interrupt_subagent() + task_tool = build_task_tool_with_parent_config( + [ + { + "name": "approver", + "description": "approves things", + "runnable": subagent, + } + ] + ) + + parent_config: dict = { + "configurable": {"thread_id": "shared-thread"}, + "recursion_limit": 100, + } + await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config) + snap = await subagent.aget_state(parent_config) + assert snap.tasks and snap.tasks[0].interrupts, ( + "fixture broken: subagent should be paused on its interrupt" + ) + + parent_config["configurable"]["surfsense_resume_value"] = { + "decisions": ["APPROVED"] + } + runtime = _make_runtime(parent_config) + + result = await task_tool.coroutine( + description="please approve", + subagent_type="approver", + runtime=runtime, + ) + + assert isinstance(result, Command) + update = result.update + assert update["decision_text"] == repr({"decisions": ["APPROVED"]}) + assert "surfsense_resume_value" not in parent_config["configurable"] + + final = await subagent.aget_state(parent_config) + assert not final.tasks or all(not t.interrupts for t in final.tasks) + + +@pytest.mark.asyncio +async def test_pending_interrupt_without_resume_value_raises_runtime_error(): + """Bridge must fail loud rather than silently replay the user's interrupt.""" + subagent = _build_single_interrupt_subagent() + task_tool = build_task_tool_with_parent_config( + [ + { + "name": "approver", + "description": "approves things", + "runnable": subagent, + } + ] + ) + + parent_config: dict = { + "configurable": {"thread_id": "guard-thread"}, + "recursion_limit": 100, + } + await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config) + snap = await subagent.aget_state(parent_config) + assert snap.tasks and snap.tasks[0].interrupts, "fixture broken" + + runtime = _make_runtime(parent_config) + + with pytest.raises(RuntimeError, match="resume bridge is broken"): + await task_tool.coroutine( + description="please approve", + subagent_type="approver", + runtime=runtime, + ) + + +def _build_bundle_subagent(): + def bundle_node(state): + from langchain_core.messages import AIMessage + + decision = interrupt( + { + "action_requests": [ + {"name": "create_a", "args": {}, "description": ""}, + {"name": "create_b", "args": {}, "description": ""}, + {"name": "create_c", "args": {}, "description": ""}, + ], + "review_configs": [{}, {}, {}], + } + ) + return { + "messages": [AIMessage(content="bundle-done")], + "decision_text": repr(decision), + } + + graph = StateGraph(_SubagentState) + graph.add_node("bundle", bundle_node) + graph.add_edge(START, "bundle") + graph.add_edge("bundle", END) + return graph.compile(checkpointer=InMemorySaver()) + + +@pytest.mark.asyncio +async def test_bundle_three_mixed_decisions_arrive_in_order(): + """Approve / edit / reject for a 3-action bundle must land at ordinals 0/1/2.""" + subagent = _build_bundle_subagent() + task_tool = build_task_tool_with_parent_config( + [ + { + "name": "bundler", + "description": "creates a bundle", + "runnable": subagent, + } + ] + ) + + parent_config: dict = { + "configurable": {"thread_id": "bundle-thread"}, + "recursion_limit": 100, + } + await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config) + + decisions_payload = { + "decisions": [ + {"type": "approve", "args": {}}, + {"type": "edit", "args": {"args": {"name": "edited-b"}}}, + {"type": "reject", "args": {"message": "no thanks"}}, + ] + } + parent_config["configurable"]["surfsense_resume_value"] = decisions_payload + runtime = _make_runtime(parent_config) + + result = await task_tool.coroutine( + description="run bundle", + subagent_type="bundler", + runtime=runtime, + ) + + assert isinstance(result, Command) + received = ast.literal_eval(result.update["decision_text"]) + assert received == decisions_payload + assert received["decisions"][0]["type"] == "approve" + assert received["decisions"][1]["type"] == "edit" + assert received["decisions"][1]["args"] == {"args": {"name": "edited-b"}} + assert received["decisions"][2]["type"] == "reject" diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_pending_interrupt.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_pending_interrupt.py new file mode 100644 index 000000000..75242689d --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_pending_interrupt.py @@ -0,0 +1,55 @@ +"""Pins the first-wins assumption of ``get_first_pending_subagent_interrupt``. + +The bridge currently relies on at-most-one pending interrupt per snapshot +(sequential tool nodes). If parallel tool calls are ever enabled, the bridge +needs an id-aware lookup; these tests will need to be revisited at that point. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume import ( + get_first_pending_subagent_interrupt, +) + + +class TestGetFirstPendingSubagentInterrupt: + def test_returns_first_when_multiple_top_level_interrupts_pending(self): + first = SimpleNamespace(id="i-1", value={"decision": "approve"}) + second = SimpleNamespace(id="i-2", value={"decision": "reject"}) + state = SimpleNamespace(interrupts=(first, second), tasks=()) + + assert get_first_pending_subagent_interrupt(state) == ( + "i-1", + {"decision": "approve"}, + ) + + def test_returns_first_when_multiple_subtask_interrupts_pending(self): + first = SimpleNamespace(id="i-A", value="approve") + second = SimpleNamespace(id="i-B", value="reject") + sub_task = SimpleNamespace(interrupts=(first, second)) + state = SimpleNamespace(interrupts=(), tasks=(sub_task,)) + + assert get_first_pending_subagent_interrupt(state) == ("i-A", "approve") + + def test_returns_none_when_no_interrupts(self): + state = SimpleNamespace(interrupts=(), tasks=()) + + assert get_first_pending_subagent_interrupt(state) == (None, None) + + def test_returns_none_when_state_is_none(self): + assert get_first_pending_subagent_interrupt(None) == (None, None) + + def test_skips_interrupts_with_none_value(self): + empty = SimpleNamespace(id="i-empty", value=None) + real = SimpleNamespace(id="i-real", value="approve") + state = SimpleNamespace(interrupts=(empty, real), tasks=()) + + assert get_first_pending_subagent_interrupt(state) == ("i-real", "approve") + + def test_normalizes_non_string_id_to_none(self): + interrupt = SimpleNamespace(id=12345, value="approve") + state = SimpleNamespace(interrupts=(interrupt,), tasks=()) + + assert get_first_pending_subagent_interrupt(state) == (None, "approve") diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_resume_helpers.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_resume_helpers.py new file mode 100644 index 000000000..347b32dbd --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_resume_helpers.py @@ -0,0 +1,71 @@ +"""Resume side-channel must be read exactly once per turn.""" + +from __future__ import annotations + +from langchain.tools import ToolRuntime + +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import ( + consume_surfsense_resume, + has_surfsense_resume, +) + + +def _runtime_with_config(config: dict) -> ToolRuntime: + return ToolRuntime( + state=None, + context=None, + config=config, + stream_writer=None, + tool_call_id="tcid-test", + store=None, + ) + + +class TestConsumeSurfsenseResume: + def test_pops_value_on_first_call(self): + runtime = _runtime_with_config( + {"configurable": {"surfsense_resume_value": {"decisions": ["approve"]}}} + ) + + assert consume_surfsense_resume(runtime) == {"decisions": ["approve"]} + + def test_second_call_returns_none(self): + configurable: dict = {"surfsense_resume_value": {"decisions": ["approve"]}} + runtime = _runtime_with_config({"configurable": configurable}) + + consume_surfsense_resume(runtime) + + assert consume_surfsense_resume(runtime) is None + assert "surfsense_resume_value" not in configurable + + def test_returns_none_when_no_payload_queued(self): + runtime = _runtime_with_config({"configurable": {}}) + + assert consume_surfsense_resume(runtime) is None + + def test_returns_none_when_configurable_missing(self): + runtime = _runtime_with_config({}) + + assert consume_surfsense_resume(runtime) is None + + +class TestHasSurfsenseResume: + def test_true_when_payload_queued(self): + runtime = _runtime_with_config( + {"configurable": {"surfsense_resume_value": "approve"}} + ) + + assert has_surfsense_resume(runtime) is True + + def test_does_not_consume_payload(self): + configurable = {"surfsense_resume_value": "approve"} + runtime = _runtime_with_config({"configurable": configurable}) + + has_surfsense_resume(runtime) + + assert configurable == {"surfsense_resume_value": "approve"} + + def test_false_when_payload_absent(self): + runtime = _runtime_with_config({"configurable": {}}) + + assert has_surfsense_resume(runtime) is False diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/__init__.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/__init__.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/test_subagent_builder.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/test_subagent_builder.py new file mode 100644 index 000000000..648e52115 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/test_subagent_builder.py @@ -0,0 +1,96 @@ +"""Subagent resilience contract: ``extra_middleware`` reaches the agent chain.""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Iterator +from typing import Any + +import pytest +from langchain.agents import create_agent +from langchain.agents.middleware import ModelFallbackMiddleware +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.language_models.fake_chat_models import ( + FakeMessagesListChatModel, +) +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.outputs import ChatGeneration, ChatResult + +from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( + pack_subagent, +) + + +class RateLimitError(Exception): + """Name matches the scoped-fallback eligibility allowlist.""" + + +class _AlwaysFailingChatModel(BaseChatModel): + @property + def _llm_type(self) -> str: + return "always-failing-test-model" + + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + msg = "primary llm exploded" + raise RateLimitError(msg) + + async def _agenerate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + msg = "primary llm exploded" + raise RateLimitError(msg) + + def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGeneration]: + msg = "primary llm exploded" + raise RateLimitError(msg) + + async def _astream( + self, *args: Any, **kwargs: Any + ) -> AsyncIterator[ChatGeneration]: + msg = "primary llm exploded" + raise RateLimitError(msg) + yield # pragma: no cover - unreachable, satisfies async generator typing + + +@pytest.mark.asyncio +async def test_subagent_recovers_when_primary_llm_fails(): + """Fallback in ``extra_middleware`` must finish the turn when primary raises.""" + primary = _AlwaysFailingChatModel() + fallback = FakeMessagesListChatModel( + responses=[AIMessage(content="recovered via fallback")] + ) + + spec = pack_subagent( + name="resilience_test", + description="test subagent", + system_prompt="be helpful", + tools=[], + model=primary, + extra_middleware=[ModelFallbackMiddleware(fallback)], + ) + + agent = create_agent( + model=spec["model"], + tools=spec["tools"], + middleware=spec["middleware"], + system_prompt=spec["system_prompt"], + ) + + result = await agent.ainvoke({"messages": [HumanMessage(content="hi")]}) + + final = result["messages"][-1] + assert isinstance(final, AIMessage) + assert final.content == "recovered via fallback" diff --git a/surfsense_backend/tests/unit/agents/new_chat/middleware/__init__.py b/surfsense_backend/tests/unit/agents/new_chat/middleware/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/agents/new_chat/middleware/test_scoped_model_fallback.py b/surfsense_backend/tests/unit/agents/new_chat/middleware/test_scoped_model_fallback.py new file mode 100644 index 000000000..80b9862e7 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/middleware/test_scoped_model_fallback.py @@ -0,0 +1,128 @@ +"""``ScopedModelFallbackMiddleware`` triggers fallback only on provider errors.""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Iterator +from typing import Any + +import pytest +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import AIMessage, BaseMessage +from langchain_core.outputs import ChatGeneration, ChatResult + + +class _RaisingChatModel(BaseChatModel): + exc_to_raise: Any + + @property + def _llm_type(self) -> str: + return "raising-test-model" + + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + raise self.exc_to_raise + + async def _agenerate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + raise self.exc_to_raise + + def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGeneration]: + raise self.exc_to_raise + + async def _astream( + self, *args: Any, **kwargs: Any + ) -> AsyncIterator[ChatGeneration]: + raise self.exc_to_raise + yield # pragma: no cover - unreachable + + +class _RecordingChatModel(BaseChatModel): + response_text: str = "fallback-ok" + call_count: int = 0 + + @property + def _llm_type(self) -> str: + return "recording-test-model" + + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + self.call_count += 1 + return ChatResult( + generations=[ChatGeneration(message=AIMessage(content=self.response_text))] + ) + + async def _agenerate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + return self._generate(messages, stop, None, **kwargs) + + +class RateLimitError(Exception): + """Name matches the scoped-fallback eligibility allowlist.""" + + +def _build_agent(primary: BaseChatModel, fallback: BaseChatModel): + from langchain.agents import create_agent + + from app.agents.new_chat.middleware.scoped_model_fallback import ( + ScopedModelFallbackMiddleware, + ) + + return create_agent( + model=primary, + tools=[], + middleware=[ScopedModelFallbackMiddleware(fallback)], + system_prompt="be helpful", + ) + + +@pytest.mark.asyncio +async def test_provider_errors_trigger_fallback(): + """Eligible exception names must drive the fallback chain.""" + primary = _RaisingChatModel(exc_to_raise=RateLimitError("429 from provider")) + fallback = _RecordingChatModel(response_text="recovered") + + agent = _build_agent(primary, fallback) + result = await agent.ainvoke({"messages": [("user", "hi")]}) + + final = result["messages"][-1] + assert isinstance(final, AIMessage) + assert final.content == "recovered" + assert fallback.call_count == 1 + + +@pytest.mark.asyncio +async def test_programming_errors_propagate_without_invoking_fallback(): + """Non-eligible exceptions must propagate; fallback must not be invoked.""" + primary = _RaisingChatModel(exc_to_raise=KeyError("missing_state_field")) + fallback = _RecordingChatModel(response_text="should-never-arrive") + + agent = _build_agent(primary, fallback) + + with pytest.raises(KeyError, match="missing_state_field"): + await agent.ainvoke({"messages": [("user", "hi")]}) + + assert fallback.call_count == 0 diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py b/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py index 6800be2af..099aea882 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py @@ -31,7 +31,6 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None: "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", "SURFSENSE_ENABLE_ACTION_LOG", "SURFSENSE_ENABLE_REVERT_ROUTE", - "SURFSENSE_ENABLE_STREAM_PARITY_V2", "SURFSENSE_ENABLE_PLUGIN_LOADER", "SURFSENSE_ENABLE_OTEL", "SURFSENSE_ENABLE_AGENT_CACHE", @@ -61,7 +60,6 @@ def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) -> assert flags.enable_kb_planner_runnable is True assert flags.enable_action_log is True assert flags.enable_revert_route is True - assert flags.enable_stream_parity_v2 is True assert flags.enable_plugin_loader is False assert flags.enable_otel is False # Phase 2: agent cache is now default-on (the prerequisite tool @@ -127,7 +125,6 @@ def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) -> "enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", "enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG", "enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE", - "enable_stream_parity_v2": "SURFSENSE_ENABLE_STREAM_PARITY_V2", "enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER", "enable_otel": "SURFSENSE_ENABLE_OTEL", } diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_mention_resolver.py b/surfsense_backend/tests/unit/agents/new_chat/test_mention_resolver.py new file mode 100644 index 000000000..1f8d35841 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_mention_resolver.py @@ -0,0 +1,285 @@ +"""Tests for the @-mention resolver. + +These tests pin down the contract that ``mention_resolver`` is the +single seam between ``MentionedDocumentInfo`` chips (frontend) and the +canonical ``/documents/...`` virtual paths (agent). The streaming task, +priority middleware, and persistence layer all consume the resolver's +output — keeping the tests focused on substitute-in-text + the +returned id partition keeps the seam stable across refactors. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.agents.new_chat import mention_resolver +from app.agents.new_chat.mention_resolver import ( + ResolvedMention, + ResolvedMentionSet, + resolve_mentions, + substitute_in_text, +) +from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT, PathIndex +from app.schemas.new_chat import MentionedDocumentInfo + +pytestmark = pytest.mark.unit + + +class TestSubstituteInText: + """``substitute_in_text`` is a pure string transform and is exercised + on every cloud-mode turn, so it has to be both fast and behaviour- + identical to the frontend's ``parseMentionSegments`` (longest-token + first, single forward pass).""" + + def test_returns_text_unchanged_when_no_tokens(self): + assert substitute_in_text("hello @foo", []) == "hello @foo" + + def test_returns_text_unchanged_when_empty(self): + assert substitute_in_text("", [("@x", "/documents/x.xml")]) == "" + + def test_replaces_single_token_with_backtick_path(self): + out = substitute_in_text( + "see @notes please", + [("@notes", "/documents/notes.xml")], + ) + assert out == "see `/documents/notes.xml` please" + + def test_longest_token_wins_over_prefix(self): + # ``@Project Roadmap`` must NOT be partially matched by ``@Project``. + # Mirrors the FE's parseMentionSegments contract. + token_to_path = [ + ("@Project Roadmap", "/documents/Roadmap.xml"), + ("@Project", "/documents/Project.xml"), + ] + out = substitute_in_text("about @Project Roadmap today", token_to_path) + assert out == "about `/documents/Roadmap.xml` today" + + def test_handles_repeated_mentions(self): + out = substitute_in_text( + "@A and @A again @B", + [ + ("@A", "/documents/a.xml"), + ("@B", "/documents/b.xml"), + ], + ) + assert ( + out == "`/documents/a.xml` and `/documents/a.xml` again `/documents/b.xml`" + ) + + def test_does_not_match_inside_word(self): + # Substitution is positional — there's no word-boundary semantics. + # ``@Pro`` inside ``foo@Project`` still matches; this is the same + # behaviour as parseMentionSegments. The test pins it so a + # future "fix" doesn't accidentally diverge between FE/BE. + out = substitute_in_text("foo@Pro", [("@Pro", "/documents/p.xml")]) + assert out == "foo`/documents/p.xml`" + + def test_idempotent_after_substitution(self): + # The output starts with a backtick, not ``@``, so re-running + # the substitution leaves it alone. + once = substitute_in_text("@A", [("@A", "/documents/a.xml")]) + twice = substitute_in_text(once, [("@A", "/documents/a.xml")]) + assert once == twice + + +class TestResolveMentions: + """``resolve_mentions`` resolves chip ids → virtual paths and emits + a ``ResolvedMentionSet`` whose id partitions feed + ``KnowledgePriorityMiddleware``.""" + + @pytest.mark.asyncio + async def test_returns_empty_when_no_mentions(self): + session = MagicMock() + session.execute = AsyncMock() + result = await resolve_mentions( + session, + search_space_id=1, + mentioned_documents=None, + ) + assert isinstance(result, ResolvedMentionSet) + assert result.mentions == [] + assert result.token_to_path == [] + assert result.mentioned_document_ids == [] + assert result.mentioned_folder_ids == [] + # No DB roundtrips when there's nothing to resolve. + session.execute.assert_not_awaited() + + @pytest.mark.asyncio + async def test_resolves_doc_chip_to_virtual_path(self, monkeypatch): + chip = MentionedDocumentInfo( + id=42, + title="Notes", + document_type="EXTENSION", + kind="doc", + ) + doc_row = SimpleNamespace(id=42, title="Notes", folder_id=None) + + async def fake_build_index(_session, _ssid): + return PathIndex() + + monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index) + + scalars = MagicMock() + scalars.all.return_value = [doc_row] + result = MagicMock() + result.scalars.return_value = scalars + session = MagicMock() + session.execute = AsyncMock(return_value=result) + + out = await resolve_mentions( + session, + search_space_id=5, + mentioned_documents=[chip], + ) + assert len(out.mentions) == 1 + mention = out.mentions[0] + assert mention.kind == "doc" + assert mention.id == 42 + assert mention.virtual_path == f"{DOCUMENTS_ROOT}/Notes.xml" + assert out.mentioned_document_ids == [42] + assert out.mentioned_folder_ids == [] + assert ("@Notes", f"{DOCUMENTS_ROOT}/Notes.xml") in out.token_to_path + + @pytest.mark.asyncio + async def test_resolves_folder_chip_with_trailing_slash(self, monkeypatch): + chip = MentionedDocumentInfo( + id=9, + title="Reports", + document_type="FOLDER", + kind="folder", + ) + folder_row = SimpleNamespace(id=9, name="Reports") + + async def fake_build_index(_session, _ssid): + return PathIndex(folder_paths={9: f"{DOCUMENTS_ROOT}/Reports"}) + + monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index) + + scalars = MagicMock() + scalars.all.return_value = [folder_row] + result = MagicMock() + result.scalars.return_value = scalars + session = MagicMock() + session.execute = AsyncMock(return_value=result) + + out = await resolve_mentions( + session, + search_space_id=3, + mentioned_documents=[chip], + ) + assert len(out.mentions) == 1 + mention = out.mentions[0] + assert mention.kind == "folder" + assert mention.id == 9 + assert mention.virtual_path == f"{DOCUMENTS_ROOT}/Reports/" + assert out.mentioned_document_ids == [] + assert out.mentioned_folder_ids == [9] + + @pytest.mark.asyncio + async def test_drops_chip_when_doc_is_missing(self, monkeypatch): + chip = MentionedDocumentInfo( + id=99, title="ghost", document_type="EXTENSION", kind="doc" + ) + + async def fake_build_index(_session, _ssid): + return PathIndex() + + monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index) + + scalars = MagicMock() + scalars.all.return_value = [] + result = MagicMock() + result.scalars.return_value = scalars + session = MagicMock() + session.execute = AsyncMock(return_value=result) + + out = await resolve_mentions( + session, + search_space_id=1, + mentioned_documents=[chip], + ) + assert out.mentions == [] + assert out.mentioned_document_ids == [] + assert out.token_to_path == [] + + @pytest.mark.asyncio + async def test_token_to_path_is_longest_first(self, monkeypatch): + # Two chips whose titles are prefixes of each other — the + # resolver MUST sort longest-first so substitution doesn't + # break the ``@Project Roadmap`` vs ``@Project`` invariant. + chip_short = MentionedDocumentInfo( + id=1, title="A", document_type="EXTENSION", kind="doc" + ) + chip_long = MentionedDocumentInfo( + id=2, title="A long one", document_type="EXTENSION", kind="doc" + ) + rows = [ + SimpleNamespace(id=1, title="A", folder_id=None), + SimpleNamespace(id=2, title="A long one", folder_id=None), + ] + + async def fake_build_index(_session, _ssid): + return PathIndex() + + monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index) + + scalars = MagicMock() + scalars.all.return_value = rows + result = MagicMock() + result.scalars.return_value = scalars + session = MagicMock() + session.execute = AsyncMock(return_value=result) + + out = await resolve_mentions( + session, + search_space_id=1, + mentioned_documents=[chip_short, chip_long], + ) + tokens = [tok for tok, _ in out.token_to_path] + assert tokens == sorted(tokens, key=len, reverse=True) + + @pytest.mark.asyncio + async def test_legacy_id_arrays_resolve_without_chip_metadata(self, monkeypatch): + # ``mentioned_document_ids`` (the legacy parallel array) must + # still resolve when no chip metadata is available — covers + # callers that haven't migrated to the discriminated chip list. + doc_row = SimpleNamespace(id=7, title="Legacy", folder_id=None) + + async def fake_build_index(_session, _ssid): + return PathIndex() + + monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index) + + scalars = MagicMock() + scalars.all.return_value = [doc_row] + result = MagicMock() + result.scalars.return_value = scalars + session = MagicMock() + session.execute = AsyncMock(return_value=result) + + out = await resolve_mentions( + session, + search_space_id=2, + mentioned_documents=None, + mentioned_document_ids=[7], + ) + assert out.mentioned_document_ids == [7] + assert len(out.mentions) == 1 + assert out.mentions[0].title == "Legacy" + + +class TestResolvedMentionEquality: + """Smoke check on the dataclass behaviour we rely on for asserting + test outputs.""" + + def test_equal_when_fields_equal(self): + a = ResolvedMention( + kind="doc", id=1, title="x", virtual_path="/documents/x.xml" + ) + b = ResolvedMention( + kind="doc", id=1, title="x", virtual_path="/documents/x.xml" + ) + assert a == b diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_path_resolver.py b/surfsense_backend/tests/unit/agents/new_chat/test_path_resolver.py index ddb20330d..ac6f61767 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_path_resolver.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_path_resolver.py @@ -196,3 +196,50 @@ class TestVirtualPathToDoc: ) assert document is target_doc assert session.execute.await_count == 2 + + @pytest.mark.asyncio + async def test_resolves_double_extension_for_uploaded_pdf(self): + # Regression: the agent renders every KB document under + # ``/documents/`` with a trailing ``.xml`` (via ``safe_filename``), + # so an uploaded PDF whose DB title is ``2025-W2.pdf`` shows up as + # ``/documents/2025-W2.pdf.xml`` in answers. Clicking that path + # must round-trip back to the row even though the title itself + # does NOT end in ``.xml``. + target_doc = SimpleNamespace(id=99, title="2025-W2.pdf", folder_id=None) + + session = MagicMock() + session.execute = AsyncMock( + side_effect=[ + _result_from_one(None), + _result_from_scalars([target_doc]), + ] + ) + + document = await virtual_path_to_doc( + session, + search_space_id=5, + virtual_path=f"{DOCUMENTS_ROOT}/2025-W2.pdf.xml", + ) + assert document is target_doc + + @pytest.mark.asyncio + async def test_resolves_path_without_xml_suffix(self): + # The user (or a hand-edited link) may pass the title-only form + # ``/documents/2025-W2.pdf``. The resolver must still find the row + # by literal title equality. + target_doc = SimpleNamespace(id=99, title="2025-W2.pdf", folder_id=None) + + session = MagicMock() + session.execute = AsyncMock( + side_effect=[ + _result_from_one(None), + _result_from_scalars([target_doc]), + ] + ) + + document = await virtual_path_to_doc( + session, + search_space_id=5, + virtual_path=f"{DOCUMENTS_ROOT}/2025-W2.pdf", + ) + assert document is target_doc diff --git a/surfsense_backend/tests/unit/middleware/test_knowledge_search.py b/surfsense_backend/tests/unit/middleware/test_knowledge_search.py index 2933a0504..3529a946b 100644 --- a/surfsense_backend/tests/unit/middleware/test_knowledge_search.py +++ b/surfsense_backend/tests/unit/middleware/test_knowledge_search.py @@ -202,6 +202,15 @@ class FakeBudgetLLM: class TestKnowledgeBaseSearchMiddlewarePlanner: + @pytest.fixture(autouse=True) + def _disable_planner_runnable(self, monkeypatch): + # ``FakeLLM`` is a duck-typed mock; ``create_agent`` (used when the + # planner Runnable path is enabled) calls ``.bind()`` on the LLM, + # which the mock does not implement. Pin the flag off so the + # planner falls through to the legacy ``self.llm.ainvoke`` path + # these tests assert against (``llm.calls[0]["config"]``). + monkeypatch.setenv("SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", "false") + def test_render_recent_conversation_prefers_latest_messages_under_budget(self): messages = [ HumanMessage(content="old user context " * 40), diff --git a/surfsense_backend/tests/unit/services/streaming/__init__.py b/surfsense_backend/tests/unit/services/streaming/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/services/streaming/test_emitter.py b/surfsense_backend/tests/unit/services/streaming/test_emitter.py new file mode 100644 index 000000000..6c4e1ff58 --- /dev/null +++ b/surfsense_backend/tests/unit/services/streaming/test_emitter.py @@ -0,0 +1,79 @@ +"""Pin the wire compactness rule and the top-level ``emitted_by`` field name.""" + +from __future__ import annotations + +import pytest + +from app.services.streaming.emitter import ( + Emitter, + attach_emitted_by, + main_emitter, + subagent_emitter, +) + +pytestmark = pytest.mark.unit + + +def test_main_emitter_payload_contains_only_level() -> None: + payload = main_emitter().to_payload() + assert payload == {"level": "main"} + + +def test_subagent_emitter_payload_includes_all_set_fields() -> None: + payload = subagent_emitter( + subagent_type="deliverables", + subagent_run_id="subagent_abc", + parent_tool_call_id="call_xyz", + ).to_payload() + assert payload == { + "level": "subagent", + "subagent_type": "deliverables", + "subagent_run_id": "subagent_abc", + "parent_tool_call_id": "call_xyz", + } + + +def test_subagent_emitter_payload_omits_unset_optional_fields() -> None: + """parent_tool_call_id is None when the run is started outside a tool boundary.""" + payload = Emitter( + level="subagent", + subagent_type="email", + subagent_run_id="subagent_1", + ).to_payload() + assert "parent_tool_call_id" not in payload + assert payload["subagent_type"] == "email" + + +def test_extra_fields_merge_into_payload() -> None: + """Future extension fields (e.g. lane colour, label) flow through ``extra``.""" + emitter = subagent_emitter( + subagent_type="search", + subagent_run_id="r1", + extra={"label": "Web Search"}, + ) + assert emitter.to_payload()["label"] == "Web Search" + + +def test_attach_emitted_by_with_none_is_noop() -> None: + payload = {"type": "text-delta", "delta": "hi"} + result = attach_emitted_by(payload, None) + assert "emitted_by" not in result + assert result is payload + + +def test_attach_emitted_by_adds_payload_under_snake_case_top_level_key() -> None: + payload = {"type": "text-delta", "delta": "hi"} + attach_emitted_by( + payload, + subagent_emitter( + subagent_type="x", + subagent_run_id="y", + parent_tool_call_id="z", + ), + ) + assert payload["emitted_by"] == { + "level": "subagent", + "subagent_type": "x", + "subagent_run_id": "y", + "parent_tool_call_id": "z", + } diff --git a/surfsense_backend/tests/unit/services/streaming/test_emitter_registry.py b/surfsense_backend/tests/unit/services/streaming/test_emitter_registry.py new file mode 100644 index 000000000..e459c946a --- /dev/null +++ b/surfsense_backend/tests/unit/services/streaming/test_emitter_registry.py @@ -0,0 +1,111 @@ +"""Pin the parent_ids walk + parallel sub-agent isolation that drives lane attribution.""" + +from __future__ import annotations + +import pytest + +from app.services.streaming.emitter import ( + Emitter, + EmitterRegistry, + main_emitter, + subagent_emitter, +) + +pytestmark = pytest.mark.unit + + +def _sub(run_id: str, kind: str = "deliverables") -> Emitter: + return subagent_emitter( + subagent_type=kind, + subagent_run_id=f"sub_{run_id}", + parent_tool_call_id=f"call_{run_id}", + ) + + +def test_unregistered_event_resolves_to_main_emitter() -> None: + registry = EmitterRegistry() + resolved = registry.resolve(run_id="run_1", parent_ids=["root"]) + assert resolved is main_emitter() + + +def test_event_owned_by_registered_run_id_returns_that_emitter() -> None: + registry = EmitterRegistry() + emitter = _sub("a") + registry.register("run_task_a", emitter) + assert registry.resolve(run_id="run_task_a", parent_ids=[]) is emitter + + +def test_descendant_resolves_via_parent_ids_chain() -> None: + """A model-call event nested under the task tool inherits its sub-agent emitter.""" + registry = EmitterRegistry() + emitter = _sub("a") + registry.register("run_task_a", emitter) + descendant = registry.resolve( + run_id="run_chat_model", + parent_ids=["root", "run_agent", "run_task_a"], + ) + assert descendant is emitter + + +def test_nearest_registered_ancestor_wins_over_distant_ones() -> None: + """Inner sub-agents owe their emitter to the nearest task tool, not the outer one.""" + registry = EmitterRegistry() + outer = _sub("outer", kind="planner") + inner = _sub("inner", kind="email") + registry.register("run_outer", outer) + registry.register("run_inner", inner) + resolved = registry.resolve( + run_id="run_inner_tool", + parent_ids=["root", "run_outer", "run_inner"], + ) + assert resolved is inner + + +def test_parallel_subagents_do_not_bleed_into_each_other() -> None: + """Two concurrent task tools each own their own descendant events.""" + registry = EmitterRegistry() + a = _sub("a", kind="search") + b = _sub("b", kind="email") + registry.register("run_task_a", a) + registry.register("run_task_b", b) + + from_a = registry.resolve(run_id="x", parent_ids=["root", "run_task_a"]) + from_b = registry.resolve(run_id="y", parent_ids=["root", "run_task_b"]) + from_main = registry.resolve(run_id="z", parent_ids=["root"]) + + assert from_a is a + assert from_b is b + assert from_main is main_emitter() + + +def test_unregister_releases_run_id_so_descendants_fall_back_to_main() -> None: + registry = EmitterRegistry() + emitter = _sub("a") + registry.register("run_task_a", emitter) + registry.unregister("run_task_a") + assert registry.resolve(run_id="x", parent_ids=["run_task_a"]) is main_emitter() + + +def test_unregister_returns_the_previously_registered_emitter() -> None: + """Lets callers emit ``data-subagent-finish`` carrying the same emitter they opened with.""" + registry = EmitterRegistry() + emitter = _sub("a") + registry.register("run_task_a", emitter) + assert registry.unregister("run_task_a") is emitter + + +def test_has_active_subagents_tracks_open_lanes() -> None: + registry = EmitterRegistry() + assert not registry.has_active_subagents() + registry.register("run_task_a", _sub("a")) + assert registry.has_active_subagents() + registry.unregister("run_task_a") + assert not registry.has_active_subagents() + + +def test_empty_run_id_and_parent_ids_resolves_to_main() -> None: + """Defensive: events without identifiers always belong to the main lane.""" + registry = EmitterRegistry() + registry.register("run_task_a", _sub("a")) + assert registry.resolve(run_id=None, parent_ids=None) is main_emitter() + assert registry.resolve(run_id="", parent_ids=[]) is main_emitter() diff --git a/surfsense_backend/tests/unit/services/streaming/test_interrupt_correlation.py b/surfsense_backend/tests/unit/services/streaming/test_interrupt_correlation.py new file mode 100644 index 000000000..df89ca59a --- /dev/null +++ b/surfsense_backend/tests/unit/services/streaming/test_interrupt_correlation.py @@ -0,0 +1,158 @@ +"""Pin id-aware pending-interrupt lookup that replaces the buggy first-wins.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import pytest + +from app.services.streaming.interrupt_correlation import ( + PendingInterrupt, + first_pending_interrupt, + get_pending_interrupt_by_id, + get_pending_interrupt_for_tool_call, + list_pending_interrupts, +) + +pytestmark = pytest.mark.unit + + +@dataclass +class _Interrupt: + value: dict[str, Any] + id: str | None = None + + +@dataclass +class _Task: + interrupts: tuple[_Interrupt, ...] = () + id: str | None = None + + +@dataclass +class _State: + tasks: tuple[_Task, ...] = () + interrupts: tuple[_Interrupt, ...] = () + + +def _hitl(name: str, tool_call_id: str | None = None) -> dict[str, Any]: + """Minimal LangChain HITLRequest payload for one action.""" + action: dict[str, Any] = {"name": name, "args": {}} + if tool_call_id is not None: + action["tool_call_id"] = tool_call_id + return { + "action_requests": [action], + "review_configs": [{"action_name": name, "allowed_decisions": ["approve"]}], + } + + +def test_empty_state_has_no_pending_interrupts() -> None: + state = _State() + assert list_pending_interrupts(state) == [] + assert first_pending_interrupt(state) is None + + +def test_single_pending_interrupt_in_task_is_returned() -> None: + state = _State( + tasks=( + _Task( + id="task_1", + interrupts=(_Interrupt(value=_hitl("send_email"), id="int_1"),), + ), + ) + ) + pending = list_pending_interrupts(state) + assert len(pending) == 1 + assert pending[0] == PendingInterrupt( + interrupt_id="int_1", + value=_hitl("send_email"), + source_task_id="task_1", + ) + + +def test_pending_interrupts_returned_in_task_then_root_order() -> None: + """Determinism matters: callers iterate in this order to render the UI.""" + state = _State( + tasks=( + _Task( + id="task_a", + interrupts=(_Interrupt(value=_hitl("a"), id="int_a"),), + ), + _Task( + id="task_b", + interrupts=(_Interrupt(value=_hitl("b"), id="int_b"),), + ), + ), + interrupts=(_Interrupt(value=_hitl("c"), id="int_c"),), + ) + pending = list_pending_interrupts(state) + ids = [p.interrupt_id for p in pending] + assert ids == ["int_a", "int_b", "int_c"] + + +def test_get_by_id_finds_the_right_interrupt_under_parallel_load() -> None: + """Replacing first-wins: id-aware lookup MUST pick the requested one.""" + state = _State( + tasks=( + _Task(interrupts=(_Interrupt(value=_hitl("a"), id="int_a"),)), + _Task(interrupts=(_Interrupt(value=_hitl("b"), id="int_b"),)), + _Task(interrupts=(_Interrupt(value=_hitl("c"), id="int_c"),)), + ) + ) + found = get_pending_interrupt_by_id(state, "int_b") + assert found is not None + assert found.value["action_requests"][0]["name"] == "b" + + +def test_get_by_id_returns_none_when_id_is_not_pending() -> None: + state = _State( + tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id="int_a"),)),) + ) + assert get_pending_interrupt_by_id(state, "missing") is None + + +def test_get_by_tool_call_id_matches_action_request_payload() -> None: + """HITLRequest carries ``tool_call_id`` per action; lookup uses that.""" + state = _State( + tasks=( + _Task( + interrupts=( + _Interrupt(value=_hitl("a", tool_call_id="call_xxx"), id="int_a"), + _Interrupt(value=_hitl("b", tool_call_id="call_yyy"), id="int_b"), + ) + ), + ) + ) + found = get_pending_interrupt_for_tool_call(state, "call_yyy") + assert found is not None + assert found.interrupt_id == "int_b" + + +def test_first_pending_interrupt_matches_legacy_first_wins_behaviour() -> None: + """Sequential-turn safety: the explicit shortcut still returns the first.""" + state = _State( + tasks=(_Task(interrupts=(_Interrupt(value=_hitl("first"), id="int_1"),)),), + interrupts=(_Interrupt(value=_hitl("second"), id="int_2"),), + ) + first = first_pending_interrupt(state) + assert first is not None + assert first.interrupt_id == "int_1" + + +def test_interrupt_without_id_falls_back_to_none() -> None: + """Snapshots from older LangGraph versions may omit ``id`` — preserve that.""" + state = _State(tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id=None),)),)) + pending = list_pending_interrupts(state) + assert len(pending) == 1 + assert pending[0].interrupt_id is None + + +def test_non_dict_interrupt_values_are_ignored() -> None: + """Defensive: a non-dict value should not crash the iteration.""" + + class _Raw: + value = "not a dict" + + state = _State(tasks=(_Task(interrupts=(_Raw(),)),)) # type: ignore[arg-type] + assert list_pending_interrupts(state) == [] diff --git a/surfsense_backend/tests/unit/services/streaming/test_interrupt_events.py b/surfsense_backend/tests/unit/services/streaming/test_interrupt_events.py new file mode 100644 index 000000000..a6685b524 --- /dev/null +++ b/surfsense_backend/tests/unit/services/streaming/test_interrupt_events.py @@ -0,0 +1,89 @@ +"""Pin interrupt-payload normalisation and the optional correlation fields on the wire.""" + +from __future__ import annotations + +import json + +import pytest + +from app.services.streaming.events.interrupt import ( + format_interrupt_request, + normalize_interrupt_payload, +) + +pytestmark = pytest.mark.unit + + +def _decode(frame: str) -> dict: + body = frame.removeprefix("data: ").removesuffix("\n\n") + return json.loads(body) + + +def test_hitlrequest_shape_is_passed_through_unchanged() -> None: + raw = { + "action_requests": [{"name": "send_email", "args": {"to": "a@b"}}], + "review_configs": [ + {"action_name": "send_email", "allowed_decisions": ["approve"]} + ], + } + assert normalize_interrupt_payload(raw) == raw + + +def test_custom_interrupt_primitive_is_converted_to_canonical_shape() -> None: + raw = { + "type": "permission", + "message": "Allow send?", + "action": {"tool": "send_email", "params": {"to": "a@b"}}, + "context": {"reason": "destructive"}, + } + out = normalize_interrupt_payload(raw) + assert out["action_requests"] == [{"name": "send_email", "args": {"to": "a@b"}}] + assert out["review_configs"] == [ + { + "action_name": "send_email", + "allowed_decisions": ["approve", "edit", "reject"], + } + ] + assert out["interrupt_type"] == "permission" + assert out["message"] == "Allow send?" + assert out["context"] == {"reason": "destructive"} + + +def test_custom_interrupt_without_message_omits_message_key() -> None: + """Optional fields stay optional on the wire; FE does not see ``"message": None``.""" + raw = {"action": {"tool": "send_email"}} + out = normalize_interrupt_payload(raw) + assert "message" not in out + + +def test_custom_interrupt_without_tool_falls_back_to_unknown_tool() -> None: + """Defensive: a malformed ``action`` block must not crash the relay.""" + out = normalize_interrupt_payload({"type": "x", "action": {}}) + assert out["action_requests"][0]["name"] == "unknown_tool" + assert out["review_configs"][0]["action_name"] == "unknown_tool" + + +def test_format_interrupt_request_carries_correlation_fields_on_the_wire() -> None: + frame = format_interrupt_request( + {"action_requests": [], "review_configs": []}, + interrupt_id="int_42", + pending_interrupt_count=3, + chat_turn_id="turn_99", + ) + payload = _decode(frame) + assert payload["type"] == "data-interrupt-request" + inner = payload["data"] + assert inner["interrupt_id"] == "int_42" + assert inner["pending_interrupt_count"] == 3 + assert inner["chat_turn_id"] == "turn_99" + + +def test_format_interrupt_request_omits_correlation_fields_when_unset() -> None: + """Backward compat: legacy single-interrupt callers don't have to supply ids.""" + frame = format_interrupt_request( + {"action_requests": [], "review_configs": []}, + ) + inner = _decode(frame)["data"] + assert "interrupt_id" not in inner + assert "pending_interrupt_count" not in inner + assert "chat_turn_id" not in inner diff --git a/surfsense_backend/tests/unit/services/streaming/test_service_emitter_propagation.py b/surfsense_backend/tests/unit/services/streaming/test_service_emitter_propagation.py new file mode 100644 index 000000000..b381f13bc --- /dev/null +++ b/surfsense_backend/tests/unit/services/streaming/test_service_emitter_propagation.py @@ -0,0 +1,142 @@ +"""Pin that sub-agent emitter reaches every wire event the relay emits.""" + +from __future__ import annotations + +import json + +import pytest + +from app.services.streaming.emitter import subagent_emitter +from app.services.streaming.service import StreamingService + +pytestmark = pytest.mark.unit + + +def _decode(frame: str) -> dict: + body = frame.removeprefix("data: ").removesuffix("\n\n") + return json.loads(body) + + +@pytest.fixture +def service() -> StreamingService: + return StreamingService() + + +@pytest.fixture +def sub_emitter(): + return subagent_emitter( + subagent_type="deliverables", + subagent_run_id="sub_xyz", + parent_tool_call_id="call_parent", + ) + + +def test_text_delta_carries_subagent_emitter_on_the_wire(service, sub_emitter) -> None: + payload = _decode(service.format_text_delta("text_1", "hi", emitter=sub_emitter)) + assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz" + assert payload["delta"] == "hi" + + +def test_reasoning_delta_carries_subagent_emitter_on_the_wire( + service, sub_emitter +) -> None: + payload = _decode( + service.format_reasoning_delta("r_1", "thinking", emitter=sub_emitter) + ) + assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz" + + +def test_tool_input_start_carries_subagent_emitter_and_lc_id( + service, sub_emitter +) -> None: + payload = _decode( + service.format_tool_input_start( + "call_1", + "send_email", + langchain_tool_call_id="lc_1", + emitter=sub_emitter, + ) + ) + assert payload["emitted_by"]["subagent_type"] == "deliverables" + assert payload["langchainToolCallId"] == "lc_1" + assert payload["toolName"] == "send_email" + + +def test_tool_output_available_carries_subagent_emitter(service, sub_emitter) -> None: + payload = _decode( + service.format_tool_output_available( + "call_1", {"ok": True}, emitter=sub_emitter + ) + ) + assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz" + assert payload["output"] == {"ok": True} + + +def test_thinking_step_carries_subagent_emitter(service, sub_emitter) -> None: + payload = _decode( + service.format_thinking_step( + step_id="s1", + title="Sending email", + status="in_progress", + emitter=sub_emitter, + ) + ) + assert payload["type"] == "data-thinking-step" + assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz" + + +def test_action_log_carries_subagent_emitter(service, sub_emitter) -> None: + payload = _decode( + service.format_action_log( + {"id": 1, "tool_name": "send_email", "reversible": False}, + emitter=sub_emitter, + ) + ) + assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz" + assert payload["data"]["tool_name"] == "send_email" + + +def test_subagent_lifecycle_events_share_run_id_for_pairing( + service, sub_emitter +) -> None: + start = _decode( + service.format_subagent_start( + subagent_run_id="sub_xyz", + subagent_type="deliverables", + parent_tool_call_id="call_parent", + emitter=sub_emitter, + ) + ) + finish = _decode( + service.format_subagent_finish( + subagent_run_id="sub_xyz", + subagent_type="deliverables", + parent_tool_call_id="call_parent", + emitter=sub_emitter, + ) + ) + assert start["data"]["subagent_run_id"] == finish["data"]["subagent_run_id"] + assert start["type"] == "data-subagent-start" + assert finish["type"] == "data-subagent-finish" + + +def test_main_emitter_events_omit_emitted_by_field(service) -> None: + payload = _decode(service.format_text_delta("text_1", "hi")) + assert "emitted_by" not in payload + + +def test_resolve_emitter_through_service_uses_registry(service, sub_emitter) -> None: + service.emitter_registry.register("run_task_1", sub_emitter) + resolved = service.resolve_emitter( + run_id="run_chat_model", + parent_ids=["root", "run_task_1"], + ) + assert resolved is sub_emitter + + +def test_message_id_is_assigned_on_message_start_and_reused(service) -> None: + frame = service.format_message_start() + payload = _decode(frame) + assigned = payload["messageId"] + assert assigned.startswith("msg_") + assert service.message_id == assigned diff --git a/surfsense_backend/tests/unit/services/streaming/test_sse_envelope.py b/surfsense_backend/tests/unit/services/streaming/test_sse_envelope.py new file mode 100644 index 000000000..511e4575a --- /dev/null +++ b/surfsense_backend/tests/unit/services/streaming/test_sse_envelope.py @@ -0,0 +1,51 @@ +"""Pin the exact SSE wire bytes the FE parser depends on.""" + +from __future__ import annotations + +import json + +import pytest + +from app.services.streaming.envelope import ( + format_done, + format_sse, + get_response_headers, +) + +pytestmark = pytest.mark.unit + + +class TestFormatSse: + def test_dict_payload_is_json_serialised(self) -> None: + frame = format_sse({"type": "start", "messageId": "msg_1"}) + assert frame.startswith("data: ") + assert frame.endswith("\n\n") + body = frame[len("data: ") : -2] + assert json.loads(body) == {"type": "start", "messageId": "msg_1"} + + def test_string_payload_is_emitted_verbatim(self) -> None: + frame = format_sse('{"already":"json"}') + assert frame == 'data: {"already":"json"}\n\n' + + def test_nested_payload_round_trips(self) -> None: + payload = { + "type": "data-action-log", + "data": {"id": 7, "tool_name": "ls", "reversible": False}, + } + frame = format_sse(payload) + body = frame.removeprefix("data: ").removesuffix("\n\n") + assert json.loads(body) == payload + + +class TestFormatDone: + def test_done_marker_is_literal(self) -> None: + assert format_done() == "data: [DONE]\n\n" + + +class TestResponseHeaders: + def test_headers_pin_ai_sdk_v1_protocol(self) -> None: + headers = get_response_headers() + assert headers["Content-Type"] == "text/event-stream" + assert headers["Cache-Control"] == "no-cache" + assert headers["Connection"] == "keep-alive" + assert headers["x-vercel-ai-ui-message-stream"] == "v1" diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/__init__.py b/surfsense_backend/tests/unit/tasks/chat/streaming/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py new file mode 100644 index 000000000..d598de492 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py @@ -0,0 +1,290 @@ +"""Pin Stage 1 extractions as faithful copies of the old helpers. + +Extractions under ``app.tasks.chat.streaming`` are compared to +``app.tasks.chat.stream_new_chat`` helpers. +For each Stage 1 extraction we assert the new function returns the same +output as the old one for a representative input set. The moment the +two diverge - intentionally or otherwise - this file fails loudly so +the divergence is reviewed rather than shipped silently. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import Any + +import pytest + +from app.agents.new_chat.errors import BusyError +from app.agents.new_chat.middleware.busy_mutex import request_cancel, reset_cancel +from app.tasks.chat.stream_new_chat import ( + _classify_stream_exception as old_classify, + _emit_stream_terminal_error as old_emit_terminal_error, + _extract_chunk_parts as old_extract_chunk_parts, + _extract_resolved_file_path as old_extract_resolved_file_path, + _first_interrupt_value as old_first_interrupt_value, + _tool_output_has_error as old_tool_output_has_error, + _tool_output_to_text as old_tool_output_to_text, +) +from app.tasks.chat.streaming.errors.classifier import ( + classify_stream_exception as new_classify, +) +from app.tasks.chat.streaming.errors.emitter import ( + emit_stream_terminal_error as new_emit_terminal_error, +) +from app.tasks.chat.streaming.helpers.chunk_parts import ( + extract_chunk_parts as new_extract_chunk_parts, +) +from app.tasks.chat.streaming.helpers.interrupt_inspector import ( + first_interrupt_value as new_first_interrupt_value, +) +from app.tasks.chat.streaming.helpers.tool_output import ( + extract_resolved_file_path as new_extract_resolved_file_path, + tool_output_has_error as new_tool_output_has_error, + tool_output_to_text as new_tool_output_to_text, +) + +pytestmark = pytest.mark.unit + + +# ---------------------------------------------------------------- chunk parts + + +@dataclass +class _Chunk: + content: Any = "" + additional_kwargs: dict[str, Any] = field(default_factory=dict) + tool_call_chunks: list[dict[str, Any]] = field(default_factory=list) + + +_CHUNK_CASES: list[Any] = [ + None, + _Chunk(content=""), + _Chunk(content="hello"), + _Chunk(content=42), # invalid type, defensively coerced to empty + _Chunk( + content=[ + {"type": "text", "text": "Hello "}, + {"type": "text", "text": "world"}, + ] + ), + _Chunk( + content=[ + {"type": "reasoning", "reasoning": "hmm "}, + {"type": "reasoning", "text": "still"}, + {"type": "text", "text": "answer"}, + ] + ), + _Chunk( + content=[ + {"type": "tool_call_chunk", "id": "c1", "name": "x", "args": "{"}, + {"type": "tool_use", "id": "c2", "name": "y"}, + {"type": "image_url", "url": "ignored"}, + ] + ), + _Chunk( + content="visible", + additional_kwargs={"reasoning_content": "private"}, + ), + _Chunk( + tool_call_chunks=[ + {"id": None, "name": None, "args": '{"a":1}', "index": 0}, + {"id": "c", "name": "n", "args": "}", "index": 0}, + ] + ), + _Chunk( + content=[{"type": "tool_call_chunk", "id": "from-block", "name": "x"}], + tool_call_chunks=[{"id": "from-attr", "name": "y"}], + ), +] + + +@pytest.mark.parametrize("chunk", _CHUNK_CASES) +def test_extract_chunk_parts_matches_old_implementation(chunk: Any) -> None: + assert new_extract_chunk_parts(chunk) == old_extract_chunk_parts(chunk) + + +# ---------------------------------------------------------- interrupt inspector + + +@dataclass +class _Interrupt: + value: dict[str, Any] + + +@dataclass +class _Task: + interrupts: tuple[Any, ...] = () + + +@dataclass +class _State: + tasks: tuple[Any, ...] = () + interrupts: tuple[Any, ...] = () + + +_INTERRUPT_CASES: list[Any] = [ + _State(), + _State(tasks=(_Task(interrupts=(_Interrupt(value={"name": "send"}),)),)), + # Multiple tasks: must return the FIRST one in iteration order. + _State( + tasks=( + _Task(interrupts=(_Interrupt(value={"name": "first"}),)), + _Task(interrupts=(_Interrupt(value={"name": "second"}),)), + ) + ), + # Empty task interrupts -> falls back to root state.interrupts. + _State( + tasks=(_Task(interrupts=()),), + interrupts=(_Interrupt(value={"name": "root"}),), + ), + # Interrupts as plain dicts (not wrapper objects). + _State(interrupts=({"value": {"name": "dict_root"}},)), + # A defective task whose `.interrupts` raises - must be tolerated. + _State(tasks=(object(),)), +] + + +@pytest.mark.parametrize("state", _INTERRUPT_CASES) +def test_first_interrupt_value_matches_old_implementation(state: Any) -> None: + assert new_first_interrupt_value(state) == old_first_interrupt_value(state) + + +# ----------------------------------------------------------- error classifier + + +def _classify_cases() -> list[Exception]: + """Inputs that the FE depends on being mapped to specific error codes.""" + return [ + Exception("totally generic error"), + Exception('{"error":{"type":"rate_limit_error","message":"slow down"}}'), + Exception( + 'OpenrouterException - {"error":{"message":"Provider returned error",' + '"code":429}}' + ), + BusyError(request_id="thread-busy-parity"), + Exception("Thread is busy with another request"), + ] + + +@pytest.mark.parametrize("exc", _classify_cases()) +def test_classify_stream_exception_matches_old_implementation( + exc: Exception, +) -> None: + new = new_classify(exc, flow_label="parity-test") + old = old_classify(exc, flow_label="parity-test") + # Strip the wall-clock retry timestamp before comparing — both + # implementations call ``time.time()`` independently and the call + # order is enough to differ by 1 ms in practice. Every other field + # in the tuple must match exactly. + new_extra = dict(new[5]) if isinstance(new[5], dict) else new[5] + old_extra = dict(old[5]) if isinstance(old[5], dict) else old[5] + if isinstance(new_extra, dict) and isinstance(old_extra, dict): + new_extra.pop("retry_after_at", None) + old_extra.pop("retry_after_at", None) + assert new[:5] == old[:5] + assert new_extra == old_extra + + +def test_classify_turn_cancelling_branch_parity() -> None: + """The TURN_CANCELLING branch reads cancel state for the busy thread id; + both implementations must agree on retry-window semantics, not just the + plain THREAD_BUSY code.""" + thread_id = "parity-cancelling-thread" + reset_cancel(thread_id) + request_cancel(thread_id) + exc = BusyError(request_id=thread_id) + new = new_classify(exc, flow_label="parity-test") + old = old_classify(exc, flow_label="parity-test") + assert new[0] == old[0] == "thread_busy" + assert new[1] == old[1] == "TURN_CANCELLING" + assert isinstance(new[5], dict) and isinstance(old[5], dict) + assert new[5]["retry_after_ms"] == old[5]["retry_after_ms"] + + +# ------------------------------------------------------------ terminal emitter + + +class _FakeStreamingService: + """Duck-types ``format_error`` for both old and new emitters.""" + + def __init__(self) -> None: + self.calls: list[dict[str, Any]] = [] + + def format_error( + self, message: str, *, error_code: str, extra: dict[str, Any] | None = None + ) -> str: + self.calls.append( + {"message": message, "error_code": error_code, "extra": extra} + ) + return f'data: {{"type":"error","errorText":"{message}"}}\n\n' + + +def test_emit_stream_terminal_error_matches_old_output_and_logs(caplog) -> None: + """The new emitter must produce the same SSE frame and log the same + structured payload as the old one for the same arguments.""" + args: dict[str, Any] = { + "flow": "new", + "request_id": "req-parity", + "thread_id": 7, + "search_space_id": 9, + "user_id": "user-parity", + "message": "boom", + "error_kind": "server_error", + "error_code": "SERVER_ERROR", + "severity": "error", + "is_expected": False, + "extra": {"foo": "bar"}, + } + + new_svc = _FakeStreamingService() + old_svc = _FakeStreamingService() + + with caplog.at_level(logging.ERROR): + new_frame = new_emit_terminal_error(streaming_service=new_svc, **args) + old_frame = old_emit_terminal_error(streaming_service=old_svc, **args) + + assert new_frame == old_frame + assert new_svc.calls == old_svc.calls + chat_error_records = [ + r for r in caplog.records if "[chat_stream_error]" in r.message + ] + # One log line per emit call (two emits -> two records). + assert len(chat_error_records) == 2 + + +# ---------------------------------------------------------------- tool output + + +def test_tool_output_helpers_match_old_implementation() -> None: + samples: list[Any] = [ + {"result": "ok"}, + {"error": "bad"}, + {"result": "Error: x"}, + "Error: plain", + "fine", + {"nested": {"a": 1}}, + ] + for s in samples: + assert new_tool_output_to_text(s) == old_tool_output_to_text(s) + assert new_tool_output_has_error(s) == old_tool_output_has_error(s) + + assert new_extract_resolved_file_path( + tool_name="write_file", + tool_output={"path": " /tmp/x "}, + tool_input=None, + ) == old_extract_resolved_file_path( + tool_name="write_file", + tool_output={"path": " /tmp/x "}, + tool_input=None, + ) + assert new_extract_resolved_file_path( + tool_name="write_file", + tool_output={}, + tool_input={"file_path": " /fallback "}, + ) == old_extract_resolved_file_path( + tool_name="write_file", + tool_output={}, + tool_input={"file_path": " /fallback "}, + ) diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_2_parity.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_2_parity.py new file mode 100644 index 000000000..3ee1ab622 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_2_parity.py @@ -0,0 +1,241 @@ +"""Parity tests for Stage 2 extractions (tool matching, thinking step, custom events).""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from app.tasks.chat.stream_new_chat import _legacy_match_lc_id as old_legacy_match +from app.tasks.chat.streaming.handlers.custom_events import ( + handle_action_log, + handle_action_log_updated, + handle_document_created, + handle_report_progress, +) +from app.tasks.chat.streaming.helpers.tool_call_matching import ( + match_buffered_langchain_tool_call_id as new_legacy_match, +) +from app.tasks.chat.streaming.relay.state import AgentEventRelayState +from app.tasks.chat.streaming.relay.thinking_step_completion import ( + complete_active_thinking_step, +) +from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame + +pytestmark = pytest.mark.unit + + +def _copy_chunk_buffer(raw: list[dict[str, Any]]) -> list[dict[str, Any]]: + return [dict(x) for x in raw] + + +def test_legacy_tool_call_match_matches_old_implementation() -> None: + cases: list[tuple[list[dict[str, Any]], str, str, dict[str, str]]] = [ + ( + [ + {"name": "write_file", "id": "lc-a"}, + {"name": "other", "id": "lc-b"}, + ], + "write_file", + "run-1", + {}, + ), + ( + [{"name": "x", "id": None}, {"name": "y", "id": "lc-fallback"}], + "write_file", + "run-2", + {}, + ), + ([{"name": "no_id"}], "write_file", "run-3", {}), + ] + for chunks_template, tool_name, run_id, lc_map_seed in cases: + old_chunks = _copy_chunk_buffer(chunks_template) + new_chunks = _copy_chunk_buffer(chunks_template) + old_map = dict(lc_map_seed) + new_map = dict(lc_map_seed) + old_out = old_legacy_match(old_chunks, tool_name, run_id, old_map) + new_out = new_legacy_match(new_chunks, tool_name, run_id, new_map) + assert new_out == old_out + assert new_chunks == old_chunks + assert new_map == old_map + + +def test_emit_thinking_step_frame_invokes_builder_before_service() -> None: + order: list[str] = [] + builder = MagicMock() + + def on_ts(*args: Any, **kwargs: Any) -> None: + order.append("builder") + + builder.on_thinking_step.side_effect = on_ts + + svc = MagicMock() + + def fmt(**kwargs: Any) -> str: + order.append("service") + return "frame" + + svc.format_thinking_step.side_effect = fmt + + out = emit_thinking_step_frame( + streaming_service=svc, + content_builder=builder, + step_id="thinking-1", + title="Working", + status="in_progress", + items=["a"], + ) + assert out == "frame" + assert order == ["builder", "service"] + builder.on_thinking_step.assert_called_once() + svc.format_thinking_step.assert_called_once() + + +def test_emit_thinking_step_frame_skips_builder_when_none() -> None: + svc = MagicMock(return_value="x") + svc.format_thinking_step.return_value = "frame" + assert ( + emit_thinking_step_frame( + streaming_service=svc, + content_builder=None, + step_id="s", + title="t", + ) + == "frame" + ) + svc.format_thinking_step.assert_called_once() + + +def test_complete_active_thinking_step_mirrors_closure_semantics() -> None: + svc = MagicMock() + svc.format_thinking_step.return_value = "done-frame" + completed: set[str] = set() + relay_state = AgentEventRelayState.for_invocation() + + frame, new_id = complete_active_thinking_step( + state=relay_state, + streaming_service=svc, + content_builder=None, + last_active_step_id="thinking-1", + last_active_step_title="T", + last_active_step_items=["x"], + completed_step_ids=completed, + ) + assert frame == "done-frame" + assert new_id is None + assert "thinking-1" in completed + + frame2, id2 = complete_active_thinking_step( + state=relay_state, + streaming_service=svc, + content_builder=None, + last_active_step_id="thinking-1", + last_active_step_title="T", + last_active_step_items=[], + completed_step_ids=completed, + ) + assert frame2 is None + assert id2 == "thinking-1" + + +def test_agent_event_relay_state_factory_matches_counter_rule() -> None: + s0 = AgentEventRelayState.for_invocation() + assert s0.thinking_step_counter == 0 + assert s0.last_active_step_id is None + + s1 = AgentEventRelayState.for_invocation( + initial_step_id="thinking-resume-1", + initial_step_title="Inherited", + initial_step_items=["Topic: X"], + ) + assert s1.thinking_step_counter == 1 + assert s1.last_active_step_id == "thinking-resume-1" + assert s1.next_thinking_step_id("thinking") == "thinking-2" + + +@pytest.mark.parametrize( + ("phase", "message", "start_items", "expected_tail"), + [ + ( + "revising_section", + "progress line", + ["Topic: Foo", "Modifying bar", "stale..."], + ["Topic: Foo", "Modifying bar", "progress line"], + ), + ( + "other", + "phase msg", + ["Topic: Foo", "old line"], + ["Topic: Foo", "phase msg"], + ), + ], +) +def test_report_progress_items_match_reference( + phase: str, + message: str, + start_items: list[str], + expected_tail: list[str], +) -> None: + svc = MagicMock() + svc.format_thinking_step.return_value = "sse" + + items = list(start_items) + frame, new_items = handle_report_progress( + {"message": message, "phase": phase}, + last_active_step_id="step-1", + last_active_step_title="Report", + last_active_step_items=items, + streaming_service=svc, + content_builder=None, + ) + assert frame == "sse" + assert new_items == expected_tail + kwargs = svc.format_thinking_step.call_args.kwargs + assert kwargs["items"] == expected_tail + + +def test_report_progress_noop_when_missing_message_or_step() -> None: + svc = MagicMock() + items = ["Topic: A"] + f1, i1 = handle_report_progress( + {"message": "", "phase": "x"}, + last_active_step_id="s", + last_active_step_title="t", + last_active_step_items=items, + streaming_service=svc, + content_builder=None, + ) + assert f1 is None and i1 is items + + f2, i2 = handle_report_progress( + {"message": "m", "phase": "x"}, + last_active_step_id=None, + last_active_step_title="t", + last_active_step_items=items, + streaming_service=svc, + content_builder=None, + ) + assert f2 is None and i2 is items + + +def test_document_action_handlers_match_format_data_guards() -> None: + svc = MagicMock() + svc.format_data.return_value = "data-frame" + + assert handle_document_created({}, streaming_service=svc) is None + assert handle_document_created({"id": 0}, streaming_service=svc) is None + handle_document_created({"id": 42, "title": "x"}, streaming_service=svc) + svc.format_data.assert_called_with( + "documents-updated", {"action": "created", "document": {"id": 42, "title": "x"}} + ) + + svc.reset_mock() + assert handle_action_log({"id": None}, streaming_service=svc) is None + handle_action_log({"id": 1}, streaming_service=svc) + svc.format_data.assert_called_once_with("action-log", {"id": 1}) + + svc.reset_mock() + assert handle_action_log_updated({"id": None}, streaming_service=svc) is None + handle_action_log_updated({"id": 2}, streaming_service=svc) + svc.format_data.assert_called_once_with("action-log-updated", {"id": 2}) diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_stream_output.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stream_output.py new file mode 100644 index 000000000..c0123b76d --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stream_output.py @@ -0,0 +1,122 @@ +"""Tests for ``stream_output`` (LangGraph events → SSE).""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +import pytest + +from app.tasks.chat.streaming.graph_stream import stream_output +from app.tasks.chat.streaming.graph_stream.result import StreamingResult + +pytestmark = pytest.mark.unit + + +@dataclass +class _Chunk: + content: Any = "" + additional_kwargs: dict[str, Any] = field(default_factory=dict) + tool_call_chunks: list[dict[str, Any]] = field(default_factory=list) + + +class _StreamingService: + def __init__(self) -> None: + self._text_idx = 0 + + def generate_text_id(self) -> str: + self._text_idx += 1 + return f"text-{self._text_idx}" + + def format_text_start(self, text_id: str) -> str: + return f"text_start:{text_id}" + + def format_text_delta(self, text_id: str, text: str) -> str: + return f"text_delta:{text_id}:{text}" + + def format_text_end(self, text_id: str) -> str: + return f"text_end:{text_id}" + + +class _Agent: + def __init__(self, events: list[dict[str, Any]]) -> None: + self.events = list(events) + self.calls: list[tuple[Any, dict[str, Any]]] = [] + + async def astream_events(self, input_data: Any, **kwargs: Any): + self.calls.append((input_data, kwargs)) + for event in self.events: + yield event + + +async def _collect(stream: Any) -> list[str]: + out: list[str] = [] + async for x in stream: + out.append(x) + return out + + +async def test_stream_output_emits_text_lifecycle_and_updates_result() -> None: + service = _StreamingService() + agent = _Agent( + [ + { + "event": "on_chat_model_stream", + "data": {"chunk": _Chunk(content="Hello")}, + }, + { + "event": "on_chat_model_stream", + "data": {"chunk": _Chunk(content=" world")}, + }, + ] + ) + result = StreamingResult() + + frames = await _collect( + stream_output( + agent=agent, + config={"configurable": {"thread_id": "t-1"}}, + input_data={"messages": []}, + streaming_service=service, + result=result, + ) + ) + + assert frames == [ + "text_start:text-1", + "text_delta:text-1:Hello", + "text_delta:text-1: world", + "text_end:text-1", + ] + assert result.accumulated_text == "Hello world" + assert result.agent_called_update_memory is False + + +async def test_stream_output_passes_runtime_context_to_agent() -> None: + service = _StreamingService() + + class _ContextAwareAgent: + async def astream_events(self, input_data: Any, **kwargs: Any): + del input_data + text = "ctx-ok" if kwargs.get("context") else "ctx-missing" + yield {"event": "on_chat_model_stream", "data": {"chunk": _Chunk(text)}} + + agent = _ContextAwareAgent() + result = StreamingResult() + + frames = await _collect( + stream_output( + agent=agent, + config={"configurable": {"thread_id": "t-2"}}, + input_data={"messages": []}, + streaming_service=service, + result=result, + runtime_context={"mentioned_document_ids": [1, 2]}, + ) + ) + + assert frames == [ + "text_start:text-1", + "text_delta:text-1:ctx-ok", + "text_end:text-1", + ] diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_task_span.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_task_span.py new file mode 100644 index 000000000..df680018d --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_task_span.py @@ -0,0 +1,71 @@ +"""Unit tests for ``task_span`` open/close helpers.""" + +from __future__ import annotations + +import pytest + +from app.tasks.chat.streaming.relay.state import AgentEventRelayState +from app.tasks.chat.streaming.relay.task_span import ( + clear_task_span_if_delegating_task_ended, + ensure_pending_task_span_for_lc, + open_task_span, +) + +pytestmark = pytest.mark.unit + + +def test_open_task_span_sets_span_and_run_id() -> None: + state = AgentEventRelayState.for_invocation() + sid = open_task_span(state, run_id="run-abc") + assert sid.startswith("spn_") + assert state.active_span_id == sid + assert state.active_task_run_id == "run-abc" + assert state.span_metadata_if_active() == {"spanId": sid} + + +def test_clear_ignored_for_non_task_tool() -> None: + state = AgentEventRelayState.for_invocation() + open_task_span(state, run_id="run-1") + sid = state.active_span_id + clear_task_span_if_delegating_task_ended( + state, tool_name="web_search", run_id="run-1" + ) + assert state.active_span_id == sid + assert state.active_task_run_id == "run-1" + + +def test_clear_ignored_when_task_run_id_mismatches() -> None: + state = AgentEventRelayState.for_invocation() + open_task_span(state, run_id="run-open") + clear_task_span_if_delegating_task_ended( + state, tool_name="task", run_id="run-other" + ) + assert state.active_span_id is not None + assert state.active_task_run_id == "run-open" + + +def test_clear_on_matching_task_end() -> None: + state = AgentEventRelayState.for_invocation() + open_task_span(state, run_id="run-x") + clear_task_span_if_delegating_task_ended(state, tool_name="task", run_id="run-x") + assert state.active_span_id is None + assert state.active_task_run_id is None + assert state.span_metadata_if_active() is None + + +def test_clear_noop_when_no_open_span() -> None: + state = AgentEventRelayState.for_invocation() + clear_task_span_if_delegating_task_ended(state, tool_name="task", run_id="run-x") + assert state.active_span_id is None + + +def test_pending_then_open_reuses_same_span_id() -> None: + state = AgentEventRelayState.for_invocation() + sid_pending = ensure_pending_task_span_for_lc(state, "lc-task-1") + assert state.pending_task_span_by_lc["lc-task-1"] == sid_pending + sid_active = open_task_span( + state, run_id="run-1", langchain_tool_call_id="lc-task-1" + ) + assert sid_active == sid_pending + assert state.active_span_id == sid_pending + assert "lc-task-1" not in state.pending_task_span_by_lc diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_tool_activity_metadata.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_tool_activity_metadata.py new file mode 100644 index 000000000..c2e68dacd --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_tool_activity_metadata.py @@ -0,0 +1,42 @@ +"""Unit tests for ``AgentEventRelayState.tool_activity_metadata``.""" + +from __future__ import annotations + +import pytest + +from app.tasks.chat.streaming.relay.state import AgentEventRelayState +from app.tasks.chat.streaming.relay.task_span import open_task_span + +pytestmark = pytest.mark.unit + + +def test_returns_none_when_no_span_and_no_thinking_step() -> None: + state = AgentEventRelayState.for_invocation() + assert state.tool_activity_metadata(thinking_step_id=None) is None + assert state.tool_activity_metadata(thinking_step_id="") is None + assert state.tool_activity_metadata(thinking_step_id=" ") is None + + +def test_thinking_step_id_only() -> None: + state = AgentEventRelayState.for_invocation() + assert state.tool_activity_metadata(thinking_step_id="thinking-3") == { + "thinkingStepId": "thinking-3", + } + + +def test_span_only_when_active() -> None: + state = AgentEventRelayState.for_invocation() + open_task_span(state, run_id="run-x") + assert state.tool_activity_metadata(thinking_step_id=None) == { + "spanId": state.active_span_id, + } + + +def test_merges_span_and_thinking_step_when_both_set() -> None: + state = AgentEventRelayState.for_invocation() + open_task_span(state, run_id="run-x") + md = state.tool_activity_metadata(thinking_step_id="thinking-7") + assert md == { + "spanId": state.active_span_id, + "thinkingStepId": "thinking-7", + } diff --git a/surfsense_backend/tests/unit/tasks/chat/test_content_builder.py b/surfsense_backend/tests/unit/tasks/chat/test_content_builder.py index c317eba20..42e62d26b 100644 --- a/surfsense_backend/tests/unit/tasks/chat/test_content_builder.py +++ b/surfsense_backend/tests/unit/tasks/chat/test_content_builder.py @@ -15,6 +15,7 @@ import json import pytest +from app.services.new_streaming_service import VercelStreamingService from app.tasks.chat.content_builder import AssistantContentBuilder pytestmark = pytest.mark.unit @@ -161,7 +162,7 @@ class TestToolHeavyTurn: _assert_jsonb_safe(snap) def test_tool_input_available_without_prior_start_creates_card(self): - # Legacy / parity_v2-OFF path: tool-input-available may be + # Late-registration: tool-input-available may be # emitted without a prior tool-input-start (no streamed # tool_call_chunks). The card should still be created. b = AssistantContentBuilder() @@ -187,7 +188,7 @@ class TestToolHeavyTurn: assert part["result"] == {"matches": 3} def test_tool_input_start_idempotent_for_same_ui_id(self): - # parity_v2: tool-input-start can fire from BOTH the chunk + # tool-input-start can fire from BOTH the chunk # registration path AND the canonical ``on_tool_start`` path. # The second call must not create a duplicate part. b = AssistantContentBuilder() @@ -231,6 +232,151 @@ class TestToolHeavyTurn: ) +# --------------------------------------------------------------------------- +# Task-span metadata on tool-call parts (JSONB persistence) +# --------------------------------------------------------------------------- + + +class TestToolCallSpanMetadata: + def test_input_available_merges_new_metadata_keys_after_start(self): + b = AssistantContentBuilder() + b.on_tool_input_start("call_t", "task", "lc_t", metadata={"spanId": "spn_1"}) + b.on_tool_input_available( + "call_t", + "task", + {"goal": "x"}, + "lc_t", + metadata={"traceId": "tr_1"}, + ) + part = b.snapshot()[0] + assert part["metadata"]["spanId"] == "spn_1" + assert part["metadata"]["traceId"] == "tr_1" + _assert_jsonb_safe(b.snapshot()) + + def test_input_available_does_not_overwrite_existing_metadata_keys(self): + b = AssistantContentBuilder() + b.on_tool_input_start("call_t", "task", "lc_t", metadata={"spanId": "spn_keep"}) + b.on_tool_input_available( + "call_t", "task", {}, "lc_t", metadata={"spanId": "spn_other"} + ) + assert b.snapshot()[0]["metadata"]["spanId"] == "spn_keep" + + def test_late_tool_input_available_carries_metadata(self): + b = AssistantContentBuilder() + b.on_tool_input_available( + "call_l", + "grep", + {"pattern": "TODO"}, + None, + metadata={"spanId": "spn_l"}, + ) + part = b.snapshot()[0] + assert part["metadata"] == {"spanId": "spn_l"} + _assert_jsonb_safe(b.snapshot()) + + def test_output_available_merges_without_clobbering_span_id(self): + b = AssistantContentBuilder() + b.on_tool_input_start("call_t", "ls", "lc", metadata={"spanId": "spn_x"}) + b.on_tool_input_available("call_t", "ls", {"path": "/"}, "lc") + b.on_tool_output_available( + "call_t", + {"ok": True}, + "lc", + metadata={"spanId": "spn_y", "extra": 1}, + ) + md = b.snapshot()[0]["metadata"] + assert md["spanId"] == "spn_x" + assert md["extra"] == 1 + + def test_output_available_adds_thinking_step_id_without_clobbering_span(self): + b = AssistantContentBuilder() + b.on_tool_input_start( + "call_t", + "ls", + "lc", + metadata={"spanId": "spn_x", "thinkingStepId": "thinking-3"}, + ) + b.on_tool_input_available("call_t", "ls", {"path": "/"}, "lc") + b.on_tool_output_available( + "call_t", + {"ok": True}, + "lc", + metadata={"spanId": "spn_x", "thinkingStepId": "thinking-3"}, + ) + md = b.snapshot()[0]["metadata"] + assert md["spanId"] == "spn_x" + assert md["thinkingStepId"] == "thinking-3" + + def test_output_available_with_none_metadata_preserves_prior(self): + b = AssistantContentBuilder() + b.on_tool_input_start("c", "ls", "lc", metadata={"spanId": "spn_1"}) + b.on_tool_input_available("c", "ls", {}, "lc") + b.on_tool_output_available("c", {"r": 1}, "lc", metadata=None) + assert b.snapshot()[0]["metadata"] == {"spanId": "spn_1"} + + def test_available_adds_thinking_step_id_after_chunk_only_start(self): + """Mirrors chunk ``tool-input-start`` then ``on_tool_start`` ``available``.""" + b = AssistantContentBuilder() + b.on_tool_input_start("lc_1", "ls", "lc_1", metadata={"spanId": "spn_a"}) + b.on_tool_input_available( + "lc_1", + "ls", + {"path": "/"}, + "lc_1", + metadata={"spanId": "spn_a", "thinkingStepId": "thinking-2"}, + ) + md = b.snapshot()[0]["metadata"] + assert md["spanId"] == "spn_a" + assert md["thinkingStepId"] == "thinking-2" + + +class TestVercelStreamingServiceToolMetadataWire: + """SSE payloads include optional ``metadata`` for FE grouping.""" + + @staticmethod + def _parse_sse_data_line(raw: str) -> dict: + assert raw.startswith("data: ") + payload = raw.split("data: ", 1)[1].split("\n\n", 1)[0].strip() + return json.loads(payload) + + def test_tool_input_available_includes_metadata_when_set(self): + svc = VercelStreamingService() + raw = svc.format_tool_input_available( + "id1", + "task", + {"a": 1}, + langchain_tool_call_id="lc1", + metadata={"spanId": "spn_w", "thinkingStepId": "thinking-4"}, + ) + body = self._parse_sse_data_line(raw) + assert body["type"] == "tool-input-available" + assert body["metadata"] == { + "spanId": "spn_w", + "thinkingStepId": "thinking-4", + } + + def test_tool_output_available_includes_metadata_when_set(self): + svc = VercelStreamingService() + raw = svc.format_tool_output_available( + "id1", + {"status": "completed"}, + langchain_tool_call_id="lc1", + metadata={"spanId": "spn_o", "thinkingStepId": "thinking-9"}, + ) + body = self._parse_sse_data_line(raw) + assert body["type"] == "tool-output-available" + assert body["metadata"] == { + "spanId": "spn_o", + "thinkingStepId": "thinking-9", + } + + def test_tool_input_available_omits_metadata_key_when_none(self): + svc = VercelStreamingService() + raw = svc.format_tool_input_available("id1", "ls", {}) + body = self._parse_sse_data_line(raw) + assert "metadata" not in body + + # --------------------------------------------------------------------------- # Thinking steps & separators # --------------------------------------------------------------------------- diff --git a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py index 60750396c..ada32d168 100644 --- a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py +++ b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py @@ -1,16 +1,13 @@ """Unit tests for live tool-call argument streaming. -Pins the wire format that ``_stream_agent_events`` emits when -``SURFSENSE_ENABLE_STREAM_PARITY_V2=true``: ``tool-input-start`` → -``tool-input-delta``... → ``tool-input-available`` → ``tool-output-available`` -all keyed by the same LangChain ``tool_call.id``. +Pins the wire format that ``_stream_agent_events`` emits: +``tool-input-start`` → ``tool-input-delta``... → ``tool-input-available`` → +``tool-output-available``, keyed consistently with LangChain ``tool_call.id`` +when the model streams indexed chunks. Identity is tracked in ``index_to_meta`` (per-chunk ``index``) and -``ui_tool_call_id_by_run`` (LangGraph ``run_id``); both are private to -``_stream_agent_events`` so we exercise them via the public wire output. - -These tests also lock in the legacy / parity_v2-OFF behaviour so the -synthetic ``call_`` shape stays stable for older clients. +``ui_tool_call_id_by_run`` (LangGraph ``run_id``); both are internal to the +streaming layer so we assert on the public SSE payloads. """ from __future__ import annotations @@ -22,8 +19,6 @@ from typing import Any import pytest -import app.tasks.chat.stream_new_chat as stream_module -from app.agents.new_chat.feature_flags import AgentFeatureFlags from app.services.new_streaming_service import VercelStreamingService from app.tasks.chat.stream_new_chat import ( StreamResult, @@ -164,24 +159,6 @@ def _tool_end( } -@pytest.fixture -def parity_v2_on(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr( - stream_module, - "get_flags", - lambda: AgentFeatureFlags(enable_stream_parity_v2=True), - ) - - -@pytest.fixture -def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr( - stream_module, - "get_flags", - lambda: AgentFeatureFlags(enable_stream_parity_v2=False), - ) - - async def _drain( events: list[dict[str, Any]], state: _FakeAgentState | None = None ) -> list[dict[str, Any]]: @@ -253,12 +230,12 @@ class TestLegacyMatch: # --------------------------------------------------------------------------- -# parity_v2 wire format tests. +# Tool input streaming wire format # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_idless_chunk_merging_by_index(parity_v2_on: None) -> None: +async def test_idless_chunk_merging_by_index() -> None: """First chunk carries id+name; later idless chunks at the same ``index`` merge into the SAME ``tool-input-start`` ui id and emit one ``tool-input-delta`` per chunk.""" @@ -302,9 +279,7 @@ async def test_idless_chunk_merging_by_index(parity_v2_on: None) -> None: @pytest.mark.asyncio -async def test_two_interleaved_tool_calls_route_by_index( - parity_v2_on: None, -) -> None: +async def test_two_interleaved_tool_calls_route_by_index() -> None: """Two same-name calls with distinct indices keep their deltas routed to the right card.""" events = [ @@ -344,7 +319,7 @@ async def test_two_interleaved_tool_calls_route_by_index( @pytest.mark.asyncio -async def test_identity_stable_across_lifecycle(parity_v2_on: None) -> None: +async def test_identity_stable_across_lifecycle() -> None: """Whatever id ``tool-input-start`` chose must be the SAME id used on ``tool-input-available`` AND ``tool-output-available``.""" events = [ @@ -367,7 +342,7 @@ async def test_identity_stable_across_lifecycle(parity_v2_on: None) -> None: @pytest.mark.asyncio -async def test_no_duplicate_tool_input_start(parity_v2_on: None) -> None: +async def test_no_duplicate_tool_input_start() -> None: """When the chunk-emission loop already fired ``tool-input-start`` for this run, ``on_tool_start`` MUST NOT emit a second one.""" events = [ @@ -386,9 +361,7 @@ async def test_no_duplicate_tool_input_start(parity_v2_on: None) -> None: @pytest.mark.asyncio -async def test_active_text_closes_before_early_tool_input_start( - parity_v2_on: None, -) -> None: +async def test_active_text_closes_before_early_tool_input_start() -> None: """Streaming a text-delta then a tool-call chunk in subsequent chunks: the wire MUST contain ``text-end`` before the FIRST ``tool-input-start`` (clean part boundary on the frontend).""" @@ -409,9 +382,7 @@ async def test_active_text_closes_before_early_tool_input_start( @pytest.mark.asyncio -async def test_mixed_text_and_tool_chunk_preserve_order( - parity_v2_on: None, -) -> None: +async def test_mixed_text_and_tool_chunk_preserve_order() -> None: """One AIMessageChunk that carries BOTH ``text`` content AND ``tool_call_chunks`` should emit the text delta FIRST, then close text, then ``tool-input-start``+``tool-input-delta``.""" @@ -441,45 +412,7 @@ async def test_mixed_text_and_tool_chunk_preserve_order( @pytest.mark.asyncio -async def test_parity_v2_off_preserves_legacy_shape( - parity_v2_off: None, -) -> None: - """When the flag is OFF, no deltas are emitted and the ``toolCallId`` - is ``call_`` (NOT the lc id).""" - events = [ - _model_stream( - tool_call_chunks=[ - {"id": "lc-1", "name": "ls", "args": '{"path":"/"}', "index": 0} - ] - ), - _tool_start(name="ls", run_id="run-A", input_payload={"path": "/"}), - _tool_end(name="ls", run_id="run-A", tool_call_id="lc-1"), - ] - payloads = await _drain(events) - - assert _of_type(payloads, "tool-input-delta") == [] - starts = _of_type(payloads, "tool-input-start") - assert len(starts) == 1 - assert starts[0]["toolCallId"].startswith("call_run-A") - # No ``langchainToolCallId`` propagation on ``tool-input-start`` in - # legacy mode (the start event fires before the ToolMessage is - # available, so we can't extract the authoritative LangChain id yet). - assert "langchainToolCallId" not in starts[0] - output = _of_type(payloads, "tool-output-available") - assert output[0]["toolCallId"].startswith("call_run-A") - # ``tool-output-available`` MUST carry ``langchainToolCallId`` even - # in legacy mode: the chat tool card uses it to backfill the - # LangChain id and join against the ``data-action-log`` SSE event - # (keyed by ``lc_tool_call_id``) so the inline Revert button can - # light up. Sourced from the returned ``ToolMessage.tool_call_id``, - # which is populated regardless of feature-flag state. - assert output[0]["langchainToolCallId"] == "lc-1" - - -@pytest.mark.asyncio -async def test_skip_append_prevents_stale_id_reuse( - parity_v2_on: None, -) -> None: +async def test_skip_append_prevents_stale_id_reuse() -> None: """Two same-name tools: the SECOND tool's ``langchainToolCallId`` must NOT come from the first tool's chunk (``pending_tool_call_chunks`` must stay empty for indexed-registered chunks).""" @@ -506,9 +439,7 @@ async def test_skip_append_prevents_stale_id_reuse( @pytest.mark.asyncio -async def test_registration_waits_for_both_id_and_name( - parity_v2_on: None, -) -> None: +async def test_registration_waits_for_both_id_and_name() -> None: """An id-only chunk (no name yet) must NOT emit ``tool-input-start``.""" events = [ _model_stream( @@ -520,12 +451,9 @@ async def test_registration_waits_for_both_id_and_name( @pytest.mark.asyncio -async def test_unmatched_fallback_still_attaches_lc_id( - parity_v2_on: None, -) -> None: - """parity_v2 ON, but the provider didn't include an ``index``: the - legacy fallback path must still emit ``tool-input-start`` with the - matching ``langchainToolCallId``.""" +async def test_unmatched_fallback_still_attaches_lc_id() -> None: + """When the provider omits chunk ``index``, buffered chunks still get a + ``tool-input-start`` with the matching ``langchainToolCallId``.""" events = [ # No index on the chunk → not registered into index_to_meta; # falls through to ``pending_tool_call_chunks`` so the legacy @@ -542,9 +470,7 @@ async def test_unmatched_fallback_still_attaches_lc_id( @pytest.mark.asyncio -async def test_interrupt_request_uses_task_that_contains_interrupt( - parity_v2_on: None, -) -> None: +async def test_interrupt_request_uses_task_that_contains_interrupt() -> None: interrupt_payload = { "type": "calendar_event_create", "action": { diff --git a/surfsense_backend/uv.lock b/surfsense_backend/uv.lock index 4dd5156e7..812be636a 100644 --- a/surfsense_backend/uv.lock +++ b/surfsense_backend/uv.lock @@ -7947,7 +7947,7 @@ wheels = [ [[package]] name = "surf-new-backend" -version = "0.0.22" +version = "0.0.23" source = { editable = "." } dependencies = [ { name = "alembic" }, diff --git a/surfsense_browser_extension/package.json b/surfsense_browser_extension/package.json index b8b5cb2ec..82c0a349a 100644 --- a/surfsense_browser_extension/package.json +++ b/surfsense_browser_extension/package.json @@ -1,7 +1,7 @@ { "name": "surfsense_browser_extension", "displayName": "Surfsense Browser Extension", - "version": "0.0.22", + "version": "0.0.23", "description": "Extension to collect Browsing History for SurfSense.", "author": "https://github.com/MODSetter", "engines": { diff --git a/surfsense_desktop/package.json b/surfsense_desktop/package.json index 744ab65ab..4ef624760 100644 --- a/surfsense_desktop/package.json +++ b/surfsense_desktop/package.json @@ -1,6 +1,6 @@ { "name": "surfsense-desktop", - "version": "0.0.22", + "version": "0.0.23", "description": "SurfSense Desktop App", "main": "dist/main.js", "scripts": { diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 9b5510df3..c431ab304 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -43,13 +43,14 @@ import { type EditMessageDialogChoice, } from "@/components/assistant-ui/edit-message-dialog"; import { StepSeparatorDataUI } from "@/components/assistant-ui/step-separator"; -import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps"; import { Thread } from "@/components/assistant-ui/thread"; import { createTokenUsageStore, type TokenUsageData, TokenUsageProvider, } from "@/components/assistant-ui/token-usage-context"; +import { type HitlDecision, PendingInterruptProvider } from "@/features/chat-messages/hitl"; +import { TimelineDataUI } from "@/features/chat-messages/timeline"; import { applyActionLogSse, applyActionLogUpdatedSse, @@ -63,7 +64,10 @@ import { documentsApiService } from "@/lib/apis/documents-api.service"; import { getBearerToken } from "@/lib/auth-utils"; import { type ChatFlow, classifyChatError } from "@/lib/chat/chat-error-classifier"; import { tagPreAcceptSendFailure, toHttpResponseError } from "@/lib/chat/chat-request-errors"; -import { convertToThreadMessage } from "@/lib/chat/message-utils"; +import { + convertToThreadMessage, + reconcileInterruptedAssistantMessages, +} from "@/lib/chat/message-utils"; import { isPodcastGenerating, looksLikePodcastRequest, @@ -107,7 +111,6 @@ import { type NewChatUserImagePayload, } from "@/lib/chat/user-turn-api-parts"; import { NotFoundError } from "@/lib/error"; -import { type BundleSubmit, HitlBundleProvider } from "@/lib/hitl"; import { trackChatBlocked, trackChatCreated, @@ -126,7 +129,7 @@ const MobileEditorPanel = dynamic( ); const MobileHitlEditPanel = dynamic( () => - import("@/components/hitl-edit-panel/hitl-edit-panel").then((m) => ({ + import("@/features/chat-messages/hitl").then((m) => ({ default: m.MobileHitlEditPanel, })), { ssr: false } @@ -196,12 +199,16 @@ function pairBundleToolCallIds( } /** - * Zod schema for mentioned document info (for type-safe parsing) + * Zod schema for mentioned document info (for type-safe parsing). + * + * ``kind`` defaults to ``"doc"`` so messages persisted before folder + * mentions existed deserialise unchanged. */ const MentionedDocumentInfoSchema = z.object({ id: z.number(), title: z.string(), document_type: z.string(), + kind: z.union([z.literal("doc"), z.literal("folder")]).optional().default("doc"), }); const MentionedDocumentsPartSchema = z.object({ @@ -395,7 +402,7 @@ export default function NewChatPage() { const memberById = new Map(membersData?.map((m) => [m.user_id, m]) ?? []); const prevById = new Map(prev.map((m) => [m.id, m])); - return syncedMessages.map((msg) => { + return reconcileInterruptedAssistantMessages(syncedMessages).map((msg) => { const member = msg.author_id ? (memberById.get(msg.author_id) ?? null) : null; // Preserve existing author info if member lookup fails (e.g., cloned chats) @@ -622,7 +629,9 @@ export default function NewChatPage() { setCurrentThread(threadData); if (messagesResponse.messages && messagesResponse.messages.length > 0) { - const loadedMessages = messagesResponse.messages.map(convertToThreadMessage); + const loadedMessages = reconcileInterruptedAssistantMessages( + messagesResponse.messages + ).map(convertToThreadMessage); setMessages(loadedMessages); for (const msg of messagesResponse.messages) { @@ -908,18 +917,29 @@ export default function NewChatPage() { hasAttachments: userImages.length > 0, hasMentionedDocuments: mentionedDocumentIds.surfsense_doc_ids.length > 0 || - mentionedDocumentIds.document_ids.length > 0, + mentionedDocumentIds.document_ids.length > 0 || + mentionedDocumentIds.folder_ids.length > 0, messageLength: userQuery.length, }); - // Collect unique mentioned docs for display & persistence + // Collect unique mention chips for display & persistence. + // Dedup key is ``kind:document_type:id`` so a folder and a + // doc with the same integer id never collapse into one + // entry. The ``kind`` field is forwarded to the backend + // so the persisted ``mentioned-documents`` content part + // can render the correct chip type on reload. const allMentionedDocs: MentionedDocumentInfo[] = []; const seenDocKeys = new Set(); for (const doc of mentionedDocuments) { - const key = `${doc.document_type}:${doc.id}`; + const key = `${doc.kind}:${doc.document_type}:${doc.id}`; if (seenDocKeys.has(key)) continue; seenDocKeys.add(key); - allMentionedDocs.push({ id: doc.id, title: doc.title, document_type: doc.document_type }); + allMentionedDocs.push({ + id: doc.id, + title: doc.title, + document_type: doc.document_type, + kind: doc.kind, + }); } if (allMentionedDocs.length > 0) { @@ -981,9 +1001,10 @@ export default function NewChatPage() { // Get mentioned document IDs for context (separate fields for backend) const hasDocumentIds = mentionedDocumentIds.document_ids.length > 0; const hasSurfsenseDocIds = mentionedDocumentIds.surfsense_doc_ids.length > 0; + const hasFolderIds = mentionedDocumentIds.folder_ids.length > 0; // Clear mentioned documents after capturing them - if (hasDocumentIds || hasSurfsenseDocIds) { + if (hasDocumentIds || hasSurfsenseDocIds || hasFolderIds) { setMentionedDocuments([]); } @@ -1008,7 +1029,11 @@ export default function NewChatPage() { mentioned_surfsense_doc_ids: hasSurfsenseDocIds ? mentionedDocumentIds.surfsense_doc_ids : undefined, - // Full mention metadata so the BE can embed a + mentioned_folder_ids: hasFolderIds + ? mentionedDocumentIds.folder_ids + : undefined, + // Full mention metadata (docs + folders, with + // ``kind`` discriminator) so the BE can embed a // ``mentioned-documents`` ContentPart on the // persisted user message (replaces the old FE-side // injection in ``persistUserTurn``). @@ -1018,6 +1043,7 @@ export default function NewChatPage() { id: d.id, title: d.title, document_type: d.document_type, + kind: d.kind, })) : undefined, disabled_tools: disabledTools.length > 0 ? disabledTools : undefined, @@ -1388,6 +1414,8 @@ export default function NewChatPage() { const existingMsg = messages.find((m) => m.id === assistantMsgId); if (existingMsg && Array.isArray(existingMsg.content)) { + // See ``ContentPartsState.suppressStepSeparators`` doc. + contentPartsState.suppressStepSeparators = true; for (const part of existingMsg.content) { if (typeof part === "object" && part !== null) { const p = part as Record; @@ -1402,15 +1430,19 @@ export default function NewChatPage() { toolName: String(p.toolName), args: (p.args as Record) ?? {}, result: p.result as unknown, - // Restore argsText so persisted pretty-printed - // JSON survives reloads (assistant-ui prefers - // supplied argsText over JSON.stringify(args)). - // langchainToolCallId restoration also fixes a - // pre-existing dropped-id bug on resume. + // argsText: assistant-ui prefers it over + // JSON.stringify(args), so restoring it keeps + // pretty-printed JSON across reloads. ...(typeof p.argsText === "string" ? { argsText: p.argsText } : {}), ...(typeof p.langchainToolCallId === "string" ? { langchainToolCallId: p.langchainToolCallId } : {}), + // metadata: spanId / thinkingStepId drive the + // timeline's step↔tool join. Dropping these + // here orphans every rehydrated tool-call. + ...(p.metadata && typeof p.metadata === "object" + ? { metadata: p.metadata as Record } + : {}), }); contentPartsState.currentTextPartIndex = -1; } else if (p.type === "data-thinking-steps") { @@ -1730,57 +1762,6 @@ export default function NewChatPage() { return () => window.removeEventListener("hitl-decision", handler); }, [handleResume, pendingInterrupt]); - // Mirror staged bundle decisions onto the cards visually so prev/next nav - // reflects past choices instead of re-prompting. Submit's ``hitl-decision`` - // handler still runs the actual resume. - useEffect(() => { - const handler = (e: Event) => { - const detail = (e as CustomEvent).detail as { - toolCallId: string; - decision: { - type: string; - message?: string; - edited_action?: { name: string; args: Record }; - }; - }; - if (!detail?.toolCallId || !detail?.decision || !pendingInterrupt) return; - setMessages((prev) => - prev.map((m) => { - if (m.id !== pendingInterrupt.assistantMsgId) return m; - const parts = m.content as unknown as Array>; - const newContent = parts.map((part) => { - if (part.toolCallId !== detail.toolCallId) return part; - if (part.type !== "tool-call") return part; - if (typeof part.result !== "object" || part.result === null) return part; - if (!("__interrupt__" in (part.result as Record))) return part; - const decided = detail.decision.type as "approve" | "reject" | "edit"; - if (decided === "edit" && detail.decision.edited_action) { - return { - ...part, - args: detail.decision.edited_action.args, - argsText: JSON.stringify(detail.decision.edited_action.args, null, 2), - result: { - ...(part.result as Record), - __decided__: decided, - }, - }; - } - return { - ...part, - result: { - ...(part.result as Record), - __decided__: decided, - }, - }; - }); - return { ...m, content: newContent as unknown as ThreadMessageLike["content"] }; - }) - ); - }; - window.addEventListener("hitl-stage", handler); - return () => window.removeEventListener("hitl-stage", handler); - }, [pendingInterrupt]); - // Convert message (pass through since already in correct format) const convertMessage = useCallback( (message: ThreadMessageLike): ThreadMessageLike => message, @@ -1895,6 +1876,23 @@ export default function NewChatPage() { const selection = await getAgentFilesystemSelection(searchSpaceId, { localFilesystemEnabled, }); + // Partition the source mentions back into doc/surfsense_doc/folder + // id buckets so the regenerate route can pass them to + // ``stream_new_chat`` and the priority middleware sees the + // same ``[USER-MENTIONED]`` priority entries the original + // turn did. Without this partition the regenerate flow + // silently dropped the agent's mention awareness — same + // architectural bug we fixed on the new-chat path. + const regenerateSurfsenseDocIds = sourceMentionedDocs + .filter((d) => d.kind === "doc" && d.document_type === "SURFSENSE_DOCS") + .map((d) => d.id); + const regenerateDocIds = sourceMentionedDocs + .filter((d) => d.kind === "doc" && d.document_type !== "SURFSENSE_DOCS") + .map((d) => d.id); + const regenerateFolderIds = sourceMentionedDocs + .filter((d) => d.kind === "folder") + .map((d) => d.id); + const requestBody: Record = { search_space_id: searchSpaceId, user_query: newUserQuery, @@ -1902,6 +1900,12 @@ export default function NewChatPage() { filesystem_mode: selection.filesystem_mode, client_platform: selection.client_platform, local_filesystem_mounts: selection.local_filesystem_mounts, + mentioned_document_ids: + regenerateDocIds.length > 0 ? regenerateDocIds : undefined, + mentioned_surfsense_doc_ids: + regenerateSurfsenseDocIds.length > 0 ? regenerateSurfsenseDocIds : undefined, + mentioned_folder_ids: + regenerateFolderIds.length > 0 ? regenerateFolderIds : undefined, // Full mention metadata for the regenerate-specific // source list. Only meaningful for edit (the BE only // re-persists a user row when ``user_query`` is set); @@ -1912,6 +1916,7 @@ export default function NewChatPage() { id: d.id, title: d.title, document_type: d.document_type, + kind: d.kind, })) : undefined, }; @@ -2279,7 +2284,7 @@ export default function NewChatPage() { [handleRegenerate, messages, agentActionItems] ); - const handleBundleSubmit = useCallback((orderedDecisions) => { + const handleApprovalSubmit = useCallback((orderedDecisions: HitlDecision[]) => { window.dispatchEvent( new CustomEvent("hitl-decision", { detail: { decisions: orderedDecisions } }) ); @@ -2353,11 +2358,11 @@ export default function NewChatPage() { return ( - + -
@@ -2367,7 +2372,7 @@ export default function NewChatPage() {
- + { diff --git a/surfsense_web/atoms/chat/mentioned-documents.atom.ts b/surfsense_web/atoms/chat/mentioned-documents.atom.ts index 9c4546237..eafdaf87e 100644 --- a/surfsense_web/atoms/chat/mentioned-documents.atom.ts +++ b/surfsense_web/atoms/chat/mentioned-documents.atom.ts @@ -4,45 +4,108 @@ import { atom } from "jotai"; import type { Document } from "@/contracts/types/document.types"; /** - * Atom to store the full document objects mentioned via @-mention chips - * in the current chat composer. This persists across component remounts. + * Sentinel ``document_type`` used for folder mention chips so the + * dedup key (`kind:document_type:id`) never collides a document with a + * folder that happens to share an integer id. */ -export const mentionedDocumentsAtom = atom[]>([]); +export const FOLDER_MENTION_DOCUMENT_TYPE = "FOLDER"; /** - * Derived read-only atom that maps deduplicated mentioned docs - * into backend payload fields. - */ -export const mentionedDocumentIdsAtom = atom((get) => { - const allDocs = get(mentionedDocumentsAtom); - const seen = new Set(); - const deduped = allDocs.filter((d) => { - const key = `${d.document_type}:${d.id}`; - if (seen.has(key)) return false; - seen.add(key); - return true; - }); - return { - surfsense_doc_ids: deduped - .filter((doc) => doc.document_type === "SURFSENSE_DOCS") - .map((doc) => doc.id), - document_ids: deduped - .filter((doc) => doc.document_type !== "SURFSENSE_DOCS") - .map((doc) => doc.id), - }; -}); - -/** - * Simplified document info for display purposes + * Display metadata for a single ``@``-mention chip. + * + * The ``kind`` discriminator identifies whether the chip is a + * knowledge-base document or a knowledge-base folder. Folders carry + * the sentinel ``document_type === FOLDER_MENTION_DOCUMENT_TYPE`` so + * the editor, picker, and persisted ``mentioned-documents`` content + * part all stay aligned with the backend Pydantic schema. */ export interface MentionedDocumentInfo { id: number; title: string; document_type: string; + kind: "doc" | "folder"; } /** - * Atom to store mentioned documents per message ID. + * Backwards-compatible doc-only chip shape for legacy callers that + * haven't migrated to the discriminated union yet. Keep narrow so + * accidental new callers fail typecheck and route through the + * discriminated type instead. + */ +type LegacyDocMention = Pick; + +/** + * Normalize an arbitrary chip-like input into the discriminated + * ``MentionedDocumentInfo`` shape. Existing call sites that only have + * ``{id, title, document_type}`` flow through here so they don't have + * to thread ``kind`` everywhere — the helper defaults to ``"doc"`` and + * rewrites the document type for folders. + */ +export function toMentionedDocumentInfo( + input: LegacyDocMention | MentionedDocumentInfo +): MentionedDocumentInfo { + if ("kind" in input && (input.kind === "doc" || input.kind === "folder")) { + return input; + } + return { + id: input.id, + title: input.title, + document_type: input.document_type, + kind: "doc", + }; +} + +/** + * Build a folder-mention chip from a folder row (id + name). + */ +export function makeFolderMention(input: { id: number; name: string }): MentionedDocumentInfo { + return { + id: input.id, + title: input.name, + document_type: FOLDER_MENTION_DOCUMENT_TYPE, + kind: "folder", + }; +} + +/** + * Atom to store the full mention objects (documents + folders) attached + * via @-mention chips in the current chat composer. Persists across + * component remounts. + */ +export const mentionedDocumentsAtom = atom([]); + +/** + * Derived read-only atom that maps deduplicated mention chips into + * backend payload fields. Doc chips split by ``document_type`` exactly + * like before; folder chips are projected into a separate + * ``folder_ids`` bucket so the route can forward + * ``mentioned_folder_ids`` to the agent without the priority middleware + * conflating them with hybrid-search ids. + */ +export const mentionedDocumentIdsAtom = atom((get) => { + const allMentions = get(mentionedDocumentsAtom); + const seen = new Set(); + const deduped = allMentions.filter((m) => { + const key = `${m.kind}:${m.document_type}:${m.id}`; + if (seen.has(key)) return false; + seen.add(key); + return true; + }); + const docs = deduped.filter((m) => m.kind === "doc"); + const folders = deduped.filter((m) => m.kind === "folder"); + return { + surfsense_doc_ids: docs + .filter((doc) => doc.document_type === "SURFSENSE_DOCS") + .map((doc) => doc.id), + document_ids: docs + .filter((doc) => doc.document_type !== "SURFSENSE_DOCS") + .map((doc) => doc.id), + folder_ids: folders.map((f) => f.id), + }; +}); + +/** + * Atom to store mentioned chips per message ID. * This allows displaying which documents were mentioned with each user message. */ export const messageDocumentsMapAtom = atom>({}); diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx index 7bccc22ee..00f3acebf 100644 --- a/surfsense_web/components/assistant-ui/assistant-message.tsx +++ b/surfsense_web/components/assistant-ui/assistant-message.tsx @@ -4,6 +4,7 @@ import { AuiIf, ErrorPrimitive, MessagePrimitive, + type ToolCallMessagePartComponent, useAui, useAuiState, } from "@assistant-ui/react"; @@ -36,11 +37,9 @@ import { MarkdownText } from "@/components/assistant-ui/markdown-text"; import { ReasoningMessagePart } from "@/components/assistant-ui/reasoning-message-part"; import { RevertTurnButton } from "@/components/assistant-ui/revert-turn-button"; import { useTokenUsage } from "@/components/assistant-ui/token-usage-context"; -import { ToolFallback } from "@/components/assistant-ui/tool-fallback"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { CommentPanelContainer } from "@/components/chat-comments/comment-panel-container/comment-panel-container"; import { CommentSheet } from "@/components/chat-comments/comment-sheet/comment-sheet"; -import { withBundleStep } from "@/components/hitl-bundle-pager"; import type { SerializableCitation } from "@/components/tool-ui/citation"; import { openSafeNavigationHref, @@ -100,146 +99,6 @@ const GenerateImageToolUI = dynamic( import("@/components/tool-ui/generate-image").then((m) => ({ default: m.GenerateImageToolUI })), { ssr: false } ); -const UpdateMemoryToolUI = dynamic( - () => import("@/components/tool-ui/user-memory").then((m) => ({ default: m.UpdateMemoryToolUI })), - { ssr: false } -); -const SandboxExecuteToolUI = dynamic( - () => - import("@/components/tool-ui/sandbox-execute").then((m) => ({ - default: m.SandboxExecuteToolUI, - })), - { ssr: false } -); -const CreateNotionPageToolUI = dynamic( - () => import("@/components/tool-ui/notion").then((m) => ({ default: m.CreateNotionPageToolUI })), - { ssr: false } -); -const UpdateNotionPageToolUI = dynamic( - () => import("@/components/tool-ui/notion").then((m) => ({ default: m.UpdateNotionPageToolUI })), - { ssr: false } -); -const DeleteNotionPageToolUI = dynamic( - () => import("@/components/tool-ui/notion").then((m) => ({ default: m.DeleteNotionPageToolUI })), - { ssr: false } -); -const CreateLinearIssueToolUI = dynamic( - () => import("@/components/tool-ui/linear").then((m) => ({ default: m.CreateLinearIssueToolUI })), - { ssr: false } -); -const UpdateLinearIssueToolUI = dynamic( - () => import("@/components/tool-ui/linear").then((m) => ({ default: m.UpdateLinearIssueToolUI })), - { ssr: false } -); -const DeleteLinearIssueToolUI = dynamic( - () => import("@/components/tool-ui/linear").then((m) => ({ default: m.DeleteLinearIssueToolUI })), - { ssr: false } -); -const CreateGoogleDriveFileToolUI = dynamic( - () => - import("@/components/tool-ui/google-drive").then((m) => ({ - default: m.CreateGoogleDriveFileToolUI, - })), - { ssr: false } -); -const DeleteGoogleDriveFileToolUI = dynamic( - () => - import("@/components/tool-ui/google-drive").then((m) => ({ - default: m.DeleteGoogleDriveFileToolUI, - })), - { ssr: false } -); -const CreateOneDriveFileToolUI = dynamic( - () => - import("@/components/tool-ui/onedrive").then((m) => ({ default: m.CreateOneDriveFileToolUI })), - { ssr: false } -); -const DeleteOneDriveFileToolUI = dynamic( - () => - import("@/components/tool-ui/onedrive").then((m) => ({ default: m.DeleteOneDriveFileToolUI })), - { ssr: false } -); -const CreateDropboxFileToolUI = dynamic( - () => - import("@/components/tool-ui/dropbox").then((m) => ({ default: m.CreateDropboxFileToolUI })), - { ssr: false } -); -const DeleteDropboxFileToolUI = dynamic( - () => - import("@/components/tool-ui/dropbox").then((m) => ({ default: m.DeleteDropboxFileToolUI })), - { ssr: false } -); -const CreateCalendarEventToolUI = dynamic( - () => - import("@/components/tool-ui/google-calendar").then((m) => ({ - default: m.CreateCalendarEventToolUI, - })), - { ssr: false } -); -const UpdateCalendarEventToolUI = dynamic( - () => - import("@/components/tool-ui/google-calendar").then((m) => ({ - default: m.UpdateCalendarEventToolUI, - })), - { ssr: false } -); -const DeleteCalendarEventToolUI = dynamic( - () => - import("@/components/tool-ui/google-calendar").then((m) => ({ - default: m.DeleteCalendarEventToolUI, - })), - { ssr: false } -); -const CreateGmailDraftToolUI = dynamic( - () => import("@/components/tool-ui/gmail").then((m) => ({ default: m.CreateGmailDraftToolUI })), - { ssr: false } -); -const UpdateGmailDraftToolUI = dynamic( - () => import("@/components/tool-ui/gmail").then((m) => ({ default: m.UpdateGmailDraftToolUI })), - { ssr: false } -); -const SendGmailEmailToolUI = dynamic( - () => import("@/components/tool-ui/gmail").then((m) => ({ default: m.SendGmailEmailToolUI })), - { ssr: false } -); -const TrashGmailEmailToolUI = dynamic( - () => import("@/components/tool-ui/gmail").then((m) => ({ default: m.TrashGmailEmailToolUI })), - { ssr: false } -); -const CreateJiraIssueToolUI = dynamic( - () => import("@/components/tool-ui/jira").then((m) => ({ default: m.CreateJiraIssueToolUI })), - { ssr: false } -); -const UpdateJiraIssueToolUI = dynamic( - () => import("@/components/tool-ui/jira").then((m) => ({ default: m.UpdateJiraIssueToolUI })), - { ssr: false } -); -const DeleteJiraIssueToolUI = dynamic( - () => import("@/components/tool-ui/jira").then((m) => ({ default: m.DeleteJiraIssueToolUI })), - { ssr: false } -); -const CreateConfluencePageToolUI = dynamic( - () => - import("@/components/tool-ui/confluence").then((m) => ({ - default: m.CreateConfluencePageToolUI, - })), - { ssr: false } -); -const UpdateConfluencePageToolUI = dynamic( - () => - import("@/components/tool-ui/confluence").then((m) => ({ - default: m.UpdateConfluencePageToolUI, - })), - { ssr: false } -); -const DeleteConfluencePageToolUI = dynamic( - () => - import("@/components/tool-ui/confluence").then((m) => ({ - default: m.DeleteConfluencePageToolUI, - })), - { ssr: false } -); - function extractDomain(url: string): string | undefined { try { return new URL(url).hostname.replace(/^www\./, ""); @@ -503,50 +362,26 @@ const MessageInfoDropdown: FC = () => { ); }; -// Wrap each tool-ui card with ``withBundleStep`` so multi-card HITL bundles -// page through them and stage decisions instead of firing one resume per card. -const TOOLS_BY_NAME = { - generate_report: withBundleStep(GenerateReportToolUI), - generate_resume: withBundleStep(GenerateResumeToolUI), - generate_podcast: withBundleStep(GeneratePodcastToolUI), - generate_video_presentation: withBundleStep(GenerateVideoPresentationToolUI), - display_image: withBundleStep(GenerateImageToolUI), - generate_image: withBundleStep(GenerateImageToolUI), - update_memory: withBundleStep(UpdateMemoryToolUI), - execute: withBundleStep(SandboxExecuteToolUI), - execute_code: withBundleStep(SandboxExecuteToolUI), - create_notion_page: withBundleStep(CreateNotionPageToolUI), - update_notion_page: withBundleStep(UpdateNotionPageToolUI), - delete_notion_page: withBundleStep(DeleteNotionPageToolUI), - create_linear_issue: withBundleStep(CreateLinearIssueToolUI), - update_linear_issue: withBundleStep(UpdateLinearIssueToolUI), - delete_linear_issue: withBundleStep(DeleteLinearIssueToolUI), - create_google_drive_file: withBundleStep(CreateGoogleDriveFileToolUI), - delete_google_drive_file: withBundleStep(DeleteGoogleDriveFileToolUI), - create_onedrive_file: withBundleStep(CreateOneDriveFileToolUI), - delete_onedrive_file: withBundleStep(DeleteOneDriveFileToolUI), - create_dropbox_file: withBundleStep(CreateDropboxFileToolUI), - delete_dropbox_file: withBundleStep(DeleteDropboxFileToolUI), - create_calendar_event: withBundleStep(CreateCalendarEventToolUI), - update_calendar_event: withBundleStep(UpdateCalendarEventToolUI), - delete_calendar_event: withBundleStep(DeleteCalendarEventToolUI), - create_gmail_draft: withBundleStep(CreateGmailDraftToolUI), - update_gmail_draft: withBundleStep(UpdateGmailDraftToolUI), - send_gmail_email: withBundleStep(SendGmailEmailToolUI), - trash_gmail_email: withBundleStep(TrashGmailEmailToolUI), - create_jira_issue: withBundleStep(CreateJiraIssueToolUI), - update_jira_issue: withBundleStep(UpdateJiraIssueToolUI), - delete_jira_issue: withBundleStep(DeleteJiraIssueToolUI), - create_confluence_page: withBundleStep(CreateConfluencePageToolUI), - update_confluence_page: withBundleStep(UpdateConfluencePageToolUI), - delete_confluence_page: withBundleStep(DeleteConfluencePageToolUI), - web_search: () => null, - link_preview: () => null, - multi_link_preview: () => null, - scrape_webpage: () => null, +/** + * Tools rendered in the message BODY — value-add deliverables only. + * + * Process tools (connector CRUD, sandbox execute, memory updates, + * etc.) are NOT here; they render in the timeline via the slice's + * tool registry (see ``features/chat-messages/timeline``). The body + * opts out of every other tool by registering ``NullBodyTool`` as the + * fallback — any tool name not in this map renders nothing in the + * body and is picked up by the timeline instead. + */ +const BODY_TOOLS = { + generate_report: GenerateReportToolUI, + generate_resume: GenerateResumeToolUI, + generate_podcast: GeneratePodcastToolUI, + generate_video_presentation: GenerateVideoPresentationToolUI, + display_image: GenerateImageToolUI, + generate_image: GenerateImageToolUI, } as const; -const TOOLS_FALLBACK = withBundleStep(ToolFallback); +const NullBodyTool: ToolCallMessagePartComponent = () => null; const AssistantMessageInner: FC = () => { const isMobile = !useMediaQuery("(min-width: 768px)"); @@ -559,8 +394,8 @@ const AssistantMessageInner: FC = () => { Text: MarkdownText, Reasoning: ReasoningMessagePart, tools: { - by_name: TOOLS_BY_NAME, - Fallback: TOOLS_FALLBACK, + by_name: BODY_TOOLS, + Fallback: NullBodyTool, }, }} /> diff --git a/surfsense_web/components/assistant-ui/inline-mention-editor.tsx b/surfsense_web/components/assistant-ui/inline-mention-editor.tsx index c585dc80f..e12556486 100644 --- a/surfsense_web/components/assistant-ui/inline-mention-editor.tsx +++ b/surfsense_web/components/assistant-ui/inline-mention-editor.tsx @@ -1,5 +1,6 @@ "use client"; +import { Folder as FolderIcon } from "lucide-react"; import type { PlateElementProps } from "platejs/react"; import { createPlatePlugin, @@ -9,23 +10,51 @@ import { usePlateEditor, } from "platejs/react"; import { type FC, forwardRef, useCallback, useImperativeHandle, useMemo, useRef } from "react"; +import { FOLDER_MENTION_DOCUMENT_TYPE } from "@/atoms/chat/mentioned-documents.atom"; import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import type { Document } from "@/contracts/types/document.types"; import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; import { cn } from "@/lib/utils"; +export type MentionKind = "doc" | "folder"; + export interface MentionedDocument { id: number; title: string; document_type?: string; + kind: MentionKind; } +/** + * Input shape for inserting a chip. ``kind`` defaults to ``"doc"`` + * when omitted so legacy callers don't have to thread the + * discriminator. Folder callers pass ``kind: "folder"`` and the + * folder ``id`` and ``title``; ``document_type`` defaults to + * ``FOLDER_MENTION_DOCUMENT_TYPE`` inside ``insertMentionChip`` so the + * dedup key (`kind:document_type:id`) never collides with a doc chip + * that happens to share an id. + */ +export type MentionChipInput = { + id: number; + title: string; + document_type?: string; + kind?: MentionKind; +}; + export interface InlineMentionEditorRef { focus: () => void; clear: () => void; setText: (text: string) => void; getText: () => string; getMentionedDocuments: () => MentionedDocument[]; + insertMentionChip: ( + mention: MentionChipInput, + options?: { removeTriggerText?: boolean } + ) => void; + /** + * @deprecated Use ``insertMentionChip``. Kept for one transition + * cycle so we don't break ad-hoc callers; prefer the new name. + */ insertDocumentChip: ( doc: Pick, options?: { removeTriggerText?: boolean } @@ -61,6 +90,13 @@ type MentionElementNode = { id: number; title: string; document_type?: string; + /** + * Discriminator added so a folder chip and a doc chip with the + * same id round-trip cleanly through ``getMentionedDocuments`` + * and the persisted ``mentioned-documents`` content part. + * Defaults to ``"doc"`` for nodes that predate this field. + */ + kind?: MentionKind; statusLabel?: string | null; statusKind?: MentionStatusKind; children: [{ text: "" }]; @@ -90,11 +126,17 @@ const MentionElement: FC> = ({ ? "text-emerald-700" : "text-amber-700"; + const isFolder = element.kind === "folder"; + return ( - {getConnectorIcon(element.document_type ?? "UNKNOWN", "h-3 w-3")} + {isFolder ? ( + + ) : ( + getConnectorIcon(element.document_type ?? "UNKNOWN", "h-3 w-3") + )} {element.title} @@ -153,10 +195,12 @@ function getMentionedDocuments(value: ComposerValue): MentionedDocument[] { for (const block of value) { for (const node of block.children) { if (!isMentionNode(node)) continue; + const kind: MentionKind = node.kind ?? "doc"; const doc: MentionedDocument = { id: node.id, title: node.title, document_type: node.document_type, + kind, }; map.set(getMentionDocKey(doc), doc); } @@ -311,21 +355,23 @@ export const InlineMentionEditor = forwardRef, - options?: { removeTriggerText?: boolean } - ) => { - if (typeof doc.id !== "number" || typeof doc.title !== "string") return; + const insertMentionChip = useCallback( + (mention: MentionChipInput, options?: { removeTriggerText?: boolean }) => { + if (typeof mention.id !== "number" || typeof mention.title !== "string") return; const removeTriggerText = options?.removeTriggerText ?? true; const current = getCurrentValue(); const selection = editor.selection; + const kind: MentionKind = mention.kind ?? "doc"; + const document_type = + mention.document_type ?? + (kind === "folder" ? FOLDER_MENTION_DOCUMENT_TYPE : undefined); const mentionNode: MentionElementNode = { type: MENTION_TYPE, - id: doc.id, - title: doc.title, - document_type: doc.document_type, + id: mention.id, + title: mention.title, + document_type, + kind, children: [{ text: "" }], }; @@ -385,6 +431,19 @@ export const InlineMentionEditor = forwardRef, + options?: { removeTriggerText?: boolean } + ) => { + insertMentionChip({ ...doc, kind: "doc" }, options); + }, + [insertMentionChip] + ); + const removeDocumentChip = useCallback( (docId: number, docType?: string) => { const current = getCurrentValue(); @@ -460,6 +519,7 @@ export const InlineMentionEditor = forwardRef { - event.preventDefault(); - event.stopPropagation(); - void (async () => { - if (electronAPI) { - let resolvedLocalPath = path; - if (electronAPI.getAgentFilesystemMounts) { - try { - const mounts = (await electronAPI.getAgentFilesystemMounts( - resolvedSearchSpaceId - )) as AgentFilesystemMount[]; - resolvedLocalPath = normalizeLocalVirtualPathForEditor(path, mounts); - } catch { - // Fall back to the raw path if mount lookup fails. - } - } - openEditorPanel({ - kind: "local_file", - localFilePath: resolvedLocalPath, - title: resolvedLocalPath.split("/").pop() || resolvedLocalPath, - searchSpaceId: resolvedSearchSpaceId, - }); - return; - } + const { displayName, isFolder } = getVirtualPathDisplay(path); + const icon = isFolder ? ( + + ) : ( + + ); - if (!resolvedSearchSpaceId || !path.startsWith("/documents/")) return; - try { - const doc = await documentsApiService.getDocumentByVirtualPath({ - search_space_id: resolvedSearchSpaceId, - virtual_path: path, - }); - openEditorPanel({ - kind: "document", - documentId: doc.id, - searchSpaceId: resolvedSearchSpaceId, - title: doc.title, - }); - } catch { - toast.error("Document not found in knowledge base."); + const handleClick = useCallback( + (event: React.MouseEvent) => { + event.preventDefault(); + event.stopPropagation(); + void (async () => { + if (electronAPI) { + let resolvedLocalPath = path; + if (electronAPI.getAgentFilesystemMounts) { + try { + const mounts = (await electronAPI.getAgentFilesystemMounts( + resolvedSearchSpaceId + )) as AgentFilesystemMount[]; + resolvedLocalPath = normalizeLocalVirtualPathForEditor(path, mounts); + } catch { + // Fall back to the raw path if mount lookup fails. + } } - })(); - }} - title="Open in editor panel" - > - {path} - + openEditorPanel({ + kind: "local_file", + localFilePath: resolvedLocalPath, + title: resolvedLocalPath.split("/").pop() || resolvedLocalPath, + searchSpaceId: resolvedSearchSpaceId, + }); + return; + } + + if (!resolvedSearchSpaceId || !path.startsWith("/documents/")) return; + try { + const doc = await documentsApiService.getDocumentByVirtualPath({ + search_space_id: resolvedSearchSpaceId, + virtual_path: path, + }); + openEditorPanel({ + kind: "document", + documentId: doc.id, + searchSpaceId: resolvedSearchSpaceId, + title: doc.title, + }); + } catch { + toast.error("Document not found in knowledge base."); + } + })(); + }, + [electronAPI, openEditorPanel, path, resolvedSearchSpaceId] + ); + + // Folders cannot open in the editor panel — keep them as visual chips. + const onClick = isFolder ? undefined : handleClick; + + return ( + ); } diff --git a/surfsense_web/components/assistant-ui/mention-chip.tsx b/surfsense_web/components/assistant-ui/mention-chip.tsx new file mode 100644 index 000000000..9f9c9b177 --- /dev/null +++ b/surfsense_web/components/assistant-ui/mention-chip.tsx @@ -0,0 +1,92 @@ +"use client"; + +import type { MouseEventHandler, ReactNode } from "react"; +import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; +import { cn } from "@/lib/utils"; + +/** + * A single, minimal chip-button used in two places: + * + * 1. User-message mention chips (rendered for every `@`-mention the user + * inserted in the composer). + * 2. AI-answer file/folder paths (rendered when the assistant emits + * `/documents/.../file.xml` or `//.../file.ext`). + * + * Both contexts want the same visual language: a compact, button-styled + * chip with an icon, a truncated label, and an optional tooltip. Sharing + * one component keeps the chat surface visually coherent and means a UX + * tweak (radius, hover, icon size) lands in both places at once. + * + * Styling rules (per shadcn skill): + * - Semantic tokens only (`border`, `bg-background`, `bg-accent`, + * `text-foreground`, `text-muted-foreground`). No raw colors. + * - Layout via `gap-*`, never `space-x-*`. + * - `cn()` for conditional classes. + * - No manual `z-index` — the tooltip handles its own stacking. + */ +export interface MentionChipProps { + /** + * Visual prefix. Keep this small (e.g. `size-3.5`); the chip controls + * its own height and oversized icons will push the label out of place. + */ + icon: ReactNode; + /** Label shown inside the chip; truncated with `…` past the max width. */ + label: string; + /** + * Full title or path shown on hover. Omit to suppress the tooltip + * entirely (e.g. when the label already conveys the full identity). + */ + tooltip?: ReactNode; + /** + * When provided, the chip behaves like a button (focusable, hover + * effect, pointer cursor). Omit for a purely decorative chip. + */ + onClick?: MouseEventHandler; + disabled?: boolean; + className?: string; + /** Optional override for the accessible name; defaults to `label`. */ + ariaLabel?: string; +} + +export function MentionChip({ + icon, + label, + tooltip, + onClick, + disabled, + className, + ariaLabel, +}: MentionChipProps) { + const isInteractive = Boolean(onClick) && !disabled; + + const chip = ( + + ); + + if (!tooltip) return chip; + + return ( + + {chip} + + {tooltip} + + + ); +} diff --git a/surfsense_web/components/assistant-ui/reasoning-message-part.tsx b/surfsense_web/components/assistant-ui/reasoning-message-part.tsx index 70636eab8..6e7aaf048 100644 --- a/surfsense_web/components/assistant-ui/reasoning-message-part.tsx +++ b/surfsense_web/components/assistant-ui/reasoning-message-part.tsx @@ -7,8 +7,8 @@ import { TextShimmerLoader } from "@/components/prompt-kit/loader"; import { cn } from "@/lib/utils"; /** - * Renders the structured `reasoning` part emitted by the backend's - * stream-parity v2 path (A1). + * Renders the structured `reasoning` part emitted by the backend stream + * (typed reasoning deltas from the chat model). * * Behaviour mirrors the existing `ThinkingStepsDisplay`: * - collapsed by default; diff --git a/surfsense_web/components/assistant-ui/thinking-steps.tsx b/surfsense_web/components/assistant-ui/thinking-steps.tsx deleted file mode 100644 index df1cef12c..000000000 --- a/surfsense_web/components/assistant-ui/thinking-steps.tsx +++ /dev/null @@ -1,175 +0,0 @@ -import { makeAssistantDataUI, useAuiState } from "@assistant-ui/react"; -import { ChevronRightIcon } from "lucide-react"; -import type { FC } from "react"; -import { useCallback, useEffect, useState } from "react"; -import { ChainOfThoughtItem } from "@/components/prompt-kit/chain-of-thought"; -import { TextShimmerLoader } from "@/components/prompt-kit/loader"; -import { cn } from "@/lib/utils"; - -export interface ThinkingStep { - id: string; - title: string; - items: string[]; - status: "pending" | "in_progress" | "completed"; -} - -/** - * Chain of thought display component - single collapsible dropdown design - */ -export const ThinkingStepsDisplay: FC<{ steps: ThinkingStep[]; isThreadRunning?: boolean }> = ({ - steps, - isThreadRunning = true, -}) => { - const getEffectiveStatus = useCallback( - (step: ThinkingStep): "pending" | "in_progress" | "completed" => { - if (step.status === "in_progress" && !isThreadRunning) { - return "completed"; - } - return step.status; - }, - [isThreadRunning] - ); - - const inProgressStep = steps.find((s) => getEffectiveStatus(s) === "in_progress"); - const allCompleted = - steps.length > 0 && - !isThreadRunning && - steps.every((s) => getEffectiveStatus(s) === "completed"); - const isProcessing = isThreadRunning && !allCompleted; - const [isOpen, setIsOpen] = useState(() => isProcessing); - - useEffect(() => { - if (isProcessing) { - setIsOpen(true); - return; - } - - if (allCompleted) { - setIsOpen(false); - } - }, [allCompleted, isProcessing]); - - if (steps.length === 0) return null; - - const getHeaderText = () => { - if (allCompleted) { - return "Reviewed"; - } - if (inProgressStep) { - return inProgressStep.title; - } - if (isProcessing) { - return "Processing"; - } - return "Reviewed"; - }; - - return ( -
-
- - -
-
-
- {steps.map((step, index) => { - const effectiveStatus = getEffectiveStatus(step); - const isLast = index === steps.length - 1; - - return ( -
-
- {!isLast && ( -
- )} -
- {effectiveStatus === "in_progress" ? ( - - - - - ) : ( - - )} -
-
- -
-
- {step.title} -
- - {step.items && step.items.length > 0 && ( -
- {step.items.map((item) => ( - - {item} - - ))} -
- )} -
-
- ); - })} -
-
-
-
-
- ); -}; - -/** - * assistant-ui data UI component that renders thinking steps from message content. - * Registered globally via makeAssistantDataUI — renders inside MessagePrimitive.Parts - * at the position of the data part in the content array. - */ -function ThinkingStepsDataRenderer({ data }: { name: string; data: unknown }) { - const isThreadRunning = useAuiState(({ thread }) => thread.isRunning); - const isLastMessage = useAuiState(({ message }) => message?.isLast ?? false); - const isMessageStreaming = isThreadRunning && isLastMessage; - - const steps = (data as { steps: ThinkingStep[] } | null)?.steps ?? []; - if (steps.length === 0) return null; - - return ( -
- -
- ); -} - -export const ThinkingStepsDataUI = makeAssistantDataUI({ - name: "thinking-steps", - render: ThinkingStepsDataRenderer, -}); diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index b4a3b58c6..ec7d19eff 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -36,7 +36,10 @@ import { } from "@/atoms/agent-tools/agent-tools.atoms"; import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom"; import { currentThreadAtom } from "@/atoms/chat/current-thread.atom"; -import { mentionedDocumentsAtom } from "@/atoms/chat/mentioned-documents.atom"; +import { + type MentionedDocumentInfo, + mentionedDocumentsAtom, +} from "@/atoms/chat/mentioned-documents.atom"; import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom"; import { clearPremiumAlertForThreadAtom, @@ -87,7 +90,6 @@ import { getToolDisplayName, getToolIcon, } from "@/contracts/enums/toolIcons"; -import type { Document } from "@/contracts/types/document.types"; import { useBatchCommentsPreload } from "@/hooks/use-comments"; import { useCommentsSync } from "@/hooks/use-comments-sync"; import { useMediaQuery } from "@/hooks/use-media-query"; @@ -377,9 +379,7 @@ const Composer: FC = () => { const [mentionQuery, setMentionQuery] = useState(""); const [actionQuery, setActionQuery] = useState(""); const editorRef = useRef(null); - const prevMentionedDocsRef = useRef< - Map> - >(new Map()); + const prevMentionedDocsRef = useRef>(new Map()); const documentPickerRef = useRef(null); const promptPickerRef = useRef(null); const { search_space_id, chat_id } = useParams(); @@ -622,20 +622,20 @@ const Composer: FC = () => { ); const handleDocumentsMention = useCallback( - (documents: Pick[]) => { + (mentions: MentionedDocumentInfo[]) => { const editorMentionedDocs = editorRef.current?.getMentionedDocuments() ?? []; const editorDocKeys = new Set(editorMentionedDocs.map((doc) => getMentionDocKey(doc))); - for (const doc of documents) { - const key = getMentionDocKey(doc); + for (const mention of mentions) { + const key = getMentionDocKey(mention); if (editorDocKeys.has(key)) continue; - editorRef.current?.insertDocumentChip(doc); + editorRef.current?.insertMentionChip(mention); } setMentionedDocuments((prev) => { const existingKeySet = new Set(prev.map((d) => getMentionDocKey(d))); - const uniqueNewDocs = documents.filter((doc) => !existingKeySet.has(getMentionDocKey(doc))); - return [...prev, ...uniqueNewDocs]; + const uniqueNew = mentions.filter((m) => !existingKeySet.has(getMentionDocKey(m))); + return [...prev, ...uniqueNew]; }); setMentionQuery(""); @@ -657,7 +657,7 @@ const Composer: FC = () => { for (const [key, doc] of nextDocsMap) { if (prevDocsMap.has(key) || editorKeys.has(key)) continue; - editor.insertDocumentChip(doc, { removeTriggerText: false }); + editor.insertMentionChip(doc, { removeTriggerText: false }); } for (const [key, doc] of prevDocsMap) { diff --git a/surfsense_web/components/assistant-ui/tool-fallback.tsx b/surfsense_web/components/assistant-ui/tool-fallback.tsx deleted file mode 100644 index 06082c9c7..000000000 --- a/surfsense_web/components/assistant-ui/tool-fallback.tsx +++ /dev/null @@ -1,512 +0,0 @@ -import { type ToolCallMessagePartComponent, useAuiState } from "@assistant-ui/react"; -import { useQueryClient } from "@tanstack/react-query"; -import { useAtomValue } from "jotai"; -import { CheckIcon, ChevronDownIcon, RotateCcw, XCircleIcon } from "lucide-react"; -import { useEffect, useMemo, useState } from "react"; -import { toast } from "sonner"; -import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom"; -import { NestedScroll } from "@/components/assistant-ui/nested-scroll"; -import { - DoomLoopApprovalToolUI, - isDoomLoopInterrupt, -} from "@/components/tool-ui/doom-loop-approval"; -import { GenericHitlApprovalToolUI } from "@/components/tool-ui/generic-hitl-approval"; -import { - AlertDialog, - AlertDialogAction, - AlertDialogCancel, - AlertDialogContent, - AlertDialogDescription, - AlertDialogFooter, - AlertDialogHeader, - AlertDialogTitle, - AlertDialogTrigger, -} from "@/components/ui/alert-dialog"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { Card } from "@/components/ui/card"; -import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/components/ui/collapsible"; -import { Separator } from "@/components/ui/separator"; -import { Spinner } from "@/components/ui/spinner"; -import { getToolDisplayName } from "@/contracts/enums/toolIcons"; -import { markActionRevertedInCache, useAgentActionsQuery } from "@/hooks/use-agent-actions-query"; -import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; -import { AppError } from "@/lib/error"; -import { isInterruptResult } from "@/lib/hitl"; -import { cn } from "@/lib/utils"; - -/** - * Inline Revert button rendered on a tool card when the matching - * ``AgentActionLog`` row is reversible and hasn't been reverted yet. - * - * Reads from the unified ``useAgentActionsQuery`` cache — the SAME - * react-query cache the agent-actions sheet consumes. SSE events - * (``data-action-log`` / ``data-action-log-updated``) and - * ``POST /threads/{id}/revert/{id}`` responses both flow through the - * cache via ``setQueryData`` helpers, so the card and the sheet stay - * in lockstep on every code path: page reload, navigation, live - * stream, post-stream reversibility flip, and explicit revert clicks. - * - * Match key (in priority order): - * 1. ``a.tool_call_id === toolCallId`` — direct hit in parity_v2 when - * the model streamed ``tool_call_chunks`` so the card's synthetic - * id IS the LangChain id. - * 2. ``a.tool_call_id === langchainToolCallId`` — legacy mode (or - * parity_v2 with provider-side chunk emission) where the card's - * synthetic id is ``call_`` and the LangChain id is - * backfilled onto the part by ``tool-output-available``. - * 3. ``(chat_turn_id, tool_name, position-within-turn)`` — fallback - * for cards whose synthetic id is ``call_`` AND whose - * ``langchainToolCallId`` never got backfilled (provider emitted - * the tool_call as a single payload with no chunks AND streaming - * pre-dated the ``tool-output-available langchainToolCallId`` - * backfill, e.g. older threads). Reads the parent message's - * ``chatTurnId`` and ``content`` via ``useAuiState`` so we can - * match position-by-tool-name within the turn against the - * action_log rows the server returned in ``created_at`` order. - */ -function ToolCardRevertButton({ - toolCallId, - toolName, - langchainToolCallId, -}: { - toolCallId: string; - toolName: string; - langchainToolCallId?: string; -}) { - const session = useAtomValue(chatSessionStateAtom); - const threadId = session?.threadId ?? null; - const queryClient = useQueryClient(); - const { findByToolCallId, findByChatTurnAndTool } = useAgentActionsQuery(threadId); - - // Parent message metadata, read via the narrowest possible - // selectors so this card doesn't re-render on every text-delta of - // every other part in the same message during streaming. - // - // IMPORTANT — ``useAuiState`` re-renders the component whenever the - // returned slice's identity changes. Returning ``message?.content`` - // (an array) would re-render on every token because the runtime - // rebuilds the parts array. Returning a PRIMITIVE (the position - // number) lets ``useAuiState``'s ``Object.is`` check short-circuit - // when the position hasn't actually moved — which is the common - // case during text streaming, when only ``text``/``reasoning`` - // parts are mutating and the same-toolName tool-call ordering is - // stable. (See Vercel React rule ``rerender-defer-reads``.) - const chatTurnId = useAuiState(({ message }) => { - const meta = message?.metadata as { custom?: { chatTurnId?: string } } | undefined; - return meta?.custom?.chatTurnId ?? null; - }); - const positionInTurn = useAuiState(({ message }) => { - const content = message?.content; - if (!Array.isArray(content)) return -1; - let n = -1; - for (const part of content) { - if ( - part && - typeof part === "object" && - (part as { type?: string }).type === "tool-call" && - (part as { toolName?: string }).toolName === toolName - ) { - n += 1; - if ((part as { toolCallId?: string }).toolCallId === toolCallId) return n; - } - } - return -1; - }); - - const action = useMemo(() => { - // Tier 1 + 2: O(1) Map-backed direct id match. Covers - // ~all parity_v2 streams and any legacy stream that backfilled - // ``langchainToolCallId`` via ``tool-output-available``. - const direct = findByToolCallId(toolCallId) ?? findByToolCallId(langchainToolCallId); - if (direct) return direct; - // Tier 3: position-within-turn fallback. Only kicks in when the - // card has a synthetic ``call_`` id AND no - // ``langchainToolCallId`` was ever backfilled — i.e. the tool - // was emitted as a single non-chunked payload AND streaming - // pre-dated the on_tool_end backfill. - if (!chatTurnId || positionInTurn < 0) return null; - const turnSameTool = findByChatTurnAndTool(chatTurnId, toolName); - return turnSameTool[positionInTurn] ?? null; - }, [ - findByToolCallId, - findByChatTurnAndTool, - toolCallId, - langchainToolCallId, - chatTurnId, - toolName, - positionInTurn, - ]); - - const [isReverting, setIsReverting] = useState(false); - const [confirmOpen, setConfirmOpen] = useState(false); - - if (!action) return null; - if (!action.reversible) return null; - if (action.reverted_by_action_id !== null && action.reverted_by_action_id !== undefined) - return null; - if (action.is_revert_action) return null; - if (action.error !== null && action.error !== undefined) return null; - if (!threadId) return null; - - const handleRevert = async () => { - setIsReverting(true); - try { - const response = await agentActionsApiService.revert(threadId, action.id); - markActionRevertedInCache(queryClient, threadId, action.id, response.new_action_id ?? null); - toast.success(response.message || "Action reverted."); - } catch (err) { - // 503 means revert is gated off on this deployment — hide the - // button silently rather than nagging the user. Any other error - // is surfaced as a toast so the operator can investigate. - if (err instanceof AppError && err.status === 503) { - return; - } - const message = - err instanceof AppError - ? err.message - : err instanceof Error - ? err.message - : "Failed to revert action."; - toast.error(message); - } finally { - setIsReverting(false); - setConfirmOpen(false); - } - }; - - return ( - - - - - - - Revert this action? - - This will undo{" "} - {getToolDisplayName(action.tool_name)} and add a - new entry to the history. Your chat is preserved — only the changes the agent made to - your knowledge base or connected apps will be rolled back where possible. - - - - Cancel - { - e.preventDefault(); - handleRevert(); - }} - disabled={isReverting} - className="gap-1.5" - > - {isReverting && } - Revert - - - - - ); -} - -/** - * Compact tool-call card. - * - * shadcn composition note: we intentionally use ``Card`` as a visual - * frame WITHOUT ``CardHeader / CardContent``. The full composition's - * ``p-6`` padding doesn't fit a compact collapsible header that IS the - * trigger; using ``Card`` alone preserves the rounded border, shadow, - * and ``bg-card`` token (semantic colors) without forcing a layout - * that doesn't fit. All status colors use semantic tokens — no manual - * dark-mode overrides, no raw hex. - */ -const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => { - const { toolCallId, toolName, argsText, result, status } = props; - // ``langchainToolCallId`` is a SurfSense-specific extension the - // streaming pipeline attaches to the tool-call content part so - // the Revert button can resolve its ``AgentActionLog`` row even - // when only the LC id is known. assistant-ui's - // ``ToolCallMessagePartProps`` doesn't list it, but the runtime - // spreads ``{...part}`` so the prop reaches us at runtime. - const langchainToolCallId = (props as { langchainToolCallId?: string }).langchainToolCallId; - - const isCancelled = status?.type === "incomplete" && status.reason === "cancelled"; - const isError = status?.type === "incomplete" && status.reason === "error"; - const isRunning = status?.type === "running" || status?.type === "requires-action"; - - /* - Per-card expansion state. Initial value is ``isRunning`` so a - card streaming in mounts already-expanded (no flash of - collapsed → expanded on first paint), while a card loaded from - history (status="complete") mounts collapsed. The useEffect - below keeps this in lockstep with this card's own ``isRunning`` - when it transitions: false → true auto-expands (e.g. a tool - that re-runs after edit), true → false auto-collapses once the - tool finishes. Because the dep is per-card ``isRunning`` and - not the chat-level streaming flag, sibling cards on the same - assistant turn each manage their own expansion independently. - Once ``isRunning`` is false the user controls expansion via - ``onOpenChange``. - */ - const [isExpanded, setIsExpanded] = useState(isRunning); - useEffect(() => { - setIsExpanded(isRunning); - }, [isRunning]); - const errorData = status?.type === "incomplete" ? status.error : undefined; - const serializedError = useMemo( - () => (errorData && typeof errorData !== "string" ? JSON.stringify(errorData) : null), - [errorData] - ); - - const serializedResult = useMemo( - () => - result !== undefined && typeof result !== "string" ? JSON.stringify(result, null, 2) : null, - [result] - ); - - const cancelledReason = - isCancelled && status.error - ? typeof status.error === "string" - ? status.error - : serializedError - : null; - const errorReason = - isError && status.error - ? typeof status.error === "string" - ? status.error - : serializedError - : null; - - const displayName = getToolDisplayName(toolName); - const subtitle = errorReason ?? cancelledReason; - - return ( - - {/* - ``group`` lets the chevron (rendered as a sibling of the - main trigger button) read the Collapsible Root's - ``data-[state=open]`` for rotation. The Collapsible is - fully controlled via ``isExpanded`` — the useEffect - above syncs it to ``isRunning`` so the card auto-opens - while a tool streams in and auto-collapses once it - finishes. We deliberately DON'T pass ``disabled`` so - both triggers stay clickable; ``onOpenChange`` is wired - to a setter that no-ops while ``isRunning`` (see - ``handleOpenChange`` below) which keeps the card pinned - open mid-stream without losing keyboard / pointer - affordance the moment streaming ends. - */} - { - // Block manual collapse while the tool is still - // streaming — otherwise a stray click on either - // trigger would close the card and hide the live - // ``argsText`` panel mid-run. After streaming the - // user has full control again. - if (isRunning) return; - setIsExpanded(next); - }} - > - {/* - Header row: main trigger on the left (icon + title - col), Revert + chevron-trigger on the right as - siblings of the main trigger. The chevron is wrapped - in its OWN ``CollapsibleTrigger`` (Radix supports - multiple triggers per Root) so clicking the chevron - toggles the same state as clicking the title row. - The Revert button stays a separate AlertDialog - trigger and stops propagation in its onClick so it - doesn't toggle the collapsible while opening the - confirm dialog. Keeping these as flat siblings — - rather than nesting Revert / chevron inside the - title trigger — avoids invalid HTML - (button-in-button) and lets the Revert button - render in BOTH the collapsed and expanded states. - */} -
- - - - - {/* - Right-side controls. The Revert button is - visible whenever the matching action is - reversible — including the collapsed state — - but ``ToolCardRevertButton`` itself returns - ``null`` while a tool is still running because - no action-log row exists yet, so it doesn't - need an explicit ``isRunning`` gate here. - */} -
- - - - -
-
- - {/* - CollapsibleContent body — auto-open while streaming - (see ``open`` prop above) so the live ``argsText`` - streams into the Inputs panel directly, no need for - a separate "Live input" panel. Native - ``overflow-auto`` instead of ``ScrollArea`` because - Radix's Viewport can let content bleed past - ``max-h-*`` in dynamic flex layouts. ``min-w-0`` on - the column wrappers guarantees ``break-all`` wraps - correctly within the bounded ``max-w-lg`` Card. - */} - - -
- {(argsText || isRunning) && ( -
-

Inputs

- - {argsText ? ( -
-											{argsText}
-										
- ) : ( - // Bridges the brief gap between - // ``tool-input-start`` (creates the - // card, ``argsText`` undefined) and - // the first ``tool-input-delta``. -

- Waiting for input… -

- )} -
-
- )} - {!isCancelled && result !== undefined && ( - <> - -
-

Result

- -
-											{typeof result === "string" ? result : serializedResult}
-										
-
-
- - )} -
-
-
-
- ); -}; - -export const ToolFallback: ToolCallMessagePartComponent = (props) => { - if (isInterruptResult(props.result)) { - if (isDoomLoopInterrupt(props.result)) { - return ; - } - return ; - } - return ; -}; diff --git a/surfsense_web/components/assistant-ui/user-message.tsx b/surfsense_web/components/assistant-ui/user-message.tsx index 145ac2d7e..b09aa7680 100644 --- a/surfsense_web/components/assistant-ui/user-message.tsx +++ b/surfsense_web/components/assistant-ui/user-message.tsx @@ -5,12 +5,16 @@ import { useAuiState, useMessagePartText, } from "@assistant-ui/react"; -import { useAtomValue } from "jotai"; -import { CheckIcon, CopyIcon, Pencil } from "lucide-react"; +import { useAtomValue, useSetAtom } from "jotai"; +import { CheckIcon, CopyIcon, Folder as FolderIcon, Pencil } from "lucide-react"; import Image from "next/image"; -import { type FC, useState } from "react"; +import { useParams } from "next/navigation"; +import { type FC, useCallback, useState } from "react"; +import { toast } from "sonner"; import { currentThreadAtom } from "@/atoms/chat/current-thread.atom"; import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom"; +import { openEditorPanelAtom } from "@/atoms/editor/editor-panel.atom"; +import { MentionChip } from "@/components/assistant-ui/mention-chip"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; @@ -61,27 +65,61 @@ const UserTextPart: FC = () => { const text = (part as { text?: string }).text ?? ""; const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom); const mentionedDocs = (messageId ? messageDocumentsMap[messageId] : undefined) ?? []; + const openEditorPanel = useSetAtom(openEditorPanelAtom); + const params = useParams(); + const searchSpaceIdParam = params?.search_space_id; + const parsedSearchSpaceId = Array.isArray(searchSpaceIdParam) + ? Number(searchSpaceIdParam[0]) + : Number(searchSpaceIdParam); + const resolvedSearchSpaceId = Number.isFinite(parsedSearchSpaceId) + ? parsedSearchSpaceId + : undefined; + + const handleOpenDoc = useCallback( + (docId: number, title: string) => { + if (!resolvedSearchSpaceId) { + toast.error("Cannot open document outside a search space."); + return; + } + openEditorPanel({ + kind: "document", + documentId: docId, + searchSpaceId: resolvedSearchSpaceId, + title, + }); + }, + [openEditorPanel, resolvedSearchSpaceId] + ); const segments = parseMentionSegments(text, mentionedDocs); return ( -

- {segments.map((segment) => - segment.type === "text" ? ( - {segment.value} +

+ {segments.map((segment) => { + if (segment.type === "text") { + return {segment.value}; + } + const isFolder = segment.doc.kind === "folder"; + const icon = isFolder ? ( + ) : ( - - - {getConnectorIcon(segment.doc.document_type ?? "UNKNOWN", "h-3 w-3")} - - {segment.doc.title} - - ) - )} + icon={icon} + label={segment.doc.title} + tooltip={isFolder ? `Folder: ${segment.doc.title}` : segment.doc.title} + onClick={ + isFolder + ? undefined + : () => handleOpenDoc(segment.doc.id, segment.doc.title) + } + className="mx-0.5" + /> + ); + })}

); }; diff --git a/surfsense_web/components/documents/FolderTreeView.tsx b/surfsense_web/components/documents/FolderTreeView.tsx index 2063fbee5..eeeb1e779 100644 --- a/surfsense_web/components/documents/FolderTreeView.tsx +++ b/surfsense_web/components/documents/FolderTreeView.tsx @@ -176,34 +176,25 @@ export function FolderTreeView({ }, [folders, docsByFolder, foldersByParent, effectiveActiveTypes, searchQuery]); const folderSelectionStates = useMemo(() => { + // One folder = one chip. The checkbox now reflects whether the + // folder itself is mentioned, not whether every nested doc is — + // that reverses the old subtree-fanout semantics in + // ``DocumentsSidebar.handleToggleFolderSelect``. We keep the + // ``"all" | "some" | "none"`` tri-state on the type so the + // existing ``FolderNode`` UI (which renders an indeterminate + // glyph for ``"some"``) stays compatible, but only ``"all"`` + // and ``"none"`` are used in practice. const states: Record = {}; - const isSelectable = (d: DocumentNodeDoc) => - d.status?.state !== "pending" && d.status?.state !== "processing"; - - function compute(folderId: number): { selected: number; total: number } { - const directDocs = (docsByFolder[folderId] ?? []).filter(isSelectable); - let selected = directDocs.filter((d) => mentionedDocKeys.has(getMentionDocKey(d))).length; - let total = directDocs.length; - - for (const child of foldersByParent[folderId] ?? []) { - const sub = compute(child.id); - selected += sub.selected; - total += sub.total; - } - - if (total === 0) states[folderId] = "none"; - else if (selected === total) states[folderId] = "all"; - else if (selected > 0) states[folderId] = "some"; - else states[folderId] = "none"; - - return { selected, total }; - } - for (const f of folders) { - if (states[f.id] === undefined) compute(f.id); + const folderMentionKey = getMentionDocKey({ + id: f.id, + document_type: "FOLDER", + kind: "folder", + }); + states[f.id] = mentionedDocKeys.has(folderMentionKey) ? "all" : "none"; } return states; - }, [folders, docsByFolder, foldersByParent, mentionedDocKeys]); + }, [folders, mentionedDocKeys]); const folderMap = useMemo(() => { const map: Record = {}; diff --git a/surfsense_web/components/editor/plate-editor.tsx b/surfsense_web/components/editor/plate-editor.tsx index c42cb991e..51ad7d700 100644 --- a/surfsense_web/components/editor/plate-editor.tsx +++ b/surfsense_web/components/editor/plate-editor.tsx @@ -11,6 +11,7 @@ import { EditorSaveContext } from "@/components/editor/editor-save-context"; import { CitationKit, injectCitationNodes } from "@/components/editor/plugins/citation-kit"; import { type EditorPreset, presetMap } from "@/components/editor/presets"; import { escapeMdxExpressions } from "@/components/editor/utils/escape-mdx"; +import { safeDeserializeMarkdown } from "@/components/editor/utils/safe-deserialize"; import { Editor, EditorContainer } from "@/components/ui/editor"; import { preprocessCitationMarkdown } from "@/lib/citations/citation-parser"; @@ -169,15 +170,17 @@ export function PlateEditor({ : markdown ? (editor) => { if (!enableCitations) { - return editor - .getApi(MarkdownPlugin) - .markdown.deserialize(escapeMdxExpressions(markdown)); + return safeDeserializeMarkdown( + editor, + escapeMdxExpressions(markdown) + ) as Value; } const { content: rewritten, urlMap } = preprocessCitationMarkdown(markdown); - const value = editor - .getApi(MarkdownPlugin) - .markdown.deserialize(escapeMdxExpressions(rewritten)); - return injectCitationNodes(value as Descendant[], urlMap) as Value; + const value = safeDeserializeMarkdown( + editor, + escapeMdxExpressions(rewritten) + ); + return injectCitationNodes(value, urlMap) as Value; } : undefined, }); @@ -200,14 +203,13 @@ export function PlateEditor({ let newValue: Descendant[]; if (enableCitations) { const { content: rewritten, urlMap } = preprocessCitationMarkdown(markdown); - const deserialized = editor - .getApi(MarkdownPlugin) - .markdown.deserialize(escapeMdxExpressions(rewritten)) as Descendant[]; + const deserialized = safeDeserializeMarkdown( + editor, + escapeMdxExpressions(rewritten) + ); newValue = injectCitationNodes(deserialized, urlMap); } else { - newValue = editor - .getApi(MarkdownPlugin) - .markdown.deserialize(escapeMdxExpressions(markdown)) as Descendant[]; + newValue = safeDeserializeMarkdown(editor, escapeMdxExpressions(markdown)); } editor.tf.reset(); editor.tf.setValue(newValue as Value); diff --git a/surfsense_web/components/editor/utils/safe-deserialize.ts b/surfsense_web/components/editor/utils/safe-deserialize.ts new file mode 100644 index 000000000..e359a7791 --- /dev/null +++ b/surfsense_web/components/editor/utils/safe-deserialize.ts @@ -0,0 +1,64 @@ +// --------------------------------------------------------------------------- +// Safe markdown deserialization for the Plate editor +// --------------------------------------------------------------------------- +// `remark-mdx` treats any HTML-like tag as JSX, so unbalanced inline HTML +// (very common in GitHub READMEs, web-scraped pages, PDF conversions) makes +// it throw "Expected a closing tag for ``" and crash the editor. +// +// Per the MDX maintainers' guidance (mdx-js/mdx, ipikuka/next-mdx-remote-client +// #14), MDX is the wrong format for untrusted markdown and the recommended +// fix is to fall back to plain markdown parsing. `MarkdownPlugin.deserialize` +// accepts a per-call `remarkPlugins` override, so we can: +// +// 1. Try with `remarkMdx` (rich MDX features, e.g. JSX-style components). +// 2. On failure, retry without `remarkMdx` (lenient HTML, like GitHub). +// 3. As a last resort, render the raw source in a paragraph so the user +// never sees a crashed editor. +// --------------------------------------------------------------------------- + +import { MarkdownPlugin, remarkMdx } from "@platejs/markdown"; +import type { Descendant } from "platejs"; +import remarkGfm from "remark-gfm"; +import remarkMath from "remark-math"; +import type { PlateEditorInstance } from "@/components/editor/plate-editor"; + +const STRICT_PLUGINS = [remarkGfm, remarkMath, remarkMdx]; +const LENIENT_PLUGINS = [remarkGfm, remarkMath]; + +function plainTextFallback(markdown: string): Descendant[] { + return [ + { + type: "p", + children: [{ text: markdown }], + } as unknown as Descendant, + ]; +} + +/** + * Deserialize markdown into a Plate value, gracefully degrading when the + * MDX-strict parser rejects raw HTML. Always returns a renderable value; + * never throws. + */ +export function safeDeserializeMarkdown( + editor: PlateEditorInstance, + markdown: string +): Descendant[] { + const api = editor.getApi(MarkdownPlugin).markdown; + + try { + return api.deserialize(markdown, { remarkPlugins: STRICT_PLUGINS }) as Descendant[]; + } catch (mdxError) { + if (process.env.NODE_ENV !== "production") { + console.warn( + "[plate-editor] MDX parse failed, retrying without remark-mdx:", + mdxError + ); + } + try { + return api.deserialize(markdown, { remarkPlugins: LENIENT_PLUGINS }) as Descendant[]; + } catch (fallbackError) { + console.error("[plate-editor] markdown deserialize failed:", fallbackError); + return plainTextFallback(markdown); + } + } +} diff --git a/surfsense_web/components/free-chat/free-chat-page.tsx b/surfsense_web/components/free-chat/free-chat-page.tsx index 080d9a2b6..927eaef87 100644 --- a/surfsense_web/components/free-chat/free-chat-page.tsx +++ b/surfsense_web/components/free-chat/free-chat-page.tsx @@ -10,13 +10,13 @@ import { Turnstile, type TurnstileInstance } from "@marsidev/react-turnstile"; import { ShieldCheck } from "lucide-react"; import { useCallback, useEffect, useRef, useState } from "react"; import { StepSeparatorDataUI } from "@/components/assistant-ui/step-separator"; -import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps"; import { createTokenUsageStore, type TokenUsageData, TokenUsageProvider, } from "@/components/assistant-ui/token-usage-context"; import { useAnonymousMode } from "@/contexts/anonymous-mode"; +import { TimelineDataUI } from "@/features/chat-messages/timeline"; import { addStepSeparator, addToolCall, @@ -228,7 +228,8 @@ export function FreeChatPage() { parsed.toolName, {}, false, - parsed.langchainToolCallId + parsed.langchainToolCallId, + parsed.metadata ); forceFlush(); break; @@ -245,6 +246,7 @@ export function FreeChatPage() { args: parsed.input || {}, argsText: finalArgsText, langchainToolCallId: parsed.langchainToolCallId, + metadata: parsed.metadata, }); } else { addToolCall( @@ -254,7 +256,8 @@ export function FreeChatPage() { parsed.toolName, parsed.input || {}, false, - parsed.langchainToolCallId + parsed.langchainToolCallId, + parsed.metadata ); updateToolCall(contentPartsState, parsed.toolCallId, { argsText: finalArgsText, @@ -268,6 +271,7 @@ export function FreeChatPage() { updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output, langchainToolCallId: parsed.langchainToolCallId, + metadata: parsed.metadata, }); forceFlush(); break; @@ -469,7 +473,7 @@ export function FreeChatPage() { return ( - +
diff --git a/surfsense_web/components/hitl-bundle-pager/index.ts b/surfsense_web/components/hitl-bundle-pager/index.ts deleted file mode 100644 index ce434d224..000000000 --- a/surfsense_web/components/hitl-bundle-pager/index.ts +++ /dev/null @@ -1,2 +0,0 @@ -export { PagerChrome } from "./pager-chrome"; -export { withBundleStep } from "./with-bundle-step"; diff --git a/surfsense_web/components/hitl-bundle-pager/pager-chrome.tsx b/surfsense_web/components/hitl-bundle-pager/pager-chrome.tsx deleted file mode 100644 index 77d75fb6d..000000000 --- a/surfsense_web/components/hitl-bundle-pager/pager-chrome.tsx +++ /dev/null @@ -1,61 +0,0 @@ -"use client"; - -import { ChevronLeftIcon, ChevronRightIcon } from "lucide-react"; -import { Button } from "@/components/ui/button"; -import { useHitlBundle } from "@/lib/hitl"; - -/** - * Prev/next nav and Submit for the current step of an active HITL bundle. - * Submission is gated on every action_request having a staged decision. - */ -export function PagerChrome() { - const bundle = useHitlBundle(); - if (!bundle) return null; - - const total = bundle.toolCallIds.length; - const step = bundle.currentStep; - const allStaged = bundle.stagedCount === total; - - return ( -
- - - {step + 1} / {total} - - · - - {bundle.stagedCount} of {total} decided - - -
- -
-
- ); -} diff --git a/surfsense_web/components/hitl-bundle-pager/with-bundle-step.tsx b/surfsense_web/components/hitl-bundle-pager/with-bundle-step.tsx deleted file mode 100644 index 64ac801fb..000000000 --- a/surfsense_web/components/hitl-bundle-pager/with-bundle-step.tsx +++ /dev/null @@ -1,37 +0,0 @@ -"use client"; - -import type { ToolCallMessagePartProps } from "@assistant-ui/react"; -import type { ComponentType } from "react"; -import { ToolCallIdProvider, useHitlBundle } from "@/lib/hitl"; -import { PagerChrome } from "./pager-chrome"; - -/** - * Wrap a tool-ui card so that, when a multi-card HITL bundle is active: - * - cards belonging to the bundle but not the current step render ``null``; - * - the current-step card renders normally and is followed by ``PagerChrome``. - * - * Cards stay completely unchanged — the wrapper provides the - * ``ToolCallIdContext`` that ``useHitlDecision`` reads to stage decisions - * against the right ``toolCallId`` instead of firing the global event. - */ -export function withBundleStep

>( - Component: ComponentType

-): ComponentType

{ - function BundleStepWrapped(props: P) { - const bundle = useHitlBundle(); - const toolCallId = props.toolCallId; - const inBundle = bundle?.isInBundle(toolCallId) ?? false; - const isStep = bundle?.isCurrentStep(toolCallId) ?? false; - - if (bundle && inBundle && !isStep) return null; - - return ( - - - {bundle && isStep ? : null} - - ); - } - BundleStepWrapped.displayName = `withBundleStep(${Component.displayName ?? Component.name ?? "ToolUI"})`; - return BundleStepWrapped as ComponentType

; -} diff --git a/surfsense_web/components/hitl-edit-panel/hitl-edit-panel.tsx b/surfsense_web/components/hitl-edit-panel/hitl-edit-panel.tsx deleted file mode 100644 index b33392f38..000000000 --- a/surfsense_web/components/hitl-edit-panel/hitl-edit-panel.tsx +++ /dev/null @@ -1,405 +0,0 @@ -"use client"; - -import { format } from "date-fns"; -import { TagInput, type Tag as TagType } from "emblor"; -import { useAtomValue, useSetAtom } from "jotai"; -import { CalendarIcon, XIcon } from "lucide-react"; -import dynamic from "next/dynamic"; -import { useCallback, useEffect, useMemo, useRef, useState } from "react"; -import type { ExtraField } from "@/atoms/chat/hitl-edit-panel.atom"; -import { closeHitlEditPanelAtom, hitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom"; -import { Button } from "@/components/ui/button"; -import { Calendar } from "@/components/ui/calendar"; -import { Drawer, DrawerContent, DrawerHandle, DrawerTitle } from "@/components/ui/drawer"; -import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; -import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; -import { Skeleton } from "@/components/ui/skeleton"; -import { Textarea } from "@/components/ui/textarea"; -import { useMediaQuery } from "@/hooks/use-media-query"; - -const PlateEditor = dynamic( - () => import("@/components/editor/plate-editor").then((m) => ({ default: m.PlateEditor })), - { ssr: false, loading: () => } -); - -function parseEmailsToTags(value: string): TagType[] { - if (!value.trim()) return []; - return value - .split(",") - .map((s) => s.trim()) - .filter(Boolean) - .map((email, i) => ({ id: `${Date.now()}-${i}`, text: email })); -} - -function tagsToEmailString(tags: TagType[]): string { - return tags.map((t) => t.text).join(", "); -} - -function EmailsTagField({ - id, - value, - onChange, - placeholder, -}: { - id: string; - value: string; - onChange: (value: string) => void; - placeholder?: string; -}) { - const [tags, setTags] = useState(() => parseEmailsToTags(value)); - const [activeTagIndex, setActiveTagIndex] = useState(null); - const isInitialMount = useRef(true); - const onChangeRef = useRef(onChange); - onChangeRef.current = onChange; - - useEffect(() => { - if (isInitialMount.current) { - isInitialMount.current = false; - return; - } - onChangeRef.current(tagsToEmailString(tags)); - }, [tags]); - - const handleSetTags = useCallback((newTags: TagType[] | ((prev: TagType[]) => TagType[])) => { - setTags((prev) => (typeof newTags === "function" ? newTags(prev) : newTags)); - }, []); - - const handleAddTag = useCallback((text: string) => { - const trimmed = text.trim(); - if (!trimmed) return; - setTags((prev) => { - if (prev.some((tag) => tag.text === trimmed)) return prev; - const newTag: TagType = { id: Date.now().toString(), text: trimmed }; - return [...prev, newTag]; - }); - }, []); - - return ( - - ); -} - -function parseDateTimeValue(value: string): { date: Date | undefined; time: string } { - if (!value) return { date: undefined, time: "09:00" }; - try { - const d = new Date(value); - if (Number.isNaN(d.getTime())) return { date: undefined, time: "09:00" }; - return { - date: d, - time: format(d, "HH:mm"), - }; - } catch { - return { date: undefined, time: "09:00" }; - } -} - -function buildLocalDateTimeString(date: Date | undefined, time: string): string { - if (!date) return ""; - const [hours, minutes] = time.split(":").map(Number); - const combined = new Date(date); - combined.setHours(hours ?? 9, minutes ?? 0, 0, 0); - const y = combined.getFullYear(); - const m = String(combined.getMonth() + 1).padStart(2, "0"); - const d = String(combined.getDate()).padStart(2, "0"); - const h = String(combined.getHours()).padStart(2, "0"); - const min = String(combined.getMinutes()).padStart(2, "0"); - return `${y}-${m}-${d}T${h}:${min}:00`; -} - -function DateTimePickerField({ - id, - value, - onChange, -}: { - id: string; - value: string; - onChange: (value: string) => void; -}) { - const parsed = useMemo(() => parseDateTimeValue(value), [value]); - const [selectedDate, setSelectedDate] = useState(parsed.date); - const [time, setTime] = useState(parsed.time); - const [open, setOpen] = useState(false); - - const handleDateSelect = useCallback( - (day: Date | undefined) => { - setSelectedDate(day); - onChange(buildLocalDateTimeString(day, time)); - setOpen(false); - }, - [time, onChange] - ); - - const handleTimeChange = useCallback( - (e: React.ChangeEvent) => { - const newTime = e.target.value; - setTime(newTime); - onChange(buildLocalDateTimeString(selectedDate, newTime)); - }, - [selectedDate, onChange] - ); - - const displayLabel = selectedDate - ? `${format(selectedDate, "MMM d, yyyy")} at ${time}` - : "Pick date & time"; - - return ( -

- - - - - - - - - -
- ); -} - -export function HitlEditPanelContent({ - title: initialTitle, - content: initialContent, - contentFormat, - extraFields, - onSave, - onClose, - showCloseButton = true, -}: { - title: string; - content: string; - toolName: string; - contentFormat?: "markdown" | "html"; - extraFields?: ExtraField[]; - onSave: (title: string, content: string, extraFieldValues?: Record) => void; - onClose?: () => void; - showCloseButton?: boolean; -}) { - const [editedTitle, setEditedTitle] = useState(initialTitle); - const contentRef = useRef(initialContent); - const [isSaving, setIsSaving] = useState(false); - const [extraFieldValues, setExtraFieldValues] = useState>(() => { - if (!extraFields) return {}; - const initial: Record = {}; - for (const field of extraFields) { - initial[field.key] = field.value; - } - return initial; - }); - - const handleContentChange = useCallback((content: string) => { - contentRef.current = content; - }, []); - - const handleExtraFieldChange = useCallback((key: string, value: string) => { - setExtraFieldValues((prev) => ({ ...prev, [key]: value })); - }, []); - - const handleSave = useCallback(() => { - if (!editedTitle.trim()) return; - setIsSaving(true); - const extras = extraFields && extraFields.length > 0 ? extraFieldValues : undefined; - onSave(editedTitle, contentRef.current, extras); - onClose?.(); - }, [editedTitle, onSave, onClose, extraFields, extraFieldValues]); - - return ( - <> -
- setEditedTitle(e.target.value)} - placeholder="Untitled" - className="flex-1 min-w-0 bg-transparent text-sm font-semibold text-foreground outline-none placeholder:text-muted-foreground" - aria-label="Page title" - /> - {onClose && showCloseButton && ( - - )} -
- - {extraFields && extraFields.length > 0 && ( -
- {extraFields.map((field) => ( -
- - {field.type === "emails" ? ( - handleExtraFieldChange(field.key, v)} - placeholder={`Add ${field.label.toLowerCase()}`} - /> - ) : field.type === "datetime-local" ? ( - handleExtraFieldChange(field.key, v)} - /> - ) : field.type === "textarea" ? ( -