Merge branch 'dev' into feat/e2e-testing

This commit is contained in:
Rohan Verma 2026-05-09 16:10:45 -07:00 committed by GitHub
commit fa31da9937
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
100 changed files with 3751 additions and 1122 deletions

View file

@ -0,0 +1,37 @@
# Deepening
How to deepen a cluster of shallow modules safely, given its dependencies. Assumes the vocabulary in [LANGUAGE.md](LANGUAGE.md) — **module**, **interface**, **seam**, **adapter**.
## Dependency categories
When assessing a candidate for deepening, classify its dependencies. The category determines how the deepened module is tested across its seam.
### 1. In-process
Pure computation, in-memory state, no I/O. Always deepenable — merge the modules and test through the new interface directly. No adapter needed.
### 2. Local-substitutable
Dependencies that have local test stand-ins (PGLite for Postgres, in-memory filesystem). Deepenable if the stand-in exists. The deepened module is tested with the stand-in running in the test suite. The seam is internal; no port at the module's external interface.
### 3. Remote but owned (Ports & Adapters)
Your own services across a network boundary (microservices, internal APIs). Define a **port** (interface) at the seam. The deep module owns the logic; the transport is injected as an **adapter**. Tests use an in-memory adapter. Production uses an HTTP/gRPC/queue adapter.
Recommendation shape: *"Define a port at the seam, implement an HTTP adapter for production and an in-memory adapter for testing, so the logic sits in one deep module even though it's deployed across a network."*
### 4. True external (Mock)
Third-party services (Stripe, Twilio, etc.) you don't control. The deepened module takes the external dependency as an injected port; tests provide a mock adapter.
## Seam discipline
- **One adapter means a hypothetical seam. Two adapters means a real one.** Don't introduce a port unless at least two adapters are justified (typically production + test). A single-adapter seam is just indirection.
- **Internal seams vs external seams.** A deep module can have internal seams (private to its implementation, used by its own tests) as well as the external seam at its interface. Don't expose internal seams through the interface just because tests use them.
## Testing strategy: replace, don't layer
- Old unit tests on shallow modules become waste once tests at the deepened module's interface exist — delete them.
- Write new tests at the deepened module's interface. The **interface is the test surface**.
- Tests assert on observable outcomes through the interface, not internal state.
- Tests should survive internal refactors — they describe behaviour, not implementation. If a test has to change when the implementation changes, it's testing past the interface.

View file

@ -0,0 +1,44 @@
# Interface Design
When the user wants to explore alternative interfaces for a chosen deepening candidate, use this parallel sub-agent pattern. Based on "Design It Twice" (Ousterhout) — your first idea is unlikely to be the best.
Uses the vocabulary in [LANGUAGE.md](LANGUAGE.md) — **module**, **interface**, **seam**, **adapter**, **leverage**.
## Process
### 1. Frame the problem space
Before spawning sub-agents, write a user-facing explanation of the problem space for the chosen candidate:
- The constraints any new interface would need to satisfy
- The dependencies it would rely on, and which category they fall into (see [DEEPENING.md](DEEPENING.md))
- A rough illustrative code sketch to ground the constraints — not a proposal, just a way to make the constraints concrete
Show this to the user, then immediately proceed to Step 2. The user reads and thinks while the sub-agents work in parallel.
### 2. Spawn sub-agents
Spawn 3+ sub-agents in parallel using the Agent tool. Each must produce a **radically different** interface for the deepened module.
Prompt each sub-agent with a separate technical brief (file paths, coupling details, dependency category from [DEEPENING.md](DEEPENING.md), what sits behind the seam). The brief is independent of the user-facing problem-space explanation in Step 1. Give each agent a different design constraint:
- Agent 1: "Minimize the interface — aim for 13 entry points max. Maximise leverage per entry point."
- Agent 2: "Maximise flexibility — support many use cases and extension."
- Agent 3: "Optimise for the most common caller — make the default case trivial."
- Agent 4 (if applicable): "Design around ports & adapters for cross-seam dependencies."
Include both [LANGUAGE.md](LANGUAGE.md) vocabulary and CONTEXT.md vocabulary in the brief so each sub-agent names things consistently with the architecture language and the project's domain language.
Each sub-agent outputs:
1. Interface (types, methods, params — plus invariants, ordering, error modes)
2. Usage example showing how callers use it
3. What the implementation hides behind the seam
4. Dependency strategy and adapters (see [DEEPENING.md](DEEPENING.md))
5. Trade-offs — where leverage is high, where it's thin
### 3. Present and compare
Present designs sequentially so the user can absorb each one, then compare them in prose. Contrast by **depth** (leverage at the interface), **locality** (where change concentrates), and **seam placement**.
After comparing, give your own recommendation: which design you think is strongest and why. If elements from different designs would combine well, propose a hybrid. Be opinionated — the user wants a strong read, not a menu.

View file

@ -0,0 +1,53 @@
# Language
Shared vocabulary for every suggestion this skill makes. Use these terms exactly — don't substitute "component," "service," "API," or "boundary." Consistent language is the whole point.
## Terms
**Module**
Anything with an interface and an implementation. Deliberately scale-agnostic — applies equally to a function, class, package, or tier-spanning slice.
_Avoid_: unit, component, service.
**Interface**
Everything a caller must know to use the module correctly. Includes the type signature, but also invariants, ordering constraints, error modes, required configuration, and performance characteristics.
_Avoid_: API, signature (too narrow — those refer only to the type-level surface).
**Implementation**
What's inside a module — its body of code. Distinct from **Adapter**: a thing can be a small adapter with a large implementation (a Postgres repo) or a large adapter with a small implementation (an in-memory fake). Reach for "adapter" when the seam is the topic; "implementation" otherwise.
**Depth**
Leverage at the interface — the amount of behaviour a caller (or test) can exercise per unit of interface they have to learn. A module is **deep** when a large amount of behaviour sits behind a small interface. A module is **shallow** when the interface is nearly as complex as the implementation.
**Seam** _(from Michael Feathers)_
A place where you can alter behaviour without editing in that place. The *location* at which a module's interface lives. Choosing where to put the seam is its own design decision, distinct from what goes behind it.
_Avoid_: boundary (overloaded with DDD's bounded context).
**Adapter**
A concrete thing that satisfies an interface at a seam. Describes *role* (what slot it fills), not substance (what's inside).
**Leverage**
What callers get from depth. More capability per unit of interface they have to learn. One implementation pays back across N call sites and M tests.
**Locality**
What maintainers get from depth. Change, bugs, knowledge, and verification concentrate at one place rather than spreading across callers. Fix once, fixed everywhere.
## Principles
- **Depth is a property of the interface, not the implementation.** A deep module can be internally composed of small, mockable, swappable parts — they just aren't part of the interface. A module can have **internal seams** (private to its implementation, used by its own tests) as well as the **external seam** at its interface.
- **The deletion test.** Imagine deleting the module. If complexity vanishes, the module wasn't hiding anything (it was a pass-through). If complexity reappears across N callers, the module was earning its keep.
- **The interface is the test surface.** Callers and tests cross the same seam. If you want to test *past* the interface, the module is probably the wrong shape.
- **One adapter means a hypothetical seam. Two adapters means a real one.** Don't introduce a seam unless something actually varies across it.
## Relationships
- A **Module** has exactly one **Interface** (the surface it presents to callers and tests).
- **Depth** is a property of a **Module**, measured against its **Interface**.
- A **Seam** is where a **Module**'s **Interface** lives.
- An **Adapter** sits at a **Seam** and satisfies the **Interface**.
- **Depth** produces **Leverage** for callers and **Locality** for maintainers.
## Rejected framings
- **Depth as ratio of implementation-lines to interface-lines** (Ousterhout): rewards padding the implementation. We use depth-as-leverage instead.
- **"Interface" as the TypeScript `interface` keyword or a class's public methods**: too narrow — interface here includes every fact a caller must know.
- **"Boundary"**: overloaded with DDD's bounded context. Say **seam** or **interface**.

View file

@ -0,0 +1,71 @@
---
name: improve-codebase-architecture
description: Find deepening opportunities in a codebase, informed by the domain language in CONTEXT.md and the decisions in docs/adr/. Use when the user wants to improve architecture, find refactoring opportunities, consolidate tightly-coupled modules, or make a codebase more testable and AI-navigable.
---
# Improve Codebase Architecture
Surface architectural friction and propose **deepening opportunities** — refactors that turn shallow modules into deep ones. The aim is testability and AI-navigability.
## Glossary
Use these terms exactly in every suggestion. Consistent language is the point — don't drift into "component," "service," "API," or "boundary." Full definitions in [LANGUAGE.md](LANGUAGE.md).
- **Module** — anything with an interface and an implementation (function, class, package, slice).
- **Interface** — everything a caller must know to use the module: types, invariants, error modes, ordering, config. Not just the type signature.
- **Implementation** — the code inside.
- **Depth** — leverage at the interface: a lot of behaviour behind a small interface. **Deep** = high leverage. **Shallow** = interface nearly as complex as the implementation.
- **Seam** — where an interface lives; a place behaviour can be altered without editing in place. (Use this, not "boundary.")
- **Adapter** — a concrete thing satisfying an interface at a seam.
- **Leverage** — what callers get from depth.
- **Locality** — what maintainers get from depth: change, bugs, knowledge concentrated in one place.
Key principles (see [LANGUAGE.md](LANGUAGE.md) for the full list):
- **Deletion test**: imagine deleting the module. If complexity vanishes, it was a pass-through. If complexity reappears across N callers, it was earning its keep.
- **The interface is the test surface.**
- **One adapter = hypothetical seam. Two adapters = real seam.**
This skill is _informed_ by the project's domain model. The domain language gives names to good seams; ADRs record decisions the skill should not re-litigate.
## Process
### 1. Explore
Read the project's domain glossary and any ADRs in the area you're touching first.
Then use the Agent tool with `subagent_type=Explore` to walk the codebase. Don't follow rigid heuristics — explore organically and note where you experience friction:
- Where does understanding one concept require bouncing between many small modules?
- Where are modules **shallow** — interface nearly as complex as the implementation?
- Where have pure functions been extracted just for testability, but the real bugs hide in how they're called (no **locality**)?
- Where do tightly-coupled modules leak across their seams?
- Which parts of the codebase are untested, or hard to test through their current interface?
Apply the **deletion test** to anything you suspect is shallow: would deleting it concentrate complexity, or just move it? A "yes, concentrates" is the signal you want.
### 2. Present candidates
Present a numbered list of deepening opportunities. For each candidate:
- **Files** — which files/modules are involved
- **Problem** — why the current architecture is causing friction
- **Solution** — plain English description of what would change
- **Benefits** — explained in terms of locality and leverage, and also in how tests would improve
**Use CONTEXT.md vocabulary for the domain, and [LANGUAGE.md](LANGUAGE.md) vocabulary for the architecture.** If `CONTEXT.md` defines "Order," talk about "the Order intake module" — not "the FooBarHandler," and not "the Order service."
**ADR conflicts**: if a candidate contradicts an existing ADR, only surface it when the friction is real enough to warrant revisiting the ADR. Mark it clearly (e.g. _"contradicts ADR-0007 — but worth reopening because…"_). Don't list every theoretical refactor an ADR forbids.
Do NOT propose interfaces yet. Ask the user: "Which of these would you like to explore?"
### 3. Grilling loop
Once the user picks a candidate, drop into a grilling conversation. Walk the design tree with them — constraints, dependencies, the shape of the deepened module, what sits behind the seam, what tests survive.
Side effects happen inline as decisions crystallize:
- **Naming a deepened module after a concept not in `CONTEXT.md`?** Add the term to `CONTEXT.md` — same discipline as `/grill-with-docs` (see [CONTEXT-FORMAT.md](../grill-with-docs/CONTEXT-FORMAT.md)). Create the file lazily if it doesn't exist.
- **Sharpening a fuzzy term during the conversation?** Update `CONTEXT.md` right there.
- **User rejects the candidate with a load-bearing reason?** Offer an ADR, framed as: _"Want me to record this as an ADR so future architecture reviews don't re-suggest it?"_ Only offer when the reason would actually be needed by a future explorer to avoid re-suggesting the same thing — skip ephemeral reasons ("not worth it right now") and self-evident ones. See [ADR-FORMAT.md](../grill-with-docs/ADR-FORMAT.md).
- **Want to explore alternative interfaces for the deepened module?** See [INTERFACE-DESIGN.md](INTERFACE-DESIGN.md).

1
.gitignore vendored
View file

@ -15,3 +15,4 @@ surfsense_web/playwright/.auth/
surfsense_web/playwright-report/ surfsense_web/playwright-report/
surfsense_web/test-results/ surfsense_web/test-results/
surfsense_web/blob-report/ surfsense_web/blob-report/
hermes-agent/

View file

@ -1 +1 @@
0.0.22 0.0.23

View file

@ -46,6 +46,12 @@
"sourceType": "github", "sourceType": "github",
"computedHash": "ddd61f32254be1303ce4b7be5d507c932de4af53489a0ebb1309bf61de99018c" "computedHash": "ddd61f32254be1303ce4b7be5d507c932de4af53489a0ebb1309bf61de99018c"
}, },
"improve-codebase-architecture": {
"source": "mattpocock/skills",
"sourceType": "github",
"skillPath": "skills/engineering/improve-codebase-architecture/SKILL.md",
"computedHash": "2da1d23b8f53cfe67f2e0b68924ab9f4ec400bb6480de097007eeaeb517d1722"
},
"internal-linking-optimizer": { "internal-linking-optimizer": {
"source": "aaron-he-zhu/seo-geo-claude-skills", "source": "aaron-he-zhu/seo-geo-claude-skills",
"sourceType": "github", "sourceType": "github",

View file

@ -2,6 +2,6 @@
from __future__ import annotations from __future__ import annotations
from .main_agent import create_surfsense_deep_agent from .main_agent import create_multi_agent_chat_deep_agent
__all__ = ["create_surfsense_deep_agent"] __all__ = ["create_multi_agent_chat_deep_agent"]

View file

@ -2,6 +2,6 @@
from __future__ import annotations from __future__ import annotations
from .runtime import create_surfsense_deep_agent from .runtime import create_multi_agent_chat_deep_agent
__all__ = ["create_surfsense_deep_agent"] __all__ = ["create_multi_agent_chat_deep_agent"]

View file

@ -11,6 +11,9 @@ from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer from langgraph.types import Checkpointer
from app.agents.multi_agent_chat.middleware import (
build_main_agent_deepagent_middleware,
)
from app.agents.multi_agent_chat.subagents.shared.permissions import ( from app.agents.multi_agent_chat.subagents.shared.permissions import (
ToolsPermissions, ToolsPermissions,
) )
@ -19,8 +22,6 @@ from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.filesystem_selection import FilesystemMode from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.db import ChatVisibility from app.db import ChatVisibility
from .middleware import build_main_agent_deepagent_middleware
def build_compiled_agent_graph_sync( def build_compiled_agent_graph_sync(
*, *,

View file

@ -1,7 +0,0 @@
"""Main-agent graph middleware assembly (SurfSense + LangChain + deepagents)."""
from __future__ import annotations
from .deepagent_stack import build_main_agent_deepagent_middleware
__all__ = ["build_main_agent_deepagent_middleware"]

View file

@ -1,506 +0,0 @@
"""Assemble the main-agent deep-agent middleware list (LangChain + SurfSense + deepagents)."""
from __future__ import annotations
import logging
from collections.abc import Sequence
from typing import Any
from deepagents import SubAgent
from deepagents.backends import StateBackend
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
from deepagents.middleware.skills import SkillsMiddleware
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
from langchain.agents.middleware import (
LLMToolSelectorMiddleware,
ModelCallLimitMiddleware,
ModelFallbackMiddleware,
TodoListMiddleware,
ToolCallLimitMiddleware,
)
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
from app.agents.multi_agent_chat.subagents import (
build_subagents,
get_subagents_to_exclude,
)
from app.agents.multi_agent_chat.subagents.shared.permissions import (
ToolsPermissions,
)
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware import (
ActionLogMiddleware,
AnonymousDocumentMiddleware,
BusyMutexMiddleware,
ClearToolUsesEdit,
DedupHITLToolCallsMiddleware,
DoomLoopMiddleware,
FileIntentMiddleware,
KnowledgeBasePersistenceMiddleware,
KnowledgePriorityMiddleware,
KnowledgeTreeMiddleware,
MemoryInjectionMiddleware,
NoopInjectionMiddleware,
OtelSpanMiddleware,
PermissionMiddleware,
RetryAfterMiddleware,
SpillingContextEditingMiddleware,
SpillToBackendEdit,
SurfSenseFilesystemMiddleware,
ToolCallNameRepairMiddleware,
build_skills_backend_factory,
create_surfsense_compaction_middleware,
default_skills_sources,
)
from app.agents.new_chat.permissions import Rule, Ruleset
from app.agents.new_chat.plugin_loader import (
PluginContext,
load_allowed_plugin_names_from_env,
load_plugin_middlewares,
)
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
from app.db import ChatVisibility
from ...context_prune.prune_tool_names import safe_exclude_tools
from .checkpointed_subagent_middleware import SurfSenseCheckpointedSubAgentMiddleware
def build_main_agent_deepagent_middleware(
*,
llm: BaseChatModel,
tools: Sequence[BaseTool],
backend_resolver: Any,
filesystem_mode: FilesystemMode,
search_space_id: int,
user_id: str | None,
thread_id: int | None,
visibility: ChatVisibility,
anon_session_id: str | None,
available_connectors: list[str] | None,
available_document_types: list[str] | None,
mentioned_document_ids: list[int] | None,
max_input_tokens: int | None,
flags: AgentFeatureFlags,
subagent_dependencies: dict[str, Any],
checkpointer: Checkpointer,
mcp_tools_by_agent: dict[str, ToolsPermissions] | None = None,
disabled_tools: list[str] | None = None,
) -> list[Any]:
"""Build ordered middleware for ``create_agent`` (Nones already stripped)."""
_memory_middleware = MemoryInjectionMiddleware(
user_id=user_id,
search_space_id=search_space_id,
thread_visibility=visibility,
)
gp_middleware = [
TodoListMiddleware(),
_memory_middleware,
FileIntentMiddleware(llm=llm),
SurfSenseFilesystemMiddleware(
backend=backend_resolver,
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
created_by_id=user_id,
thread_id=thread_id,
),
create_surfsense_compaction_middleware(llm, StateBackend),
PatchToolCallsMiddleware(),
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
]
# Build permission rulesets up front so the GP subagent can mirror ``ask``
# rules into ``interrupt_on``: tool calls emitted from within ``task`` runs
# never reach the parent's ``PermissionMiddleware``.
is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER
permission_enabled = flags.enable_permission and not flags.disable_new_agent_stack
permission_rulesets: list[Ruleset] = []
if permission_enabled or is_desktop_fs:
permission_rulesets.append(
Ruleset(
rules=[Rule(permission="*", pattern="*", action="allow")],
origin="surfsense_defaults",
)
)
if is_desktop_fs:
permission_rulesets.append(
Ruleset(
rules=[
Rule(permission="rm", pattern="*", action="ask"),
Rule(permission="rmdir", pattern="*", action="ask"),
Rule(permission="move_file", pattern="*", action="ask"),
Rule(permission="edit_file", pattern="*", action="ask"),
Rule(permission="write_file", pattern="*", action="ask"),
],
origin="desktop_safety",
)
)
# Tools that self-prompt via ``request_approval`` must not also appear
# as ``ask`` rules — that would double-prompt the user for one call.
_tool_names_in_use = {t.name for t in tools}
# Deny parent-bound tools whose ``required_connector`` is missing.
# No-op today (connector subagents are pruned upstream); guards future
# additions to the parent's tool list.
if permission_enabled:
_available_set = set(available_connectors or [])
_synthesized: list[Rule] = []
for tool_def in BUILTIN_TOOLS:
if tool_def.name not in _tool_names_in_use:
continue
rc = tool_def.required_connector
if rc and rc not in _available_set:
_synthesized.append(
Rule(permission=tool_def.name, pattern="*", action="deny")
)
if _synthesized:
permission_rulesets.append(
Ruleset(rules=_synthesized, origin="connector_synthesized")
)
gp_interrupt_on: dict[str, bool] = {
rule.permission: True
for rs in permission_rulesets
for rule in rs.rules
if rule.action == "ask" and rule.permission in _tool_names_in_use
}
general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key]
**GENERAL_PURPOSE_SUBAGENT,
"model": llm,
"tools": tools,
"middleware": gp_middleware,
}
if gp_interrupt_on:
general_purpose_spec["interrupt_on"] = gp_interrupt_on
# Deny-only on subagents: ``task`` runs bypass the parent's
# PermissionMiddleware, while bucket-based ask gates own the ask path.
subagent_deny_rulesets: list[Ruleset] = [
Ruleset(
rules=[r for r in rs.rules if r.action == "deny"],
origin=rs.origin,
)
for rs in permission_rulesets
]
subagent_deny_rulesets = [rs for rs in subagent_deny_rulesets if rs.rules]
subagent_deny_permission_mw: PermissionMiddleware | None = (
PermissionMiddleware(rulesets=subagent_deny_rulesets)
if subagent_deny_rulesets
else None
)
if subagent_deny_permission_mw is not None:
# Run deny check on already-repaired tool calls; insert before
# PatchToolCallsMiddleware (append if the slot moves).
_patch_idx = next(
(
i
for i, m in enumerate(gp_middleware)
if isinstance(m, PatchToolCallsMiddleware)
),
len(gp_middleware),
)
gp_middleware.insert(_patch_idx, subagent_deny_permission_mw)
registry_subagents: list[SubAgent] = []
try:
subagent_extra_middleware: list[Any] = [
TodoListMiddleware(),
SurfSenseFilesystemMiddleware(
backend=backend_resolver,
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
created_by_id=user_id,
thread_id=thread_id,
),
]
if subagent_deny_permission_mw is not None:
subagent_extra_middleware.append(subagent_deny_permission_mw)
registry_subagents = build_subagents(
dependencies=subagent_dependencies,
model=llm,
extra_middleware=subagent_extra_middleware,
mcp_tools_by_agent=mcp_tools_by_agent or {},
exclude=get_subagents_to_exclude(available_connectors),
disabled_tools=disabled_tools,
)
logging.info(
"Registry subagents: %s",
[s["name"] for s in registry_subagents],
)
except Exception:
logging.exception("Registry subagent build failed")
raise
subagent_specs: list[SubAgent] = [general_purpose_spec, *registry_subagents]
summarization_mw = create_surfsense_compaction_middleware(llm, StateBackend)
context_edit_mw = None
if (
flags.enable_context_editing
and not flags.disable_new_agent_stack
and max_input_tokens
):
spill_edit = SpillToBackendEdit(
trigger=int(max_input_tokens * 0.55),
clear_at_least=int(max_input_tokens * 0.15),
keep=5,
exclude_tools=safe_exclude_tools(tools),
clear_tool_inputs=True,
)
clear_edit = ClearToolUsesEdit(
trigger=int(max_input_tokens * 0.55),
clear_at_least=int(max_input_tokens * 0.15),
keep=5,
exclude_tools=safe_exclude_tools(tools),
clear_tool_inputs=True,
placeholder="[cleared - older tool output trimmed for context]",
)
context_edit_mw = SpillingContextEditingMiddleware(
edits=[spill_edit, clear_edit],
backend_resolver=backend_resolver,
)
retry_mw = (
RetryAfterMiddleware(max_retries=3)
if flags.enable_retry_after and not flags.disable_new_agent_stack
else None
)
fallback_mw: ModelFallbackMiddleware | None = None
if flags.enable_model_fallback and not flags.disable_new_agent_stack:
try:
fallback_mw = ModelFallbackMiddleware(
"openai:gpt-4o-mini",
"anthropic:claude-3-5-haiku-20241022",
)
except Exception:
logging.warning("ModelFallbackMiddleware init failed; skipping.")
fallback_mw = None
model_call_limit_mw = (
ModelCallLimitMiddleware(
thread_limit=120,
run_limit=80,
exit_behavior="end",
)
if flags.enable_model_call_limit and not flags.disable_new_agent_stack
else None
)
tool_call_limit_mw = (
ToolCallLimitMiddleware(
thread_limit=300, run_limit=80, exit_behavior="continue"
)
if flags.enable_tool_call_limit and not flags.disable_new_agent_stack
else None
)
noop_mw = (
NoopInjectionMiddleware()
if flags.enable_compaction_v2 and not flags.disable_new_agent_stack
else None
)
repair_mw = None
if flags.enable_tool_call_repair and not flags.disable_new_agent_stack:
registered_names: set[str] = {t.name for t in tools}
registered_names |= {
"write_todos",
"ls",
"read_file",
"write_file",
"edit_file",
"glob",
"grep",
"execute",
"task",
"mkdir",
"cd",
"pwd",
"move_file",
"rm",
"rmdir",
"list_tree",
"execute_code",
}
repair_mw = ToolCallNameRepairMiddleware(
registered_tool_names=registered_names,
fuzzy_match_threshold=None,
)
doom_loop_mw = (
DoomLoopMiddleware(threshold=3)
if flags.enable_doom_loop and not flags.disable_new_agent_stack
else None
)
permission_mw: PermissionMiddleware | None = (
PermissionMiddleware(rulesets=permission_rulesets)
if permission_rulesets
else None
)
action_log_mw: ActionLogMiddleware | None = None
if (
flags.enable_action_log
and not flags.disable_new_agent_stack
and thread_id is not None
):
try:
tool_defs_by_name = {td.name: td for td in BUILTIN_TOOLS}
action_log_mw = ActionLogMiddleware(
thread_id=thread_id,
search_space_id=search_space_id,
user_id=user_id,
tool_definitions=tool_defs_by_name,
)
except Exception: # pragma: no cover - defensive
logging.warning(
"ActionLogMiddleware init failed; running without it.",
exc_info=True,
)
action_log_mw = None
busy_mutex_mw: BusyMutexMiddleware | None = (
BusyMutexMiddleware()
if flags.enable_busy_mutex and not flags.disable_new_agent_stack
else None
)
otel_mw: OtelSpanMiddleware | None = (
OtelSpanMiddleware()
if flags.enable_otel and not flags.disable_new_agent_stack
else None
)
plugin_middlewares: list[Any] = []
if flags.enable_plugin_loader and not flags.disable_new_agent_stack:
try:
allowed_names = load_allowed_plugin_names_from_env()
if allowed_names:
plugin_middlewares = load_plugin_middlewares(
PluginContext.build(
search_space_id=search_space_id,
user_id=user_id,
thread_visibility=visibility,
llm=llm,
),
allowed_plugin_names=allowed_names,
)
except Exception: # pragma: no cover - defensive
logging.warning(
"Plugin loader failed; continuing without plugins.",
exc_info=True,
)
plugin_middlewares = []
skills_mw: SkillsMiddleware | None = None
if flags.enable_skills and not flags.disable_new_agent_stack:
try:
skills_factory = build_skills_backend_factory(
search_space_id=search_space_id
if filesystem_mode == FilesystemMode.CLOUD
else None,
)
skills_mw = SkillsMiddleware(
backend=skills_factory,
sources=default_skills_sources(),
)
except Exception as exc: # pragma: no cover - defensive
logging.warning("SkillsMiddleware init failed; skipping: %s", exc)
skills_mw = None
selector_mw: LLMToolSelectorMiddleware | None = None
if (
flags.enable_llm_tool_selector
and not flags.disable_new_agent_stack
and len(tools) > 30
):
try:
selector_mw = LLMToolSelectorMiddleware(
model="openai:gpt-4o-mini",
max_tools=12,
always_include=[
name
for name in (
"update_memory",
"get_connected_accounts",
"scrape_webpage",
)
if name in {t.name for t in tools}
],
)
except Exception:
logging.warning("LLMToolSelectorMiddleware init failed; skipping.")
selector_mw = None
deepagent_middleware = [
busy_mutex_mw,
otel_mw,
TodoListMiddleware(),
_memory_middleware,
AnonymousDocumentMiddleware(
anon_session_id=anon_session_id,
)
if filesystem_mode == FilesystemMode.CLOUD
else None,
KnowledgeTreeMiddleware(
search_space_id=search_space_id,
filesystem_mode=filesystem_mode,
llm=llm,
)
if filesystem_mode == FilesystemMode.CLOUD
else None,
KnowledgePriorityMiddleware(
llm=llm,
search_space_id=search_space_id,
filesystem_mode=filesystem_mode,
available_connectors=available_connectors,
available_document_types=available_document_types,
mentioned_document_ids=mentioned_document_ids,
),
FileIntentMiddleware(llm=llm),
SurfSenseFilesystemMiddleware(
backend=backend_resolver,
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
created_by_id=user_id,
thread_id=thread_id,
),
KnowledgeBasePersistenceMiddleware(
search_space_id=search_space_id,
created_by_id=user_id,
filesystem_mode=filesystem_mode,
thread_id=thread_id,
)
if filesystem_mode == FilesystemMode.CLOUD
else None,
skills_mw,
SurfSenseCheckpointedSubAgentMiddleware(
checkpointer=checkpointer,
backend=StateBackend,
subagents=subagent_specs,
),
selector_mw,
model_call_limit_mw,
tool_call_limit_mw,
context_edit_mw,
summarization_mw,
noop_mw,
retry_mw,
fallback_mw,
repair_mw,
permission_mw,
doom_loop_mw,
action_log_mw,
PatchToolCallsMiddleware(),
DedupHITLToolCallsMiddleware(agent_tools=list(tools)),
*plugin_middlewares,
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
]
return [m for m in deepagent_middleware if m is not None]

View file

@ -2,6 +2,6 @@
from __future__ import annotations from __future__ import annotations
from .factory import create_surfsense_deep_agent from .factory import create_multi_agent_chat_deep_agent
__all__ = ["create_surfsense_deep_agent"] __all__ = ["create_multi_agent_chat_deep_agent"]

View file

@ -0,0 +1,117 @@
"""Compiled agent graph caching for the multi-agent path."""
from __future__ import annotations
import asyncio
from collections.abc import Sequence
from typing import Any
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
from app.agents.multi_agent_chat.subagents.shared.permissions import ToolsPermissions
from app.agents.new_chat.agent_cache import (
flags_signature,
get_cache,
stable_hash,
system_prompt_hash,
tools_signature,
)
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.db import ChatVisibility
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
def mcp_signature(mcp_tools_by_agent: dict[str, ToolsPermissions]) -> str:
"""Hash the per-agent MCP tool surface so a change rotates the cache key."""
rows = []
for agent_name in sorted(mcp_tools_by_agent.keys()):
perms = mcp_tools_by_agent[agent_name]
allow_names = sorted(item.get("name", "") for item in perms.get("allow", []))
ask_names = sorted(item.get("name", "") for item in perms.get("ask", []))
rows.append((agent_name, allow_names, ask_names))
return stable_hash(rows)
async def build_agent_with_cache(
*,
llm: BaseChatModel,
tools: Sequence[BaseTool],
final_system_prompt: str,
backend_resolver: Any,
filesystem_mode: FilesystemMode,
search_space_id: int,
user_id: str | None,
thread_id: int | None,
visibility: ChatVisibility,
anon_session_id: str | None,
available_connectors: list[str],
available_document_types: list[str],
mentioned_document_ids: list[int] | None,
max_input_tokens: int | None,
flags: AgentFeatureFlags,
checkpointer: Checkpointer,
subagent_dependencies: dict[str, Any],
mcp_tools_by_agent: dict[str, ToolsPermissions],
disabled_tools: list[str] | None,
config_id: str | None,
) -> Any:
"""Compile the multi-agent graph, serving from cache when key components are stable."""
async def _build() -> Any:
return await asyncio.to_thread(
build_compiled_agent_graph_sync,
llm=llm,
tools=tools,
final_system_prompt=final_system_prompt,
backend_resolver=backend_resolver,
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
user_id=user_id,
thread_id=thread_id,
visibility=visibility,
anon_session_id=anon_session_id,
available_connectors=available_connectors,
available_document_types=available_document_types,
mentioned_document_ids=mentioned_document_ids,
max_input_tokens=max_input_tokens,
flags=flags,
checkpointer=checkpointer,
subagent_dependencies=subagent_dependencies,
mcp_tools_by_agent=mcp_tools_by_agent,
disabled_tools=disabled_tools,
)
if not (flags.enable_agent_cache and not flags.disable_new_agent_stack):
return await _build()
# Every per-request value any middleware closes over at __init__ must be in
# the key, otherwise a hit will leak state across threads. Bump the schema
# version when the component list changes shape.
cache_key = stable_hash(
"multi-agent-v1",
config_id,
thread_id,
user_id,
search_space_id,
visibility,
filesystem_mode,
anon_session_id,
tools_signature(
tools,
available_connectors=available_connectors,
available_document_types=available_document_types,
),
mcp_signature(mcp_tools_by_agent),
flags_signature(flags),
system_prompt_hash(final_system_prompt),
max_input_tokens,
sorted(disabled_tools) if disabled_tools else None,
)
return await get_cache().get_or_build(cache_key, builder=_build)
__all__ = ["build_agent_with_cache", "mcp_signature"]

View file

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
import time import time
from collections.abc import Sequence from collections.abc import Sequence
@ -26,23 +25,24 @@ from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
from app.agents.new_chat.filesystem_backends import build_backend_resolver from app.agents.new_chat.filesystem_backends import build_backend_resolver
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
from app.agents.new_chat.llm_config import AgentConfig from app.agents.new_chat.llm_config import AgentConfig
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_tool from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_tool
from app.agents.new_chat.tools.registry import build_tools_async from app.agents.new_chat.tools.registry import build_tools_async
from app.db import ChatVisibility from app.db import ChatVisibility
from app.services.connector_service import ConnectorService from app.services.connector_service import ConnectorService
from app.utils.perf import get_perf_logger from app.utils.perf import get_perf_logger
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
from ..system_prompt import build_main_agent_system_prompt from ..system_prompt import build_main_agent_system_prompt
from ..tools import ( from ..tools import (
MAIN_AGENT_SURFSENSE_TOOL_NAMES, MAIN_AGENT_SURFSENSE_TOOL_NAMES,
MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED, MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED,
) )
from .agent_cache import build_agent_with_cache
_perf_log = get_perf_logger() _perf_log = get_perf_logger()
async def create_surfsense_deep_agent( async def create_multi_agent_chat_deep_agent(
llm: BaseChatModel, llm: BaseChatModel,
search_space_id: int, search_space_id: int,
db_session: AsyncSession, db_session: AsyncSession,
@ -62,6 +62,9 @@ async def create_surfsense_deep_agent(
): ):
"""Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled.""" """Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled."""
_t_agent_total = time.perf_counter() _t_agent_total = time.perf_counter()
apply_litellm_prompt_caching(llm, agent_config=agent_config, thread_id=thread_id)
filesystem_selection = filesystem_selection or FilesystemSelection() filesystem_selection = filesystem_selection or FilesystemSelection()
backend_resolver = build_backend_resolver( backend_resolver = build_backend_resolver(
filesystem_selection, filesystem_selection,
@ -85,7 +88,18 @@ async def create_surfsense_deep_agent(
) )
except Exception as e: except Exception as e:
logging.warning("Failed to discover available connectors/document types: %s", e) logging.warning(
"Connector/doc-type discovery failed; excluding connector subagents this turn: %s",
e,
)
# Fail closed: a None list short-circuits ``get_subagents_to_exclude`` to "exclude
# nothing", which would silently advertise every connector specialist on a flaky
# discovery call. Empty list excludes connector-gated subagents while keeping builtins.
if available_connectors is None:
available_connectors = []
if available_document_types is None:
available_document_types = []
_perf_log.info( _perf_log.info(
"[create_agent] Connector/doc-type discovery in %.3fs", "[create_agent] Connector/doc-type discovery in %.3fs",
time.perf_counter() - _t0, time.perf_counter() - _t0,
@ -115,7 +129,18 @@ async def create_surfsense_deep_agent(
} }
_t0 = time.perf_counter() _t0 = time.perf_counter()
mcp_tools_by_agent = await load_mcp_tools_by_connector(db_session, search_space_id) try:
mcp_tools_by_agent = await load_mcp_tools_by_connector(
db_session, search_space_id
)
except Exception as e:
# Degrade to builtins-only rather than aborting the turn: a transient
# DB or MCP-server hiccup should not deny the user a response.
logging.warning(
"MCP tool discovery failed; subagents will run without MCP tools this turn: %s",
e,
)
mcp_tools_by_agent = {}
_perf_log.info( _perf_log.info(
"[create_agent] load_mcp_tools_by_connector in %.3fs (%d buckets)", "[create_agent] load_mcp_tools_by_connector in %.3fs (%d buckets)",
time.perf_counter() - _t0, time.perf_counter() - _t0,
@ -195,9 +220,10 @@ async def create_surfsense_deep_agent(
final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT
config_id = agent_config.config_id if agent_config is not None else None
_t0 = time.perf_counter() _t0 = time.perf_counter()
agent = await asyncio.to_thread( agent = await build_agent_with_cache(
build_compiled_agent_graph_sync,
llm=llm, llm=llm,
tools=tools, tools=tools,
final_system_prompt=final_system_prompt, final_system_prompt=final_system_prompt,
@ -217,6 +243,7 @@ async def create_surfsense_deep_agent(
subagent_dependencies=dependencies, subagent_dependencies=dependencies,
mcp_tools_by_agent=mcp_tools_by_agent, mcp_tools_by_agent=mcp_tools_by_agent,
disabled_tools=disabled_tools, disabled_tools=disabled_tools,
config_id=config_id,
) )
_perf_log.info( _perf_log.info(
"[create_agent] Middleware stack + graph compiled in %.3fs", "[create_agent] Middleware stack + graph compiled in %.3fs",

View file

@ -0,0 +1,7 @@
"""Multi-agent middleware stack assembly."""
from __future__ import annotations
from .stack import build_main_agent_deepagent_middleware
__all__ = ["build_main_agent_deepagent_middleware"]

View file

@ -0,0 +1,36 @@
"""Audit row per tool call (reversibility metadata)."""
from __future__ import annotations
import logging
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.middleware import ActionLogMiddleware
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
from ..shared.flags import enabled
def build_action_log_mw(
*,
flags: AgentFeatureFlags,
thread_id: int | None,
search_space_id: int,
user_id: str | None,
) -> ActionLogMiddleware | None:
if not enabled(flags, "enable_action_log") or thread_id is None:
return None
try:
tool_defs_by_name = {td.name: td for td in BUILTIN_TOOLS}
return ActionLogMiddleware(
thread_id=thread_id,
search_space_id=search_space_id,
user_id=user_id,
tool_definitions=tool_defs_by_name,
)
except Exception: # pragma: no cover - defensive
logging.warning(
"ActionLogMiddleware init failed; running without it.",
exc_info=True,
)
return None

View file

@ -0,0 +1,16 @@
"""Anonymous document hydration from Redis (cloud only)."""
from __future__ import annotations
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware import AnonymousDocumentMiddleware
def build_anonymous_doc_mw(
*,
filesystem_mode: FilesystemMode,
anon_session_id: str | None,
) -> AnonymousDocumentMiddleware | None:
if filesystem_mode != FilesystemMode.CLOUD:
return None
return AnonymousDocumentMiddleware(anon_session_id=anon_session_id)

View file

@ -0,0 +1,12 @@
"""Per-thread cooperative lock around the whole turn."""
from __future__ import annotations
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.middleware import BusyMutexMiddleware
from ..shared.flags import enabled
def build_busy_mutex_mw(flags: AgentFeatureFlags) -> BusyMutexMiddleware | None:
return BusyMutexMiddleware() if enabled(flags, "enable_busy_mutex") else None

View file

@ -69,9 +69,16 @@ def build_task_tool_with_parent_config(
raise ValueError(msg) raise ValueError(msg)
state_update = {k: v for k, v in result.items() if k not in EXCLUDED_STATE_KEYS} state_update = {k: v for k, v in result.items() if k not in EXCLUDED_STATE_KEYS}
message_text = ( messages = result["messages"]
result["messages"][-1].text.rstrip() if result["messages"][-1].text else "" if not messages:
) msg = (
"CompiledSubAgent returned an empty 'messages' list. "
"Subagents must produce at least one message so the parent has "
"output to forward back to the user."
)
raise ValueError(msg)
last_text = getattr(messages[-1], "text", None) or ""
message_text = last_text.rstrip()
return Command( return Command(
update={ update={
**state_update, **state_update,

View file

@ -0,0 +1,50 @@
"""Spill + clear-tool-uses passes to keep payloads under budget."""
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
from langchain_core.tools import BaseTool
from app.agents.multi_agent_chat.main_agent.context_prune.prune_tool_names import (
safe_exclude_tools,
)
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.middleware import (
ClearToolUsesEdit,
SpillingContextEditingMiddleware,
SpillToBackendEdit,
)
from ..shared.flags import enabled
def build_context_editing_mw(
*,
flags: AgentFeatureFlags,
max_input_tokens: int | None,
tools: Sequence[BaseTool],
backend_resolver: Any,
) -> SpillingContextEditingMiddleware | None:
if not enabled(flags, "enable_context_editing") or not max_input_tokens:
return None
spill_edit = SpillToBackendEdit(
trigger=int(max_input_tokens * 0.55),
clear_at_least=int(max_input_tokens * 0.15),
keep=5,
exclude_tools=safe_exclude_tools(tools),
clear_tool_inputs=True,
)
clear_edit = ClearToolUsesEdit(
trigger=int(max_input_tokens * 0.55),
clear_at_least=int(max_input_tokens * 0.15),
keep=5,
exclude_tools=safe_exclude_tools(tools),
clear_tool_inputs=True,
placeholder="[cleared - older tool output trimmed for context]",
)
return SpillingContextEditingMiddleware(
edits=[spill_edit, clear_edit],
backend_resolver=backend_resolver,
)

View file

@ -0,0 +1,13 @@
"""Drop duplicate HITL tool calls before execution."""
from __future__ import annotations
from collections.abc import Sequence
from langchain_core.tools import BaseTool
from app.agents.new_chat.middleware import DedupHITLToolCallsMiddleware
def build_dedup_hitl_mw(tools: Sequence[BaseTool]) -> DedupHITLToolCallsMiddleware:
return DedupHITLToolCallsMiddleware(agent_tools=list(tools))

View file

@ -0,0 +1,14 @@
"""Stop N identical tool calls in a row via interrupt."""
from __future__ import annotations
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.middleware import DoomLoopMiddleware
from ..shared.flags import enabled
def build_doom_loop_mw(flags: AgentFeatureFlags) -> DoomLoopMiddleware | None:
return (
DoomLoopMiddleware(threshold=3) if enabled(flags, "enable_doom_loop") else None
)

View file

@ -0,0 +1,23 @@
"""Commit staged cloud filesystem mutations to Postgres at end of turn."""
from __future__ import annotations
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware import KnowledgeBasePersistenceMiddleware
def build_kb_persistence_mw(
*,
filesystem_mode: FilesystemMode,
search_space_id: int,
user_id: str | None,
thread_id: int | None,
) -> KnowledgeBasePersistenceMiddleware | None:
if filesystem_mode != FilesystemMode.CLOUD:
return None
return KnowledgeBasePersistenceMiddleware(
search_space_id=search_space_id,
created_by_id=user_id,
filesystem_mode=filesystem_mode,
thread_id=thread_id,
)

View file

@ -0,0 +1,27 @@
"""KB priority planner: <priority_documents> injection."""
from __future__ import annotations
from langchain_core.language_models import BaseChatModel
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware import KnowledgePriorityMiddleware
def build_knowledge_priority_mw(
*,
llm: BaseChatModel,
search_space_id: int,
filesystem_mode: FilesystemMode,
available_connectors: list[str] | None,
available_document_types: list[str] | None,
mentioned_document_ids: list[int] | None,
) -> KnowledgePriorityMiddleware:
return KnowledgePriorityMiddleware(
llm=llm,
search_space_id=search_space_id,
filesystem_mode=filesystem_mode,
available_connectors=available_connectors,
available_document_types=available_document_types,
mentioned_document_ids=mentioned_document_ids,
)

View file

@ -0,0 +1,23 @@
"""<workspace_tree> injection (cloud only)."""
from __future__ import annotations
from langchain_core.language_models import BaseChatModel
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware import KnowledgeTreeMiddleware
def build_knowledge_tree_mw(
*,
filesystem_mode: FilesystemMode,
search_space_id: int,
llm: BaseChatModel,
) -> KnowledgeTreeMiddleware | None:
if filesystem_mode != FilesystemMode.CLOUD:
return None
return KnowledgeTreeMiddleware(
search_space_id=search_space_id,
filesystem_mode=filesystem_mode,
llm=llm,
)

View file

@ -0,0 +1,12 @@
"""Provider-compat: append a `_noop` tool when tools=[] but history has tool calls."""
from __future__ import annotations
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.middleware import NoopInjectionMiddleware
from ..shared.flags import enabled
def build_noop_injection_mw(flags: AgentFeatureFlags) -> NoopInjectionMiddleware | None:
return NoopInjectionMiddleware() if enabled(flags, "enable_compaction_v2") else None

View file

@ -0,0 +1,12 @@
"""OTel spans on model and tool calls."""
from __future__ import annotations
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.middleware import OtelSpanMiddleware
from ..shared.flags import enabled
def build_otel_mw(flags: AgentFeatureFlags) -> OtelSpanMiddleware | None:
return OtelSpanMiddleware() if enabled(flags, "enable_otel") else None

View file

@ -0,0 +1,49 @@
"""Tail-of-stack plugin slot driven by env allowlist."""
from __future__ import annotations
import logging
from typing import Any
from langchain_core.language_models import BaseChatModel
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.plugin_loader import (
PluginContext,
load_allowed_plugin_names_from_env,
load_plugin_middlewares,
)
from app.db import ChatVisibility
from ..shared.flags import enabled
def build_plugin_middlewares(
*,
flags: AgentFeatureFlags,
search_space_id: int,
user_id: str | None,
visibility: ChatVisibility,
llm: BaseChatModel,
) -> list[Any]:
if not enabled(flags, "enable_plugin_loader"):
return []
try:
allowed_names = load_allowed_plugin_names_from_env()
if not allowed_names:
return []
return load_plugin_middlewares(
PluginContext.build(
search_space_id=search_space_id,
user_id=user_id,
thread_visibility=visibility,
llm=llm,
),
allowed_plugin_names=allowed_names,
)
except Exception: # pragma: no cover - defensive
logging.warning(
"Plugin loader failed; continuing without plugins.",
exc_info=True,
)
return []

View file

@ -0,0 +1,50 @@
"""Repair miscased / unknown tool names to the registered set or invalid_tool."""
from __future__ import annotations
from collections.abc import Sequence
from langchain_core.tools import BaseTool
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.middleware import ToolCallNameRepairMiddleware
from ..shared.flags import enabled
# deepagents-built-in tool names the repair pass treats as known.
_DEEPAGENT_BUILTIN_TOOL_NAMES: frozenset[str] = frozenset(
{
"write_todos",
"ls",
"read_file",
"write_file",
"edit_file",
"glob",
"grep",
"execute",
"task",
"mkdir",
"cd",
"pwd",
"move_file",
"rm",
"rmdir",
"list_tree",
"execute_code",
}
)
def build_repair_mw(
*,
flags: AgentFeatureFlags,
tools: Sequence[BaseTool],
) -> ToolCallNameRepairMiddleware | None:
if not enabled(flags, "enable_tool_call_repair"):
return None
registered_names: set[str] = {t.name for t in tools}
registered_names |= _DEEPAGENT_BUILTIN_TOOL_NAMES
return ToolCallNameRepairMiddleware(
registered_tool_names=registered_names,
fuzzy_match_threshold=None,
)

View file

@ -0,0 +1,39 @@
"""LLM-based tool subset selection (only when >30 tools)."""
from __future__ import annotations
import logging
from collections.abc import Sequence
from langchain.agents.middleware import LLMToolSelectorMiddleware
from langchain_core.tools import BaseTool
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from ..shared.flags import enabled
def build_selector_mw(
*,
flags: AgentFeatureFlags,
tools: Sequence[BaseTool],
) -> LLMToolSelectorMiddleware | None:
if not enabled(flags, "enable_llm_tool_selector") or len(tools) <= 30:
return None
try:
return LLMToolSelectorMiddleware(
model="openai:gpt-4o-mini",
max_tools=12,
always_include=[
name
for name in (
"update_memory",
"get_connected_accounts",
"scrape_webpage",
)
if name in {t.name for t in tools}
],
)
except Exception:
logging.warning("LLMToolSelectorMiddleware init failed; skipping.")
return None

View file

@ -0,0 +1,39 @@
"""Skill discovery + injection."""
from __future__ import annotations
import logging
from deepagents.middleware.skills import SkillsMiddleware
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware import (
build_skills_backend_factory,
default_skills_sources,
)
from ..shared.flags import enabled
def build_skills_mw(
*,
flags: AgentFeatureFlags,
filesystem_mode: FilesystemMode,
search_space_id: int,
) -> SkillsMiddleware | None:
if not enabled(flags, "enable_skills"):
return None
try:
skills_factory = build_skills_backend_factory(
search_space_id=search_space_id
if filesystem_mode == FilesystemMode.CLOUD
else None,
)
return SkillsMiddleware(
backend=skills_factory,
sources=default_skills_sources(),
)
except Exception as exc: # pragma: no cover - defensive
logging.warning("SkillsMiddleware init failed; skipping: %s", exc)
return None

View file

@ -0,0 +1,9 @@
"""Anthropic prompt caching annotations on system/tool/message blocks."""
from __future__ import annotations
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
def build_anthropic_cache_mw() -> AnthropicPromptCachingMiddleware:
return AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")

View file

@ -0,0 +1,14 @@
"""Context-window summarization with SurfSense protected sections."""
from __future__ import annotations
from typing import Any
from deepagents.backends import StateBackend
from langchain_core.language_models import BaseChatModel
from app.agents.new_chat.middleware import create_surfsense_compaction_middleware
def build_compaction_mw(llm: BaseChatModel) -> Any:
return create_surfsense_compaction_middleware(llm, StateBackend)

View file

@ -0,0 +1,11 @@
"""File-intent classifier that gates strict write contracts."""
from __future__ import annotations
from langchain_core.language_models import BaseChatModel
from app.agents.new_chat.middleware import FileIntentMiddleware
def build_file_intent_mw(llm: BaseChatModel) -> FileIntentMiddleware:
return FileIntentMiddleware(llm=llm)

View file

@ -0,0 +1,25 @@
"""SurfSense filesystem tools/middleware."""
from __future__ import annotations
from typing import Any
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware import SurfSenseFilesystemMiddleware
def build_filesystem_mw(
*,
backend_resolver: Any,
filesystem_mode: FilesystemMode,
search_space_id: int,
user_id: str | None,
thread_id: int | None,
) -> SurfSenseFilesystemMiddleware:
return SurfSenseFilesystemMiddleware(
backend=backend_resolver,
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
created_by_id=user_id,
thread_id=thread_id,
)

View file

@ -0,0 +1,10 @@
"""Single source of truth for the feature-flag predicate."""
from __future__ import annotations
from app.agents.new_chat.feature_flags import AgentFeatureFlags
def enabled(flags: AgentFeatureFlags, attr: str) -> bool:
"""``flags.<attr>`` is on AND the new-agent-stack kill switch is off."""
return getattr(flags, attr) and not flags.disable_new_agent_stack

View file

@ -0,0 +1,19 @@
"""User/team memory injection prepended to the conversation."""
from __future__ import annotations
from app.agents.new_chat.middleware import MemoryInjectionMiddleware
from app.db import ChatVisibility
def build_memory_mw(
*,
user_id: str | None,
search_space_id: int,
visibility: ChatVisibility,
) -> MemoryInjectionMiddleware:
return MemoryInjectionMiddleware(
user_id=user_id,
search_space_id=search_space_id,
thread_visibility=visibility,
)

View file

@ -0,0 +1,9 @@
"""Repair dangling tool-call sequences before each agent turn."""
from __future__ import annotations
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
def build_patch_tool_calls_mw() -> PatchToolCallsMiddleware:
return PatchToolCallsMiddleware()

View file

@ -0,0 +1,12 @@
"""Permission rulesets fanned out to parent / general-purpose / subagent stacks."""
from __future__ import annotations
from .context import PermissionContext, build_permission_context
from .middleware import build_full_permission_mw
__all__ = [
"PermissionContext",
"build_full_permission_mw",
"build_permission_context",
]

View file

@ -0,0 +1,107 @@
"""Derive shared permission context once; fan out to all three stack layers.
The context carries:
- ``rulesets``: full ask/deny/allow rules for the main-agent permission middleware.
- ``general_purpose_interrupt_on``: ``ask`` rules mirrored as deepagents
``interrupt_on`` so HITL still triggers from inside ``task`` runs (subagents
bypass the main-agent permission middleware).
- ``subagent_deny_mw``: a deny-only ``PermissionMiddleware`` instance shared
across the general-purpose and registry subagent stacks.
"""
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from langchain_core.tools import BaseTool
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware import PermissionMiddleware
from app.agents.new_chat.permissions import Rule, Ruleset
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
from ..flags import enabled
@dataclass(frozen=True)
class PermissionContext:
rulesets: list[Ruleset]
general_purpose_interrupt_on: dict[str, bool]
subagent_deny_mw: PermissionMiddleware | None
def build_permission_context(
*,
flags: AgentFeatureFlags,
filesystem_mode: FilesystemMode,
tools: Sequence[BaseTool],
available_connectors: list[str] | None,
) -> PermissionContext:
is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER
permission_enabled = enabled(flags, "enable_permission")
rulesets: list[Ruleset] = []
if permission_enabled or is_desktop_fs:
rulesets.append(
Ruleset(
rules=[Rule(permission="*", pattern="*", action="allow")],
origin="surfsense_defaults",
)
)
if is_desktop_fs:
rulesets.append(
Ruleset(
rules=[
Rule(permission="rm", pattern="*", action="ask"),
Rule(permission="rmdir", pattern="*", action="ask"),
Rule(permission="move_file", pattern="*", action="ask"),
Rule(permission="edit_file", pattern="*", action="ask"),
Rule(permission="write_file", pattern="*", action="ask"),
],
origin="desktop_safety",
)
)
tool_names_in_use = {t.name for t in tools}
if permission_enabled:
available_set = set(available_connectors or [])
synthesized: list[Rule] = []
for tool_def in BUILTIN_TOOLS:
if tool_def.name not in tool_names_in_use:
continue
rc = tool_def.required_connector
if rc and rc not in available_set:
synthesized.append(
Rule(permission=tool_def.name, pattern="*", action="deny")
)
if synthesized:
rulesets.append(Ruleset(rules=synthesized, origin="connector_synthesized"))
general_purpose_interrupt_on: dict[str, bool] = {
rule.permission: True
for rs in rulesets
for rule in rs.rules
if rule.action == "ask" and rule.permission in tool_names_in_use
}
deny_rulesets = [
Ruleset(
rules=[r for r in rs.rules if r.action == "deny"],
origin=rs.origin,
)
for rs in rulesets
]
deny_rulesets = [rs for rs in deny_rulesets if rs.rules]
subagent_deny_mw: PermissionMiddleware | None = (
PermissionMiddleware(rulesets=deny_rulesets) if deny_rulesets else None
)
return PermissionContext(
rulesets=rulesets,
general_purpose_interrupt_on=general_purpose_interrupt_on,
subagent_deny_mw=subagent_deny_mw,
)

View file

@ -0,0 +1,10 @@
"""Main-agent permission middleware (full ask/deny/allow rules)."""
from __future__ import annotations
from app.agents.new_chat.middleware import PermissionMiddleware
from app.agents.new_chat.permissions import Ruleset
def build_full_permission_mw(rulesets: list[Ruleset]) -> PermissionMiddleware | None:
return PermissionMiddleware(rulesets=rulesets) if rulesets else None

View file

@ -0,0 +1,7 @@
"""Resilience middleware shared as the same instances across parent / general-purpose / registry."""
from __future__ import annotations
from .bundle import ResilienceBundle, build_resilience_bundle
__all__ = ["ResilienceBundle", "build_resilience_bundle"]

View file

@ -0,0 +1,51 @@
"""Construct each resilience middleware once; same instances flow into every consumer."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from langchain.agents.middleware import (
ModelCallLimitMiddleware,
ToolCallLimitMiddleware,
)
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.middleware import RetryAfterMiddleware
from app.agents.new_chat.middleware.scoped_model_fallback import (
ScopedModelFallbackMiddleware,
)
from .fallback import build_fallback_mw
from .model_call_limit import build_model_call_limit_mw
from .retry import build_retry_mw
from .tool_call_limit import build_tool_call_limit_mw
@dataclass(frozen=True)
class ResilienceBundle:
retry: RetryAfterMiddleware | None
fallback: ScopedModelFallbackMiddleware | None
model_call_limit: ModelCallLimitMiddleware | None
tool_call_limit: ToolCallLimitMiddleware | None
def as_list(self) -> list[Any]:
return [
m
for m in (
self.retry,
self.fallback,
self.model_call_limit,
self.tool_call_limit,
)
if m is not None
]
def build_resilience_bundle(flags: AgentFeatureFlags) -> ResilienceBundle:
return ResilienceBundle(
retry=build_retry_mw(flags),
fallback=build_fallback_mw(flags),
model_call_limit=build_model_call_limit_mw(flags),
tool_call_limit=build_tool_call_limit_mw(flags),
)

View file

@ -0,0 +1,27 @@
"""Switch to a fallback model on provider/network errors only."""
from __future__ import annotations
import logging
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.middleware.scoped_model_fallback import (
ScopedModelFallbackMiddleware,
)
from ..flags import enabled
def build_fallback_mw(
flags: AgentFeatureFlags,
) -> ScopedModelFallbackMiddleware | None:
if not enabled(flags, "enable_model_fallback"):
return None
try:
return ScopedModelFallbackMiddleware(
"openai:gpt-4o-mini",
"anthropic:claude-3-5-haiku-20241022",
)
except Exception:
logging.warning("ScopedModelFallbackMiddleware init failed; skipping.")
return None

View file

@ -0,0 +1,21 @@
"""Cap model calls per thread / per run to prevent runaway cost."""
from __future__ import annotations
from langchain.agents.middleware import ModelCallLimitMiddleware
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from ..flags import enabled
def build_model_call_limit_mw(
flags: AgentFeatureFlags,
) -> ModelCallLimitMiddleware | None:
if not enabled(flags, "enable_model_call_limit"):
return None
return ModelCallLimitMiddleware(
thread_limit=120,
run_limit=80,
exit_behavior="end",
)

View file

@ -0,0 +1,16 @@
"""Retry on transient model errors (e.g. Retry-After-bearing 429s)."""
from __future__ import annotations
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.middleware import RetryAfterMiddleware
from ..flags import enabled
def build_retry_mw(flags: AgentFeatureFlags) -> RetryAfterMiddleware | None:
return (
RetryAfterMiddleware(max_retries=3)
if enabled(flags, "enable_retry_after")
else None
)

View file

@ -0,0 +1,21 @@
"""Cap tool calls per thread / per run to bound infinite-loop blast radius."""
from __future__ import annotations
from langchain.agents.middleware import ToolCallLimitMiddleware
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from ..flags import enabled
def build_tool_call_limit_mw(
flags: AgentFeatureFlags,
) -> ToolCallLimitMiddleware | None:
if not enabled(flags, "enable_tool_call_limit"):
return None
return ToolCallLimitMiddleware(
thread_limit=300,
run_limit=80,
exit_behavior="continue",
)

View file

@ -0,0 +1,9 @@
"""Todo-list middleware (each consumer needs its own instance)."""
from __future__ import annotations
from langchain.agents.middleware import TodoListMiddleware
def build_todos_mw() -> TodoListMiddleware:
return TodoListMiddleware()

View file

@ -0,0 +1,216 @@
"""Main-agent middleware list assembly: one line per slot."""
from __future__ import annotations
import logging
from collections.abc import Sequence
from typing import Any
from deepagents import SubAgent
from deepagents.backends import StateBackend
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
from app.agents.multi_agent_chat.subagents import (
build_subagents,
get_subagents_to_exclude,
)
from app.agents.multi_agent_chat.subagents.builtins.general_purpose.agent import (
build_subagent as build_general_purpose_subagent,
)
from app.agents.multi_agent_chat.subagents.shared.permissions import ToolsPermissions
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.db import ChatVisibility
from .main_agent.action_log import build_action_log_mw
from .main_agent.anonymous_doc import build_anonymous_doc_mw
from .main_agent.busy_mutex import build_busy_mutex_mw
from .main_agent.checkpointed_subagent_middleware import (
SurfSenseCheckpointedSubAgentMiddleware,
)
from .main_agent.context_editing import build_context_editing_mw
from .main_agent.dedup_hitl import build_dedup_hitl_mw
from .main_agent.doom_loop import build_doom_loop_mw
from .main_agent.kb_persistence import build_kb_persistence_mw
from .main_agent.knowledge_priority import build_knowledge_priority_mw
from .main_agent.knowledge_tree import build_knowledge_tree_mw
from .main_agent.noop_injection import build_noop_injection_mw
from .main_agent.otel import build_otel_mw
from .main_agent.plugins import build_plugin_middlewares
from .main_agent.repair import build_repair_mw
from .main_agent.selector import build_selector_mw
from .main_agent.skills import build_skills_mw
from .shared.anthropic_cache import build_anthropic_cache_mw
from .shared.compaction import build_compaction_mw
from .shared.file_intent import build_file_intent_mw
from .shared.filesystem import build_filesystem_mw
from .shared.memory import build_memory_mw
from .shared.patch_tool_calls import build_patch_tool_calls_mw
from .shared.permissions import (
build_full_permission_mw,
build_permission_context,
)
from .shared.resilience import build_resilience_bundle
from .shared.todos import build_todos_mw
from .subagent.extras import build_subagent_extras
def build_main_agent_deepagent_middleware(
*,
llm: BaseChatModel,
tools: Sequence[BaseTool],
backend_resolver: Any,
filesystem_mode: FilesystemMode,
search_space_id: int,
user_id: str | None,
thread_id: int | None,
visibility: ChatVisibility,
anon_session_id: str | None,
available_connectors: list[str] | None,
available_document_types: list[str] | None,
mentioned_document_ids: list[int] | None,
max_input_tokens: int | None,
flags: AgentFeatureFlags,
subagent_dependencies: dict[str, Any],
checkpointer: Checkpointer,
mcp_tools_by_agent: dict[str, ToolsPermissions] | None = None,
disabled_tools: list[str] | None = None,
) -> list[Any]:
"""Ordered middleware for ``create_agent`` (None entries already stripped)."""
permissions = build_permission_context(
flags=flags,
filesystem_mode=filesystem_mode,
tools=tools,
available_connectors=available_connectors,
)
resilience = build_resilience_bundle(flags)
# Single instance threaded into both the main-agent stack and the general-purpose subagent.
memory_mw = build_memory_mw(
user_id=user_id,
search_space_id=search_space_id,
visibility=visibility,
)
general_purpose_subagent = build_general_purpose_subagent(
llm=llm,
tools=tools,
backend_resolver=backend_resolver,
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
user_id=user_id,
thread_id=thread_id,
permissions=permissions,
resilience=resilience,
memory_mw=memory_mw,
)
subagents_registry: list[SubAgent] = []
try:
subagent_extras = build_subagent_extras(
permissions=permissions,
resilience=resilience,
)
subagents_registry = build_subagents(
dependencies=subagent_dependencies,
model=llm,
extra_middleware=subagent_extras,
mcp_tools_by_agent=mcp_tools_by_agent or {},
exclude=get_subagents_to_exclude(available_connectors),
disabled_tools=disabled_tools,
)
logging.debug(
"Subagents registry: %s",
[s["name"] for s in subagents_registry],
)
except Exception:
# Degrade to general-purpose-only rather than aborting the turn:
# one bad subagent dep should not deny the user a response.
logging.exception(
"Subagents registry build failed; falling back to general-purpose only"
)
subagents_registry = []
subagents: list[SubAgent] = [general_purpose_subagent, *subagents_registry]
stack: list[Any] = [
build_busy_mutex_mw(flags),
build_otel_mw(flags),
build_todos_mw(),
memory_mw,
build_anonymous_doc_mw(
filesystem_mode=filesystem_mode, anon_session_id=anon_session_id
),
build_knowledge_tree_mw(
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
llm=llm,
),
build_knowledge_priority_mw(
llm=llm,
search_space_id=search_space_id,
filesystem_mode=filesystem_mode,
available_connectors=available_connectors,
available_document_types=available_document_types,
mentioned_document_ids=mentioned_document_ids,
),
build_file_intent_mw(llm),
build_filesystem_mw(
backend_resolver=backend_resolver,
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
user_id=user_id,
thread_id=thread_id,
),
build_kb_persistence_mw(
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
user_id=user_id,
thread_id=thread_id,
),
build_skills_mw(
flags=flags,
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
),
SurfSenseCheckpointedSubAgentMiddleware(
checkpointer=checkpointer,
backend=StateBackend,
subagents=subagents,
),
build_selector_mw(flags=flags, tools=tools),
resilience.model_call_limit,
resilience.tool_call_limit,
build_context_editing_mw(
flags=flags,
max_input_tokens=max_input_tokens,
tools=tools,
backend_resolver=backend_resolver,
),
build_compaction_mw(llm),
build_noop_injection_mw(flags),
resilience.retry,
resilience.fallback,
build_repair_mw(flags=flags, tools=tools),
build_full_permission_mw(permissions.rulesets),
build_doom_loop_mw(flags),
build_action_log_mw(
flags=flags,
thread_id=thread_id,
search_space_id=search_space_id,
user_id=user_id,
),
build_patch_tool_calls_mw(),
build_dedup_hitl_mw(tools),
*build_plugin_middlewares(
flags=flags,
search_space_id=search_space_id,
user_id=user_id,
visibility=visibility,
llm=llm,
),
build_anthropic_cache_mw(),
]
return [m for m in stack if m is not None]

View file

@ -0,0 +1,28 @@
"""Extra middleware threaded into every registry subagent's stack.
Registry subagents are scoped to one domain (deliverables, research, memory,
connectors, MCP) and never read or write the SurfSense filesystem that
capability belongs to the main agent and is delegated to the general-purpose
subagent as an escape hatch. Keeping FS off the registry stacks avoids
polluting their tool surface with FS tools they never act on.
"""
from __future__ import annotations
from typing import Any
from ..shared.permissions import PermissionContext
from ..shared.resilience import ResilienceBundle
from ..shared.todos import build_todos_mw
def build_subagent_extras(
*,
permissions: PermissionContext,
resilience: ResilienceBundle,
) -> list[Any]:
extras: list[Any] = [build_todos_mw()]
if permissions.subagent_deny_mw is not None:
extras.append(permissions.subagent_deny_mw)
extras.extend(resilience.as_list())
return extras

View file

@ -0,0 +1,105 @@
"""General-purpose subagent for the multi-agent main agent."""
from __future__ import annotations
from collections.abc import Sequence
from typing import Any, cast
from deepagents import SubAgent
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from app.agents.multi_agent_chat.middleware.shared.anthropic_cache import (
build_anthropic_cache_mw,
)
from app.agents.multi_agent_chat.middleware.shared.compaction import (
build_compaction_mw,
)
from app.agents.multi_agent_chat.middleware.shared.file_intent import (
build_file_intent_mw,
)
from app.agents.multi_agent_chat.middleware.shared.filesystem import (
build_filesystem_mw,
)
from app.agents.multi_agent_chat.middleware.shared.patch_tool_calls import (
build_patch_tool_calls_mw,
)
from app.agents.multi_agent_chat.middleware.shared.permissions import (
PermissionContext,
)
from app.agents.multi_agent_chat.middleware.shared.resilience import (
ResilienceBundle,
)
from app.agents.multi_agent_chat.middleware.shared.todos import build_todos_mw
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware import MemoryInjectionMiddleware
NAME = "general-purpose"
def build_subagent(
*,
llm: BaseChatModel,
tools: Sequence[BaseTool],
backend_resolver: Any,
filesystem_mode: FilesystemMode,
search_space_id: int,
user_id: str | None,
thread_id: int | None,
permissions: PermissionContext,
resilience: ResilienceBundle,
memory_mw: MemoryInjectionMiddleware,
) -> SubAgent:
"""Deny + resilience inserts encapsulated here so the orchestrator never mutates the list."""
middleware: list[Any] = [
build_todos_mw(),
memory_mw,
build_file_intent_mw(llm),
build_filesystem_mw(
backend_resolver=backend_resolver,
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
user_id=user_id,
thread_id=thread_id,
),
build_compaction_mw(llm),
build_patch_tool_calls_mw(),
build_anthropic_cache_mw(),
]
if permissions.subagent_deny_mw is not None:
patch_idx = next(
(
i
for i, m in enumerate(middleware)
if isinstance(m, PatchToolCallsMiddleware)
),
len(middleware),
)
middleware.insert(patch_idx, permissions.subagent_deny_mw)
resilience_mws = resilience.as_list()
if resilience_mws:
cache_idx = next(
(
i
for i, m in enumerate(middleware)
if isinstance(m, AnthropicPromptCachingMiddleware)
),
len(middleware),
)
for offset, mw in enumerate(resilience_mws):
middleware.insert(cache_idx + offset, mw)
spec: dict[str, Any] = {
**GENERAL_PURPOSE_SUBAGENT,
"model": llm,
"tools": tools,
"middleware": middleware,
}
if permissions.general_purpose_interrupt_on:
spec["interrupt_on"] = permissions.general_purpose_interrupt_on
return cast(SubAgent, spec)

View file

@ -168,20 +168,46 @@ def create_create_calendar_event_tool(
f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}" f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
) )
tz = context.get("timezone", "UTC")
if ( if (
connector.connector_type connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
): ):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id") cca_id = connector.config.get("composio_connected_account_id")
if cca_id: if not cca_id:
creds = build_composio_credentials(cca_id)
else:
return { return {
"status": "error", "status": "error",
"message": "Composio connected account ID not found for this connector.", "message": "Composio connected account ID not found for this connector.",
} }
from app.services.composio_service import ComposioService
(
event_id,
html_link,
error,
) = await ComposioService().create_calendar_event(
connected_account_id=cca_id,
entity_id=f"surfsense_{user_id}",
summary=final_summary,
start_datetime=final_start_datetime,
end_datetime=final_end_datetime,
timezone=tz,
description=final_description,
location=final_location,
attendees=final_attendees,
)
if error:
return {"status": "error", "message": error}
created = {
"id": event_id,
"summary": final_summary,
"htmlLink": html_link,
}
logger.info(
f"Calendar event created via Composio: id={event_id}, summary={final_summary}"
)
else: else:
config_data = dict(connector.config) config_data = dict(connector.config)
@ -211,70 +237,69 @@ def create_create_calendar_event_tool(
expiry=datetime.fromisoformat(exp) if exp else None, expiry=datetime.fromisoformat(exp) if exp else None,
) )
service = await asyncio.get_event_loop().run_in_executor( service = await asyncio.get_event_loop().run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds) None, lambda: build("calendar", "v3", credentials=creds)
)
tz = context.get("timezone", "UTC")
event_body: dict[str, Any] = {
"summary": final_summary,
"start": {"dateTime": final_start_datetime, "timeZone": tz},
"end": {"dateTime": final_end_datetime, "timeZone": tz},
}
if final_description:
event_body["description"] = final_description
if final_location:
event_body["location"] = final_location
if final_attendees:
event_body["attendees"] = [
{"email": e.strip()} for e in final_attendees if e.strip()
]
try:
created = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
service.events()
.insert(calendarId="primary", body=event_body)
.execute()
),
) )
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403: event_body: dict[str, Any] = {
logger.warning( "summary": final_summary,
f"Insufficient permissions for connector {actual_connector_id}: {api_err}" "start": {"dateTime": final_start_datetime, "timeZone": tz},
"end": {"dateTime": final_end_datetime, "timeZone": tz},
}
if final_description:
event_body["description"] = final_description
if final_location:
event_body["location"] = final_location
if final_attendees:
event_body["attendees"] = [
{"email": e.strip()} for e in final_attendees if e.strip()
]
try:
created = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
service.events()
.insert(calendarId="primary", body=event_body)
.execute()
),
) )
try: except Exception as api_err:
from sqlalchemy.orm.attributes import flag_modified from googleapiclient.errors import HttpError
_res = await db_session.execute( if isinstance(api_err, HttpError) and api_err.resp.status == 403:
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning( logger.warning(
"Failed to persist auth_expired for connector %s", f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
actual_connector_id,
exc_info=True,
) )
return { try:
"status": "insufficient_permissions", from sqlalchemy.orm.attributes import flag_modified
"connector_id": actual_connector_id,
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info( _res = await db_session.execute(
f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}" select(SearchSourceConnector).where(
) SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(
f"Calendar event created via Google API: id={created.get('id')}, summary={created.get('summary')}"
)
kb_message_suffix = "" kb_message_suffix = ""
try: try:

View file

@ -163,16 +163,22 @@ def create_delete_calendar_event_tool(
connector.connector_type connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
): ):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id") cca_id = connector.config.get("composio_connected_account_id")
if cca_id: if not cca_id:
creds = build_composio_credentials(cca_id)
else:
return { return {
"status": "error", "status": "error",
"message": "Composio connected account ID not found for this connector.", "message": "Composio connected account ID not found for this connector.",
} }
from app.services.composio_service import ComposioService
error = await ComposioService().delete_calendar_event(
connected_account_id=cca_id,
entity_id=f"surfsense_{user_id}",
event_id=final_event_id,
)
if error:
return {"status": "error", "message": error}
else: else:
config_data = dict(connector.config) config_data = dict(connector.config)
@ -202,51 +208,51 @@ def create_delete_calendar_event_tool(
expiry=datetime.fromisoformat(exp) if exp else None, expiry=datetime.fromisoformat(exp) if exp else None,
) )
service = await asyncio.get_event_loop().run_in_executor( service = await asyncio.get_event_loop().run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds) None, lambda: build("calendar", "v3", credentials=creds)
)
try:
await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
service.events()
.delete(calendarId="primary", eventId=final_event_id)
.execute()
),
) )
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403: try:
logger.warning( await asyncio.get_event_loop().run_in_executor(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}" None,
lambda: (
service.events()
.delete(calendarId="primary", eventId=final_event_id)
.execute()
),
) )
try: except Exception as api_err:
from sqlalchemy.orm.attributes import flag_modified from googleapiclient.errors import HttpError
_res = await db_session.execute( if isinstance(api_err, HttpError) and api_err.resp.status == 403:
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning( logger.warning(
"Failed to persist auth_expired for connector %s", f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
actual_connector_id,
exc_info=True,
) )
return { try:
"status": "insufficient_permissions", from sqlalchemy.orm.attributes import flag_modified
"connector_id": actual_connector_id,
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.", _res = await db_session.execute(
} select(SearchSourceConnector).where(
raise SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(f"Calendar event deleted: event_id={final_event_id}") logger.info(f"Calendar event deleted: event_id={final_event_id}")

View file

@ -16,6 +16,14 @@ _CALENDAR_TYPES = [
] ]
def _to_calendar_boundary(value: str, *, is_end: bool) -> str:
"""Promote a bare YYYY-MM-DD to RFC3339 with a day-edge time, leave full datetimes alone."""
if "T" in value:
return value
time = "23:59:59" if is_end else "00:00:00"
return f"{value}T{time}Z"
def create_search_calendar_events_tool( def create_search_calendar_events_tool(
db_session: AsyncSession | None = None, db_session: AsyncSession | None = None,
search_space_id: int | None = None, search_space_id: int | None = None,
@ -61,22 +69,47 @@ def create_search_calendar_events_tool(
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.", "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
} }
creds = _build_credentials(connector) if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this connector.",
}
from app.connectors.google_calendar_connector import GoogleCalendarConnector from app.services.composio_service import ComposioService
cal = GoogleCalendarConnector( events_raw, error = await ComposioService().get_calendar_events(
credentials=creds, connected_account_id=cca_id,
session=db_session, entity_id=f"surfsense_{user_id}",
user_id=user_id, time_min=_to_calendar_boundary(start_date, is_end=False),
connector_id=connector.id, time_max=_to_calendar_boundary(end_date, is_end=True),
) max_results=max_results,
)
if not events_raw and not error:
error = "No events found in the specified date range."
else:
creds = _build_credentials(connector)
events_raw, error = await cal.get_all_primary_calendar_events( from app.connectors.google_calendar_connector import (
start_date=start_date, GoogleCalendarConnector,
end_date=end_date, )
max_results=max_results,
) cal = GoogleCalendarConnector(
credentials=creds,
session=db_session,
user_id=user_id,
connector_id=connector.id,
)
events_raw, error = await cal.get_all_primary_calendar_events(
start_date=start_date,
end_date=end_date,
max_results=max_results,
)
if error: if error:
if ( if (

View file

@ -192,20 +192,62 @@ def create_update_calendar_event_tool(
f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}" f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
) )
has_changes = any(
v is not None
for v in (
final_new_summary,
final_new_start_datetime,
final_new_end_datetime,
final_new_description,
final_new_location,
final_new_attendees,
)
)
if not has_changes:
return {
"status": "error",
"message": "No changes specified. Please provide at least one field to update.",
}
if ( if (
connector.connector_type connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR == SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
): ):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id") cca_id = connector.config.get("composio_connected_account_id")
if cca_id: if not cca_id:
creds = build_composio_credentials(cca_id)
else:
return { return {
"status": "error", "status": "error",
"message": "Composio connected account ID not found for this connector.", "message": "Composio connected account ID not found for this connector.",
} }
from app.services.composio_service import ComposioService
tz_for_composio: str | None = None
if final_new_start_datetime is not None and not _is_date_only(
final_new_start_datetime
):
tz_for_composio = (
context.get("timezone") if isinstance(context, dict) else None
)
_, html_link, error = await ComposioService().update_calendar_event(
connected_account_id=cca_id,
entity_id=f"surfsense_{user_id}",
event_id=final_event_id,
summary=final_new_summary,
start_time=final_new_start_datetime,
end_time=final_new_end_datetime,
timezone=tz_for_composio,
description=final_new_description,
location=final_new_location,
attendees=final_new_attendees,
)
if error:
return {"status": "error", "message": error}
updated = {"htmlLink": html_link}
logger.info(
f"Calendar event updated via Composio: event_id={final_event_id}"
)
else: else:
config_data = dict(connector.config) config_data = dict(connector.config)
@ -235,81 +277,79 @@ def create_update_calendar_event_tool(
expiry=datetime.fromisoformat(exp) if exp else None, expiry=datetime.fromisoformat(exp) if exp else None,
) )
service = await asyncio.get_event_loop().run_in_executor( service = await asyncio.get_event_loop().run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds) None, lambda: build("calendar", "v3", credentials=creds)
)
update_body: dict[str, Any] = {}
if final_new_summary is not None:
update_body["summary"] = final_new_summary
if final_new_start_datetime is not None:
update_body["start"] = _build_time_body(
final_new_start_datetime, context
) )
if final_new_end_datetime is not None:
update_body["end"] = _build_time_body(final_new_end_datetime, context)
if final_new_description is not None:
update_body["description"] = final_new_description
if final_new_location is not None:
update_body["location"] = final_new_location
if final_new_attendees is not None:
update_body["attendees"] = [
{"email": e.strip()} for e in final_new_attendees if e.strip()
]
if not update_body: update_body: dict[str, Any] = {}
return { if final_new_summary is not None:
"status": "error", update_body["summary"] = final_new_summary
"message": "No changes specified. Please provide at least one field to update.", if final_new_start_datetime is not None:
} update_body["start"] = _build_time_body(
final_new_start_datetime, context
try:
updated = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
service.events()
.patch(
calendarId="primary",
eventId=final_event_id,
body=update_body,
)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
) )
try: if final_new_end_datetime is not None:
from sqlalchemy.orm.attributes import flag_modified update_body["end"] = _build_time_body(
final_new_end_datetime, context
)
if final_new_description is not None:
update_body["description"] = final_new_description
if final_new_location is not None:
update_body["location"] = final_new_location
if final_new_attendees is not None:
update_body["attendees"] = [
{"email": e.strip()} for e in final_new_attendees if e.strip()
]
_res = await db_session.execute( try:
select(SearchSourceConnector).where( updated = await asyncio.get_event_loop().run_in_executor(
SearchSourceConnector.id == actual_connector_id None,
lambda: (
service.events()
.patch(
calendarId="primary",
eventId=final_event_id,
body=update_body,
) )
) .execute()
_conn = _res.scalar_one_or_none() ),
if _conn and not _conn.config.get("auth_expired"): )
_conn.config = {**_conn.config, "auth_expired": True} except Exception as api_err:
flag_modified(_conn, "config") from googleapiclient.errors import HttpError
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(f"Calendar event updated: event_id={final_event_id}") if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(
f"Calendar event updated via Google API: event_id={final_event_id}"
)
kb_message_suffix = "" kb_message_suffix = ""
if document_id is not None: if document_id is not None:

View file

@ -161,16 +161,39 @@ def create_create_gmail_draft_tool(
connector.connector_type connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
): ):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id") cca_id = connector.config.get("composio_connected_account_id")
if cca_id: if not cca_id:
creds = build_composio_credentials(cca_id)
else:
return { return {
"status": "error", "status": "error",
"message": "Composio connected account ID not found for this Gmail connector.", "message": "Composio connected account ID not found for this Gmail connector.",
} }
from app.services.composio_service import ComposioService
(
draft_id,
draft_message_id,
draft_thread_id,
error,
) = await ComposioService().create_gmail_draft(
connected_account_id=cca_id,
entity_id=f"surfsense_{user_id}",
to=final_to,
subject=final_subject,
body=final_body,
cc=final_cc,
bcc=final_bcc,
)
if error:
return {"status": "error", "message": error}
created = {
"id": draft_id,
"message": {
"id": draft_message_id,
"threadId": draft_thread_id,
},
}
logger.info(f"Gmail draft created via Composio: id={draft_id}")
else: else:
from google.oauth2.credentials import Credentials from google.oauth2.credentials import Credentials
@ -208,63 +231,65 @@ def create_create_gmail_draft_tool(
expiry=datetime.fromisoformat(exp) if exp else None, expiry=datetime.fromisoformat(exp) if exp else None,
) )
from googleapiclient.discovery import build from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds) gmail_service = build("gmail", "v1", credentials=creds)
message = MIMEText(final_body) message = MIMEText(final_body)
message["to"] = final_to message["to"] = final_to
message["subject"] = final_subject message["subject"] = final_subject
if final_cc: if final_cc:
message["cc"] = final_cc message["cc"] = final_cc
if final_bcc: if final_bcc:
message["bcc"] = final_bcc message["bcc"] = final_bcc
raw = base64.urlsafe_b64encode(message.as_bytes()).decode() raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try: try:
created = await asyncio.get_event_loop().run_in_executor( created = await asyncio.get_event_loop().run_in_executor(
None, None,
lambda: ( lambda: (
gmail_service.users() gmail_service.users()
.drafts() .drafts()
.create(userId="me", body={"message": {"raw": raw}}) .create(userId="me", body={"message": {"raw": raw}})
.execute() .execute()
), ),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
) )
try: except Exception as api_err:
from sqlalchemy.orm.attributes import flag_modified from googleapiclient.errors import HttpError
_res = await db_session.execute( if isinstance(api_err, HttpError) and api_err.resp.status == 403:
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning( logger.warning(
"Failed to persist auth_expired for connector %s", f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
actual_connector_id,
exc_info=True,
) )
return { try:
"status": "insufficient_permissions", from sqlalchemy.orm.attributes import flag_modified
"connector_id": actual_connector_id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(f"Gmail draft created: id={created.get('id')}") _res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(
f"Gmail draft created via Google API: id={created.get('id')}"
)
kb_message_suffix = "" kb_message_suffix = ""
try: try:

View file

@ -50,7 +50,56 @@ def create_read_gmail_email_tool(
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.", "message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
} }
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
):
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
}
from app.agents.new_chat.tools.gmail.search_emails import (
_format_gmail_summary,
)
from app.services.composio_service import ComposioService
detail, error = await ComposioService().get_gmail_message_detail(
connected_account_id=cca_id,
entity_id=f"surfsense_{user_id}",
message_id=message_id,
)
if error:
return {"status": "error", "message": error}
if not detail:
return {
"status": "not_found",
"message": f"Email with ID '{message_id}' not found.",
}
summary = _format_gmail_summary(detail)
content = (
f"# {summary['subject']}\n\n"
f"**From:** {summary['from']}\n"
f"**To:** {summary['to']}\n"
f"**Date:** {summary['date']}\n\n"
f"## Message Content\n\n"
f"{detail.get('messageText') or detail.get('snippet') or ''}\n\n"
f"## Message Details\n\n"
f"- **Message ID:** {summary['message_id']}\n"
f"- **Thread ID:** {summary['thread_id']}\n"
)
return {
"status": "success",
"message_id": summary["message_id"] or message_id,
"content": content,
}
from app.agents.new_chat.tools.gmail.search_emails import (
_build_credentials,
)
creds = _build_credentials(connector) creds = _build_credentials(connector)

View file

@ -1,5 +1,4 @@
import logging import logging
from datetime import datetime
from typing import Any from typing import Any
from langchain_core.tools import tool from langchain_core.tools import tool
@ -15,57 +14,6 @@ _GMAIL_TYPES = [
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
] ]
_token_encryption_cache: object | None = None
def _get_token_encryption():
global _token_encryption_cache
if _token_encryption_cache is None:
from app.config import config
from app.utils.oauth_security import TokenEncryption
if not config.SECRET_KEY:
raise RuntimeError("SECRET_KEY not configured for token decryption.")
_token_encryption_cache = TokenEncryption(config.SECRET_KEY)
return _token_encryption_cache
def _build_credentials(connector: SearchSourceConnector):
"""Build Google OAuth Credentials from a connector's stored config.
Handles both native OAuth connectors (with encrypted tokens) and
Composio-backed connectors. Shared by Gmail and Calendar tools.
"""
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
raise ValueError("Composio connected account ID not found.")
return build_composio_credentials(cca_id)
from google.oauth2.credentials import Credentials
cfg = dict(connector.config)
if cfg.get("_token_encrypted"):
enc = _get_token_encryption()
for key in ("token", "refresh_token", "client_secret"):
if cfg.get(key):
cfg[key] = enc.decrypt_token(cfg[key])
exp = (cfg.get("expiry") or "").replace("Z", "")
return Credentials(
token=cfg.get("token"),
refresh_token=cfg.get("refresh_token"),
token_uri=cfg.get("token_uri"),
client_id=cfg.get("client_id"),
client_secret=cfg.get("client_secret"),
scopes=cfg.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
def create_search_gmail_tool( def create_search_gmail_tool(
db_session: AsyncSession | None = None, db_session: AsyncSession | None = None,
@ -110,6 +58,50 @@ def create_search_gmail_tool(
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.", "message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
} }
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
):
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
}
from app.agents.new_chat.tools.gmail.search_emails import (
_format_gmail_summary,
)
from app.services.composio_service import ComposioService
(
messages,
_next,
_estimate,
error,
) = await ComposioService().get_gmail_messages(
connected_account_id=cca_id,
entity_id=f"surfsense_{user_id}",
query=query,
max_results=max_results,
)
if error:
return {"status": "error", "message": error}
emails = [_format_gmail_summary(m) for m in messages]
if not emails:
return {
"status": "success",
"emails": [],
"total": 0,
"message": "No emails found.",
}
return {"status": "success", "emails": emails, "total": len(emails)}
from app.agents.new_chat.tools.gmail.search_emails import (
_build_credentials,
)
creds = _build_credentials(connector) creds = _build_credentials(connector)
from app.connectors.google_gmail_connector import GoogleGmailConnector from app.connectors.google_gmail_connector import GoogleGmailConnector

View file

@ -162,16 +162,31 @@ def create_send_gmail_email_tool(
connector.connector_type connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
): ):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id") cca_id = connector.config.get("composio_connected_account_id")
if cca_id: if not cca_id:
creds = build_composio_credentials(cca_id)
else:
return { return {
"status": "error", "status": "error",
"message": "Composio connected account ID not found for this Gmail connector.", "message": "Composio connected account ID not found for this Gmail connector.",
} }
from app.services.composio_service import ComposioService
(
sent_message_id,
sent_thread_id,
error,
) = await ComposioService().send_gmail_email(
connected_account_id=cca_id,
entity_id=f"surfsense_{user_id}",
to=final_to,
subject=final_subject,
body=final_body,
cc=final_cc,
bcc=final_bcc,
)
if error:
return {"status": "error", "message": error}
sent = {"id": sent_message_id, "threadId": sent_thread_id}
else: else:
from google.oauth2.credentials import Credentials from google.oauth2.credentials import Credentials
@ -209,61 +224,61 @@ def create_send_gmail_email_tool(
expiry=datetime.fromisoformat(exp) if exp else None, expiry=datetime.fromisoformat(exp) if exp else None,
) )
from googleapiclient.discovery import build from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds) gmail_service = build("gmail", "v1", credentials=creds)
message = MIMEText(final_body) message = MIMEText(final_body)
message["to"] = final_to message["to"] = final_to
message["subject"] = final_subject message["subject"] = final_subject
if final_cc: if final_cc:
message["cc"] = final_cc message["cc"] = final_cc
if final_bcc: if final_bcc:
message["bcc"] = final_bcc message["bcc"] = final_bcc
raw = base64.urlsafe_b64encode(message.as_bytes()).decode() raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try: try:
sent = await asyncio.get_event_loop().run_in_executor( sent = await asyncio.get_event_loop().run_in_executor(
None, None,
lambda: ( lambda: (
gmail_service.users() gmail_service.users()
.messages() .messages()
.send(userId="me", body={"raw": raw}) .send(userId="me", body={"raw": raw})
.execute() .execute()
), ),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
) )
try: except Exception as api_err:
from sqlalchemy.orm.attributes import flag_modified from googleapiclient.errors import HttpError
_res = await db_session.execute( if isinstance(api_err, HttpError) and api_err.resp.status == 403:
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning( logger.warning(
"Failed to persist auth_expired for connector %s", f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
actual_connector_id,
exc_info=True,
) )
return { try:
"status": "insufficient_permissions", from sqlalchemy.orm.attributes import flag_modified
"connector_id": actual_connector_id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", _res = await db_session.execute(
} select(SearchSourceConnector).where(
raise SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info( logger.info(
f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}" f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}"

View file

@ -162,16 +162,22 @@ def create_trash_gmail_email_tool(
connector.connector_type connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
): ):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id") cca_id = connector.config.get("composio_connected_account_id")
if cca_id: if not cca_id:
creds = build_composio_credentials(cca_id)
else:
return { return {
"status": "error", "status": "error",
"message": "Composio connected account ID not found for this Gmail connector.", "message": "Composio connected account ID not found for this Gmail connector.",
} }
from app.services.composio_service import ComposioService
error = await ComposioService().trash_gmail_message(
connected_account_id=cca_id,
entity_id=f"surfsense_{user_id}",
message_id=final_message_id,
)
if error:
return {"status": "error", "message": error}
else: else:
from google.oauth2.credentials import Credentials from google.oauth2.credentials import Credentials
@ -209,49 +215,49 @@ def create_trash_gmail_email_tool(
expiry=datetime.fromisoformat(exp) if exp else None, expiry=datetime.fromisoformat(exp) if exp else None,
) )
from googleapiclient.discovery import build from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds) gmail_service = build("gmail", "v1", credentials=creds)
try: try:
await asyncio.get_event_loop().run_in_executor( await asyncio.get_event_loop().run_in_executor(
None, None,
lambda: ( lambda: (
gmail_service.users() gmail_service.users()
.messages() .messages()
.trash(userId="me", id=final_message_id) .trash(userId="me", id=final_message_id)
.execute() .execute()
), ),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {connector.id}: {api_err}"
) )
try: except Exception as api_err:
from sqlalchemy.orm.attributes import flag_modified from googleapiclient.errors import HttpError
if not connector.config.get("auth_expired"): if isinstance(api_err, HttpError) and api_err.resp.status == 403:
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
logger.warning( logger.warning(
"Failed to persist auth_expired for connector %s", f"Insufficient permissions for connector {connector.id}: {api_err}"
connector.id,
exc_info=True,
) )
return { try:
"status": "insufficient_permissions", from sqlalchemy.orm.attributes import flag_modified
"connector_id": connector.id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.", if not connector.config.get("auth_expired"):
} connector.config = {
raise **connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": connector.id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(f"Gmail email trashed: message_id={final_message_id}") logger.info(f"Gmail email trashed: message_id={final_message_id}")

View file

@ -192,16 +192,51 @@ def create_update_gmail_draft_tool(
connector.connector_type connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR == SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
): ):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id") cca_id = connector.config.get("composio_connected_account_id")
if cca_id: if not cca_id:
creds = build_composio_credentials(cca_id)
else:
return { return {
"status": "error", "status": "error",
"message": "Composio connected account ID not found for this Gmail connector.", "message": "Composio connected account ID not found for this Gmail connector.",
} }
if not final_draft_id:
return {
"status": "error",
"message": (
"Could not find this draft in Gmail. "
"It may have already been sent or deleted."
),
}
from app.services.composio_service import ComposioService
(
new_draft_id,
new_message_id,
error,
) = await ComposioService().update_gmail_draft(
connected_account_id=cca_id,
entity_id=f"surfsense_{user_id}",
draft_id=final_draft_id,
to=final_to or None,
subject=final_subject,
body=final_body,
cc=final_cc,
bcc=final_bcc,
)
if error:
if "not found" in error.lower() or "no longer" in error.lower():
return {
"status": "error",
"message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
}
return {"status": "error", "message": error}
updated = {
"id": new_draft_id or final_draft_id,
"message": {"id": new_message_id} if new_message_id else {},
}
logger.info(f"Gmail draft updated via Composio: id={updated.get('id')}")
else: else:
from google.oauth2.credentials import Credentials from google.oauth2.credentials import Credentials
@ -239,88 +274,90 @@ def create_update_gmail_draft_tool(
expiry=datetime.fromisoformat(exp) if exp else None, expiry=datetime.fromisoformat(exp) if exp else None,
) )
from googleapiclient.discovery import build from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds) gmail_service = build("gmail", "v1", credentials=creds)
# Resolve draft_id if not already available # Resolve draft_id if not already available
if not final_draft_id: if not final_draft_id:
logger.info( logger.info(
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}" f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
) )
final_draft_id = await _find_draft_id_by_message( final_draft_id = await _find_draft_id_by_message(
gmail_service, message_id gmail_service, message_id
)
if not final_draft_id:
return {
"status": "error",
"message": (
"Could not find this draft in Gmail. "
"It may have already been sent or deleted."
),
}
message = MIMEText(final_body)
if final_to:
message["to"] = final_to
message["subject"] = final_subject
if final_cc:
message["cc"] = final_cc
if final_bcc:
message["bcc"] = final_bcc
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try:
updated = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.drafts()
.update(
userId="me",
id=final_draft_id,
body={"message": {"raw": raw}},
)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {connector.id}: {api_err}"
) )
try:
from sqlalchemy.orm.attributes import flag_modified
if not connector.config.get("auth_expired"): if not final_draft_id:
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": connector.id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
if isinstance(api_err, HttpError) and api_err.resp.status == 404:
return { return {
"status": "error", "status": "error",
"message": "Draft no longer exists in Gmail. It may have been sent or deleted.", "message": (
"Could not find this draft in Gmail. "
"It may have already been sent or deleted."
),
} }
raise
logger.info(f"Gmail draft updated: id={updated.get('id')}") message = MIMEText(final_body)
if final_to:
message["to"] = final_to
message["subject"] = final_subject
if final_cc:
message["cc"] = final_cc
if final_bcc:
message["bcc"] = final_bcc
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try:
updated = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.drafts()
.update(
userId="me",
id=final_draft_id,
body={"message": {"raw": raw}},
)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {connector.id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
if not connector.config.get("auth_expired"):
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": connector.id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
if isinstance(api_err, HttpError) and api_err.resp.status == 404:
return {
"status": "error",
"message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
}
raise
logger.info(
f"Gmail draft updated via Google API: id={updated.get('id')}"
)
kb_message_suffix = "" kb_message_suffix = ""
if document_id: if document_id:

View file

@ -179,59 +179,96 @@ def create_create_google_drive_file_tool(
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}" f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
) )
pre_built_creds = None async def _flag_auth_expired() -> None:
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
if ( if (
connector.connector_type connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
): ):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id") cca_id = connector.config.get("composio_connected_account_id")
if cca_id: if not cca_id:
pre_built_creds = build_composio_credentials(cca_id) return {
"status": "error",
"message": "Composio connected account ID not found for this Google Drive connector.",
}
client = GoogleDriveClient( from app.services.composio_service import ComposioService
session=db_session,
connector_id=actual_connector_id, created, error = await ComposioService().create_drive_file_from_text(
credentials=pre_built_creds, connected_account_id=cca_id,
) entity_id=f"surfsense_{user_id}",
try:
created = await client.create_file(
name=final_name, name=final_name,
mime_type=mime_type, mime_type=mime_type,
parent_folder_id=final_parent_folder_id,
content=final_content, content=final_content,
parent_id=final_parent_folder_id,
) )
except HttpError as http_err:
if http_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute( if error or not created:
select(SearchSourceConnector).where( err_lower = (error or "").lower()
SearchSourceConnector.id == actual_connector_id if (
) "insufficient" in err_lower
) or "permission" in err_lower
_conn = _res.scalar_one_or_none() or "403" in err_lower
if _conn and not _conn.config.get("auth_expired"): ):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning( logger.warning(
"Failed to persist auth_expired for connector %s", f"Insufficient permissions for Composio Drive connector {actual_connector_id}: {error}"
actual_connector_id,
exc_info=True,
) )
await _flag_auth_expired()
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
}
logger.error(
f"Composio Drive create_file failed for connector {actual_connector_id}: {error}"
)
return { return {
"status": "insufficient_permissions", "status": "error",
"connector_id": actual_connector_id, "message": "Something went wrong while creating the file. Please try again.",
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
} }
raise else:
client = GoogleDriveClient(
session=db_session,
connector_id=actual_connector_id,
)
try:
created = await client.create_file(
name=final_name,
mime_type=mime_type,
parent_folder_id=final_parent_folder_id,
content=final_content,
)
except HttpError as http_err:
if http_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
)
await _flag_auth_expired()
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info( logger.info(
f"Google Drive file created: id={created.get('id')}, name={created.get('name')}" f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"

View file

@ -158,51 +158,84 @@ def create_delete_google_drive_file_tool(
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}" f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
) )
pre_built_creds = None async def _flag_auth_expired() -> None:
try:
from sqlalchemy.orm.attributes import flag_modified
if not connector.config.get("auth_expired"):
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
if ( if (
connector.connector_type connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR == SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
): ):
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id") cca_id = connector.config.get("composio_connected_account_id")
if cca_id: if not cca_id:
pre_built_creds = build_composio_credentials(cca_id)
client = GoogleDriveClient(
session=db_session,
connector_id=connector.id,
credentials=pre_built_creds,
)
try:
await client.trash_file(file_id=final_file_id)
except HttpError as http_err:
if http_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {connector.id}: {http_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
if not connector.config.get("auth_expired"):
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
return { return {
"status": "insufficient_permissions", "status": "error",
"connector_id": connector.id, "message": "Composio connected account ID not found for this Google Drive connector.",
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
} }
raise
from app.services.composio_service import ComposioService
error = await ComposioService().trash_drive_file(
connected_account_id=cca_id,
entity_id=f"surfsense_{user_id}",
file_id=final_file_id,
)
if error:
err_lower = error.lower()
if (
"insufficient" in err_lower
or "permission" in err_lower
or "403" in err_lower
):
logger.warning(
f"Insufficient permissions for Composio Drive connector {connector.id}: {error}"
)
await _flag_auth_expired()
return {
"status": "insufficient_permissions",
"connector_id": connector.id,
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
}
logger.error(
f"Composio Drive trash_file failed for connector {connector.id}: {error}"
)
return {
"status": "error",
"message": "Something went wrong while trashing the file. Please try again.",
}
else:
client = GoogleDriveClient(
session=db_session,
connector_id=connector.id,
)
try:
await client.trash_file(file_id=final_file_id)
except HttpError as http_err:
if http_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {connector.id}: {http_err}"
)
await _flag_auth_expired()
return {
"status": "insufficient_permissions",
"connector_id": connector.id,
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info( logger.info(
f"Google Drive file deleted (moved to trash): file_id={final_file_id}" f"Google Drive file deleted (moved to trash): file_id={final_file_id}"

View file

@ -31,7 +31,6 @@ from langchain.agents import create_agent
from langchain.agents.middleware import ( from langchain.agents.middleware import (
LLMToolSelectorMiddleware, LLMToolSelectorMiddleware,
ModelCallLimitMiddleware, ModelCallLimitMiddleware,
ModelFallbackMiddleware,
TodoListMiddleware, TodoListMiddleware,
ToolCallLimitMiddleware, ToolCallLimitMiddleware,
) )
@ -77,6 +76,9 @@ from app.agents.new_chat.middleware import (
create_surfsense_compaction_middleware, create_surfsense_compaction_middleware,
default_skills_sources, default_skills_sources,
) )
from app.agents.new_chat.middleware.scoped_model_fallback import (
ScopedModelFallbackMiddleware,
)
from app.agents.new_chat.permissions import Rule, Ruleset from app.agents.new_chat.permissions import Rule, Ruleset
from app.agents.new_chat.plugin_loader import ( from app.agents.new_chat.plugin_loader import (
PluginContext, PluginContext,
@ -792,15 +794,15 @@ def _build_compiled_agent_blocking(
# Fallback chain — primary is the agent's own model; we add cheap # Fallback chain — primary is the agent's own model; we add cheap
# alternatives. Off by default; only the first call site that # alternatives. Off by default; only the first call site that
# configures the chain via env should enable it. # configures the chain via env should enable it.
fallback_mw: ModelFallbackMiddleware | None = None fallback_mw: ScopedModelFallbackMiddleware | None = None
if flags.enable_model_fallback and not flags.disable_new_agent_stack: if flags.enable_model_fallback and not flags.disable_new_agent_stack:
try: try:
fallback_mw = ModelFallbackMiddleware( fallback_mw = ScopedModelFallbackMiddleware(
"openai:gpt-4o-mini", "openai:gpt-4o-mini",
"anthropic:claude-3-5-haiku-20241022", "anthropic:claude-3-5-haiku-20241022",
) )
except Exception: except Exception:
logging.warning("ModelFallbackMiddleware init failed; skipping.") logging.warning("ScopedModelFallbackMiddleware init failed; skipping.")
fallback_mw = None fallback_mw = None
model_call_limit_mw = ( model_call_limit_mw = (
ModelCallLimitMiddleware( ModelCallLimitMiddleware(

View file

@ -0,0 +1,91 @@
"""Fallback only on provider/network errors; let programming bugs raise."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from langchain.agents.middleware import ModelFallbackMiddleware
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langchain_core.messages import AIMessage
# Matched by class name across the MRO so we don't have to import every
# provider SDK (openai/anthropic/google/...). Extend as new providers ship.
_FALLBACK_ELIGIBLE_NAMES: frozenset[str] = frozenset(
{
"RateLimitError",
"APIStatusError",
"InternalServerError",
"ServiceUnavailableError",
"BadGatewayError",
"GatewayTimeoutError",
"APIConnectionError",
"APITimeoutError",
"ConnectError",
"ConnectTimeout",
"ReadTimeout",
"RemoteProtocolError",
"TimeoutError",
"TimeoutException",
}
)
def _is_fallback_eligible(exc: BaseException) -> bool:
return any(cls.__name__ in _FALLBACK_ELIGIBLE_NAMES for cls in type(exc).__mro__)
class ScopedModelFallbackMiddleware(ModelFallbackMiddleware):
"""Re-raise non-provider exceptions instead of walking the fallback chain."""
def wrap_model_call( # type: ignore[override]
self,
request: ModelRequest[Any],
handler: Callable[[ModelRequest[Any]], ModelResponse[Any]],
) -> ModelResponse[Any] | AIMessage:
last_exception: Exception
try:
return handler(request)
except Exception as e:
if not _is_fallback_eligible(e):
raise
last_exception = e
for fallback_model in self.models:
try:
return handler(request.override(model=fallback_model))
except Exception as e:
if not _is_fallback_eligible(e):
raise
last_exception = e
continue
raise last_exception
async def awrap_model_call( # type: ignore[override]
self,
request: ModelRequest[Any],
handler: Callable[[ModelRequest[Any]], Awaitable[ModelResponse[Any]]],
) -> ModelResponse[Any] | AIMessage:
last_exception: Exception
try:
return await handler(request)
except Exception as e:
if not _is_fallback_eligible(e):
raise
last_exception = e
for fallback_model in self.models:
try:
return await handler(request.override(model=fallback_model))
except Exception as e:
if not _is_fallback_eligible(e):
raise
last_exception = e
continue
raise last_exception

View file

@ -1,5 +1,15 @@
import re
from app.config import config from app.config import config
# Regex that matches a Markdown table block (header + separator + one or more rows)
# A table block starts with a | at the beginning of a line and ends when a
# non-table line (or end of string) is encountered.
_TABLE_BLOCK_RE = re.compile(
r"(?:(?:^|\n)(?=[ \t]*\|)(?:[ \t]*\|[^\n]*\n)+)",
re.MULTILINE,
)
def chunk_text(text: str, use_code_chunker: bool = False) -> list[str]: def chunk_text(text: str, use_code_chunker: bool = False) -> list[str]:
"""Chunk a text string using the configured chunker and return the chunk texts.""" """Chunk a text string using the configured chunker and return the chunk texts."""
@ -7,3 +17,43 @@ def chunk_text(text: str, use_code_chunker: bool = False) -> list[str]:
config.code_chunker_instance if use_code_chunker else config.chunker_instance config.code_chunker_instance if use_code_chunker else config.chunker_instance
) )
return [c.text for c in chunker.chunk(text)] return [c.text for c in chunker.chunk(text)]
def chunk_text_hybrid(text: str) -> list[str]:
"""Table-aware chunker that prevents Markdown tables from being split mid-row.
Algorithm:
1. Scan the document for Markdown table blocks.
2. Each table block is emitted as a single, unmodified chunk so that its
header, separator row, and data rows always stay together.
3. The non-table prose segments between (and around) tables are passed through
the normal ``chunk_text`` chunker and their sub-chunks are interleaved in
document order.
This ensures that table data is never sliced in the middle by the token-based
chunker, which would otherwise produce garbled rows that are useless for RAG.
Fixes #1334.
"""
chunks: list[str] = []
cursor = 0
for match in _TABLE_BLOCK_RE.finditer(text):
# Prose before this table
prose = text[cursor : match.start()].strip()
if prose:
chunks.extend(chunk_text(prose))
# The table itself is kept as one indivisible chunk
table_block = match.group(0).strip()
if table_block:
chunks.append(table_block)
cursor = match.end()
# Remaining prose after the last table (or entire text if no tables)
trailing = text[cursor:].strip()
if trailing:
chunks.extend(chunk_text(trailing))
return chunks

View file

@ -19,7 +19,7 @@ from app.db import (
DocumentType, DocumentType,
) )
from app.indexing_pipeline.connector_document import ConnectorDocument from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_chunker import chunk_text from app.indexing_pipeline.document_chunker import chunk_text, chunk_text_hybrid
from app.indexing_pipeline.document_embedder import embed_texts from app.indexing_pipeline.document_embedder import embed_texts
from app.indexing_pipeline.document_hashing import ( from app.indexing_pipeline.document_hashing import (
compute_content_hash, compute_content_hash,
@ -387,11 +387,19 @@ class IndexingPipelineService:
) )
t_step = time.perf_counter() t_step = time.perf_counter()
chunk_texts = await asyncio.to_thread( if connector_doc.should_use_code_chunker:
chunk_text, chunk_texts = await asyncio.to_thread(
connector_doc.source_markdown, chunk_text,
use_code_chunker=connector_doc.should_use_code_chunker, connector_doc.source_markdown,
) use_code_chunker=True,
)
else:
# Use the table-aware hybrid chunker so Markdown tables are not
# split mid-row (see issue #1334).
chunk_texts = await asyncio.to_thread(
chunk_text_hybrid,
connector_doc.source_markdown,
)
texts_to_embed = [content, *chunk_texts] texts_to_embed = [content, *chunk_texts]
embeddings = await asyncio.to_thread(embed_texts, texts_to_embed) embeddings = await asyncio.to_thread(embed_texts, texts_to_embed)

View file

@ -1027,6 +1027,505 @@ class ComposioService:
logger.error(f"Failed to list Calendar events: {e!s}") logger.error(f"Failed to list Calendar events: {e!s}")
return [], str(e) return [], str(e)
@staticmethod
def _unwrap_response_data(data: Any) -> Any:
"""Composio responses often nest the meaningful payload under
``data.data.response_data``. Walk that envelope safely and return
whichever inner dict actually has the result keys."""
if not isinstance(data, dict):
return data
inner = data.get("data", data)
if isinstance(inner, dict):
return inner.get("response_data", inner)
return inner
@staticmethod
def _split_email_csv(value: str | None) -> list[str] | None:
"""Tools accept comma-separated cc/bcc strings; Composio expects an array."""
if not value:
return None
addrs = [e.strip() for e in value.split(",") if e.strip()]
return addrs or None
# ===== Gmail write methods =====
async def send_gmail_email(
self,
connected_account_id: str,
entity_id: str,
to: str,
subject: str,
body: str,
cc: str | None = None,
bcc: str | None = None,
is_html: bool = False,
) -> tuple[str | None, str | None, str | None]:
"""Send a Gmail message via the Composio ``GMAIL_SEND_EMAIL`` toolkit.
Returns:
Tuple of (message_id, thread_id, error). On success ``error`` is
None and at least one of the IDs is populated when Composio
returns them; on failure both IDs are None.
"""
try:
params: dict[str, Any] = {
"recipient_email": to,
"subject": subject,
"body": body,
"is_html": is_html,
}
if cc:
cc_list = self._split_email_csv(cc)
if cc_list:
params["cc"] = cc_list
if bcc:
bcc_list = self._split_email_csv(bcc)
if bcc_list:
params["bcc"] = bcc_list
result = await self.execute_tool(
connected_account_id=connected_account_id,
tool_name="GMAIL_SEND_EMAIL",
params=params,
entity_id=entity_id,
)
if not result.get("success"):
return None, None, result.get("error", "Unknown error")
payload = self._unwrap_response_data(result.get("data", {}))
message_id = None
thread_id = None
if isinstance(payload, dict):
message_id = (
payload.get("id")
or payload.get("message_id")
or payload.get("messageId")
)
thread_id = payload.get("threadId") or payload.get("thread_id")
return message_id, thread_id, None
except Exception as e:
logger.error(f"Failed to send Gmail email: {e!s}")
return None, None, str(e)
async def create_gmail_draft(
self,
connected_account_id: str,
entity_id: str,
to: str,
subject: str,
body: str,
cc: str | None = None,
bcc: str | None = None,
is_html: bool = False,
) -> tuple[str | None, str | None, str | None, str | None]:
"""Create a Gmail draft via the Composio ``GMAIL_CREATE_EMAIL_DRAFT`` toolkit.
Returns:
Tuple of (draft_id, message_id, thread_id, error). On success
``error`` is None and ``draft_id`` is populated.
"""
try:
params: dict[str, Any] = {
"recipient_email": to,
"subject": subject,
"body": body,
"is_html": is_html,
}
cc_list = self._split_email_csv(cc)
if cc_list:
params["cc"] = cc_list
bcc_list = self._split_email_csv(bcc)
if bcc_list:
params["bcc"] = bcc_list
result = await self.execute_tool(
connected_account_id=connected_account_id,
tool_name="GMAIL_CREATE_EMAIL_DRAFT",
params=params,
entity_id=entity_id,
)
if not result.get("success"):
return None, None, None, result.get("error", "Unknown error")
payload = self._unwrap_response_data(result.get("data", {}))
draft_id = None
message_id = None
thread_id = None
if isinstance(payload, dict):
draft_id = payload.get("id") or payload.get("draft_id")
draft_message = payload.get("message") or {}
if isinstance(draft_message, dict):
message_id = draft_message.get("id") or draft_message.get(
"message_id"
)
thread_id = draft_message.get("threadId") or draft_message.get(
"thread_id"
)
if message_id is None:
message_id = payload.get("message_id") or payload.get("messageId")
if thread_id is None:
thread_id = payload.get("thread_id") or payload.get("threadId")
return draft_id, message_id, thread_id, None
except Exception as e:
logger.error(f"Failed to create Gmail draft: {e!s}")
return None, None, None, str(e)
async def update_gmail_draft(
self,
connected_account_id: str,
entity_id: str,
draft_id: str,
to: str | None = None,
subject: str | None = None,
body: str | None = None,
cc: str | None = None,
bcc: str | None = None,
is_html: bool = False,
) -> tuple[str | None, str | None, str | None]:
"""Update an existing Gmail draft via ``GMAIL_UPDATE_DRAFT``.
Returns:
Tuple of (draft_id, message_id, error).
"""
try:
params: dict[str, Any] = {
"draft_id": draft_id,
"is_html": is_html,
}
if to:
params["recipient_email"] = to
if subject is not None:
params["subject"] = subject
if body is not None:
params["body"] = body
cc_list = self._split_email_csv(cc)
if cc_list:
params["cc"] = cc_list
bcc_list = self._split_email_csv(bcc)
if bcc_list:
params["bcc"] = bcc_list
result = await self.execute_tool(
connected_account_id=connected_account_id,
tool_name="GMAIL_UPDATE_DRAFT",
params=params,
entity_id=entity_id,
)
if not result.get("success"):
return None, None, result.get("error", "Unknown error")
payload = self._unwrap_response_data(result.get("data", {}))
new_draft_id = draft_id
message_id = None
if isinstance(payload, dict):
new_draft_id = payload.get("id") or payload.get("draft_id") or draft_id
draft_message = payload.get("message") or {}
if isinstance(draft_message, dict):
message_id = draft_message.get("id") or draft_message.get(
"message_id"
)
if message_id is None:
message_id = payload.get("message_id") or payload.get("messageId")
return new_draft_id, message_id, None
except Exception as e:
logger.error(f"Failed to update Gmail draft: {e!s}")
return None, None, str(e)
async def trash_gmail_message(
self,
connected_account_id: str,
entity_id: str,
message_id: str,
) -> str | None:
"""Move a Gmail message to trash via ``GMAIL_MOVE_TO_TRASH``.
Returns the error message on failure, ``None`` on success.
"""
try:
result = await self.execute_tool(
connected_account_id=connected_account_id,
tool_name="GMAIL_MOVE_TO_TRASH",
params={"message_id": message_id},
entity_id=entity_id,
)
if not result.get("success"):
return result.get("error", "Unknown error")
return None
except Exception as e:
logger.error(f"Failed to trash Gmail message: {e!s}")
return str(e)
# ===== Google Calendar write methods =====
async def create_calendar_event(
self,
connected_account_id: str,
entity_id: str,
summary: str,
start_datetime: str,
end_datetime: str,
timezone: str | None = None,
description: str | None = None,
location: str | None = None,
attendees: list[str] | None = None,
calendar_id: str = "primary",
) -> tuple[str | None, str | None, str | None]:
"""Create a Google Calendar event via ``GOOGLECALENDAR_CREATE_EVENT``.
Composio strips trailing timezone info on ``start_datetime`` /
``end_datetime`` and uses the ``timezone`` field as the IANA name,
so callers may pass ISO 8601 strings with or without offsets.
Returns:
Tuple of (event_id, html_link, error).
"""
try:
params: dict[str, Any] = {
"summary": summary,
"start_datetime": start_datetime,
"end_datetime": end_datetime,
"calendar_id": calendar_id,
}
if timezone:
params["timezone"] = timezone
if description:
params["description"] = description
if location:
params["location"] = location
if attendees:
params["attendees"] = [a for a in attendees if a]
result = await self.execute_tool(
connected_account_id=connected_account_id,
tool_name="GOOGLECALENDAR_CREATE_EVENT",
params=params,
entity_id=entity_id,
)
if not result.get("success"):
return None, None, result.get("error", "Unknown error")
payload = self._unwrap_response_data(result.get("data", {}))
event_id = None
html_link = None
if isinstance(payload, dict):
event_id = payload.get("id") or payload.get("event_id")
html_link = payload.get("htmlLink") or payload.get("html_link")
return event_id, html_link, None
except Exception as e:
logger.error(f"Failed to create Calendar event: {e!s}")
return None, None, str(e)
async def update_calendar_event(
self,
connected_account_id: str,
entity_id: str,
event_id: str,
summary: str | None = None,
start_time: str | None = None,
end_time: str | None = None,
timezone: str | None = None,
description: str | None = None,
location: str | None = None,
attendees: list[str] | None = None,
calendar_id: str = "primary",
) -> tuple[str | None, str | None, str | None]:
"""Patch an existing Google Calendar event via ``GOOGLECALENDAR_PATCH_EVENT``.
Uses PATCH (not PUT) semantics so omitted fields are preserved.
Returns:
Tuple of (event_id, html_link, error).
"""
try:
params: dict[str, Any] = {
"event_id": event_id,
"calendar_id": calendar_id,
}
if summary is not None:
params["summary"] = summary
if start_time is not None:
params["start_time"] = start_time
if end_time is not None:
params["end_time"] = end_time
if timezone:
params["timezone"] = timezone
if description is not None:
params["description"] = description
if location is not None:
params["location"] = location
if attendees is not None:
params["attendees"] = [a for a in attendees if a]
result = await self.execute_tool(
connected_account_id=connected_account_id,
tool_name="GOOGLECALENDAR_PATCH_EVENT",
params=params,
entity_id=entity_id,
)
if not result.get("success"):
return None, None, result.get("error", "Unknown error")
payload = self._unwrap_response_data(result.get("data", {}))
new_event_id = event_id
html_link = None
if isinstance(payload, dict):
new_event_id = payload.get("id") or payload.get("event_id") or event_id
html_link = payload.get("htmlLink") or payload.get("html_link")
return new_event_id, html_link, None
except Exception as e:
logger.error(f"Failed to patch Calendar event: {e!s}")
return None, None, str(e)
async def delete_calendar_event(
self,
connected_account_id: str,
entity_id: str,
event_id: str,
calendar_id: str = "primary",
) -> str | None:
"""Delete a Google Calendar event via ``GOOGLECALENDAR_DELETE_EVENT``.
Returns the error message on failure, ``None`` on success (idempotent
on already-deleted events).
"""
try:
result = await self.execute_tool(
connected_account_id=connected_account_id,
tool_name="GOOGLECALENDAR_DELETE_EVENT",
params={
"event_id": event_id,
"calendar_id": calendar_id,
},
entity_id=entity_id,
)
if not result.get("success"):
return result.get("error", "Unknown error")
return None
except Exception as e:
logger.error(f"Failed to delete Calendar event: {e!s}")
return str(e)
# ===== Google Drive write methods =====
@staticmethod
def _drive_web_view_link(file_id: str, mime_type: str | None) -> str:
"""Synthesize a Google Drive ``webViewLink`` from id + mimeType.
Composio's ``GOOGLEDRIVE_CREATE_FILE_FROM_TEXT`` returns flat
metadata (id, name, mimeType) but does not always include a
``webViewLink``. We rebuild the canonical UI URL based on the
Workspace MIME type so callers can keep using a single field.
"""
if not file_id:
return ""
mt = (mime_type or "").lower()
if mt == "application/vnd.google-apps.document":
return f"https://docs.google.com/document/d/{file_id}/edit"
if mt == "application/vnd.google-apps.spreadsheet":
return f"https://docs.google.com/spreadsheets/d/{file_id}/edit"
if mt == "application/vnd.google-apps.presentation":
return f"https://docs.google.com/presentation/d/{file_id}/edit"
if mt == "application/vnd.google-apps.folder":
return f"https://drive.google.com/drive/folders/{file_id}"
return f"https://drive.google.com/file/d/{file_id}/view"
async def create_drive_file_from_text(
self,
connected_account_id: str,
entity_id: str,
name: str,
mime_type: str,
content: str | None = None,
parent_id: str | None = None,
) -> tuple[dict[str, Any] | None, str | None]:
"""Create a Google Drive file from text via ``GOOGLEDRIVE_CREATE_FILE_FROM_TEXT``.
Composio's tool requires ``text_content`` even for "empty" files;
an empty string is accepted. Native Workspace types (Docs, Sheets)
are produced by setting ``mime_type`` to the Google Apps MIME, and
Drive auto-converts the text payload (e.g. CSV Sheet).
Returns:
Tuple of (file_meta, error). ``file_meta`` keys:
``id``, ``name``, ``mimeType``, ``webViewLink``.
"""
try:
params: dict[str, Any] = {
"file_name": name,
"mime_type": mime_type,
"text_content": content if content is not None else "",
}
if parent_id:
params["parent_id"] = parent_id
result = await self.execute_tool(
connected_account_id=connected_account_id,
tool_name="GOOGLEDRIVE_CREATE_FILE_FROM_TEXT",
params=params,
entity_id=entity_id,
)
if not result.get("success"):
return None, result.get("error", "Unknown error")
payload = self._unwrap_response_data(result.get("data", {}))
file_id: str | None = None
file_name: str | None = name
mime: str | None = mime_type
web_view_link: str | None = None
if isinstance(payload, dict):
file_id = (
payload.get("id") or payload.get("file_id") or payload.get("fileId")
)
file_name = payload.get("name") or payload.get("file_name") or name
mime = payload.get("mimeType") or payload.get("mime_type") or mime_type
web_view_link = payload.get("webViewLink") or payload.get(
"web_view_link"
)
if not file_id:
return None, "Composio response did not include a file id"
if not web_view_link:
web_view_link = self._drive_web_view_link(file_id, mime)
return (
{
"id": file_id,
"name": file_name,
"mimeType": mime,
"webViewLink": web_view_link,
},
None,
)
except Exception as e:
logger.error(f"Failed to create Drive file: {e!s}")
return None, str(e)
async def trash_drive_file(
self,
connected_account_id: str,
entity_id: str,
file_id: str,
) -> str | None:
"""Move a Google Drive file to trash via ``GOOGLEDRIVE_TRASH_FILE``.
Returns the error message on failure, ``None`` on success.
"""
try:
result = await self.execute_tool(
connected_account_id=connected_account_id,
tool_name="GOOGLEDRIVE_TRASH_FILE",
params={"file_id": file_id},
entity_id=entity_id,
)
if not result.get("success"):
return result.get("error", "Unknown error")
return None
except Exception as e:
logger.error(f"Failed to trash Drive file: {e!s}")
return str(e)
# ===== User Info Methods ===== # ===== User Info Methods =====
async def get_connected_account_email( async def get_connected_account_email(

View file

@ -28,9 +28,7 @@ from langchain_core.messages import HumanMessage
from sqlalchemy.future import select from sqlalchemy.future import select
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from app.agents.multi_agent_chat import ( from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent
create_surfsense_deep_agent as create_registry_deep_agent,
)
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.agents.new_chat.checkpointer import get_checkpointer from app.agents.new_chat.checkpointer import get_checkpointer
from app.agents.new_chat.context import SurfSenseContextSchema from app.agents.new_chat.context import SurfSenseContextSchema
@ -577,6 +575,43 @@ async def _preflight_llm(llm: Any) -> None:
) )
async def _build_main_agent_for_thread(
agent_factory: Any,
*,
llm: Any,
search_space_id: int,
db_session: Any,
connector_service: ConnectorService,
checkpointer: Any,
user_id: str | None,
thread_id: int | None,
agent_config: AgentConfig | None,
firecrawl_api_key: str | None,
thread_visibility: ChatVisibility | None,
filesystem_selection: FilesystemSelection | None,
disabled_tools: list[str] | None = None,
mentioned_document_ids: list[int] | None = None,
) -> Any:
"""Single (re)build path so the agent factory cannot drift across
initial build, preflight repin, and mid-stream 429 recovery for one
``thread_id``: a graph swap mid-turn would corrupt checkpointer state."""
return await agent_factory(
llm=llm,
search_space_id=search_space_id,
db_session=db_session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=thread_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=thread_visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
)
async def _settle_speculative_agent_build(task: asyncio.Task[Any]) -> None: async def _settle_speculative_agent_build(task: asyncio.Task[Any]) -> None:
"""Wait for a discarded speculative agent build to release shared state. """Wait for a discarded speculative agent build to release shared state.
@ -2767,7 +2802,7 @@ async def stream_new_chat(
_t0 = time.perf_counter() _t0 = time.perf_counter()
agent_factory = ( agent_factory = (
create_registry_deep_agent create_multi_agent_chat_deep_agent
if use_multi_agent if use_multi_agent
else create_surfsense_deep_agent else create_surfsense_deep_agent
) )
@ -2776,7 +2811,8 @@ async def stream_new_chat(
# if preflight reports 429 we will discard this future and rebuild # if preflight reports 429 we will discard this future and rebuild
# against the freshly pinned config below. # against the freshly pinned config below.
agent_build_task = asyncio.create_task( agent_build_task = asyncio.create_task(
agent_factory( _build_main_agent_for_thread(
agent_factory,
llm=llm, llm=llm,
search_space_id=search_space_id, search_space_id=search_space_id,
db_session=session, db_session=session,
@ -2787,9 +2823,9 @@ async def stream_new_chat(
agent_config=agent_config, agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key, firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility, thread_visibility=visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools, disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids, mentioned_document_ids=mentioned_document_ids,
filesystem_selection=filesystem_selection,
), ),
name="agent_build:stream_new_chat", name="agent_build:stream_new_chat",
) )
@ -3466,7 +3502,8 @@ async def stream_new_chat(
title_task = None title_task = None
_t0 = time.perf_counter() _t0 = time.perf_counter()
agent = await create_surfsense_deep_agent( agent = await _build_main_agent_for_thread(
agent_factory,
llm=llm, llm=llm,
search_space_id=search_space_id, search_space_id=search_space_id,
db_session=session, db_session=session,
@ -3477,9 +3514,9 @@ async def stream_new_chat(
agent_config=agent_config, agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key, firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility, thread_visibility=visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools, disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids, mentioned_document_ids=mentioned_document_ids,
filesystem_selection=filesystem_selection,
) )
_perf_log.info( _perf_log.info(
"[stream_new_chat] Runtime rate-limit recovery repinned " "[stream_new_chat] Runtime rate-limit recovery repinned "
@ -4130,12 +4167,13 @@ async def stream_resume_chat(
_t0 = time.perf_counter() _t0 = time.perf_counter()
agent_factory = ( agent_factory = (
create_registry_deep_agent create_multi_agent_chat_deep_agent
if _app_config.MULTI_AGENT_CHAT_ENABLED if _app_config.MULTI_AGENT_CHAT_ENABLED
else create_surfsense_deep_agent else create_surfsense_deep_agent
) )
agent_build_task = asyncio.create_task( agent_build_task = asyncio.create_task(
agent_factory( _build_main_agent_for_thread(
agent_factory,
llm=llm, llm=llm,
search_space_id=search_space_id, search_space_id=search_space_id,
db_session=session, db_session=session,
@ -4224,7 +4262,8 @@ async def stream_resume_chat(
"fallback_config_id": llm_config_id, "fallback_config_id": llm_config_id,
}, },
) )
agent = await agent_factory( agent = await _build_main_agent_for_thread(
agent_factory,
llm=llm, llm=llm,
search_space_id=search_space_id, search_space_id=search_space_id,
db_session=session, db_session=session,
@ -4409,7 +4448,8 @@ async def stream_resume_chat(
raise stream_exc raise stream_exc
_t0 = time.perf_counter() _t0 = time.perf_counter()
agent = await create_surfsense_deep_agent( agent = await _build_main_agent_for_thread(
agent_factory,
llm=llm, llm=llm,
search_space_id=search_space_id, search_space_id=search_space_id,
db_session=session, db_session=session,
@ -4421,6 +4461,7 @@ async def stream_resume_chat(
firecrawl_api_key=firecrawl_api_key, firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility, thread_visibility=visibility,
filesystem_selection=filesystem_selection, filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
) )
_perf_log.info( _perf_log.info(
"[stream_resume] Runtime rate-limit recovery repinned " "[stream_resume] Runtime rate-limit recovery repinned "

View file

@ -1,6 +1,6 @@
[project] [project]
name = "surf-new-backend" name = "surf-new-backend"
version = "0.0.22" version = "0.0.23"
description = "SurfSense Backend" description = "SurfSense Backend"
requires-python = ">=3.12" requires-python = ">=3.12"
dependencies = [ dependencies = [

View file

@ -0,0 +1,208 @@
"""End-to-end resume-bridge tests against a real LangGraph subagent."""
from __future__ import annotations
import ast
import pytest
from langchain.tools import ToolRuntime
from langchain_core.messages import HumanMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.types import Command, interrupt
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
build_task_tool_with_parent_config,
)
class _SubagentState(TypedDict, total=False):
messages: list
decision_text: str
def _build_single_interrupt_subagent():
def approve_node(state):
from langchain_core.messages import AIMessage
decision = interrupt(
{
"action_requests": [
{
"name": "do_thing",
"args": {"x": 1},
"description": "test action",
}
],
"review_configs": [{}],
}
)
return {
"messages": [AIMessage(content="done")],
"decision_text": repr(decision),
}
graph = StateGraph(_SubagentState)
graph.add_node("approve", approve_node)
graph.add_edge(START, "approve")
graph.add_edge("approve", END)
return graph.compile(checkpointer=InMemorySaver())
def _make_runtime(config: dict) -> ToolRuntime:
return ToolRuntime(
state={"messages": [HumanMessage(content="seed")]},
context=None,
config=config,
stream_writer=None,
tool_call_id="parent-tcid-1",
store=None,
)
@pytest.mark.asyncio
async def test_resume_bridge_dispatches_decision_into_pending_subagent():
"""Side-channel decision must reach the subagent's pending interrupt verbatim."""
subagent = _build_single_interrupt_subagent()
task_tool = build_task_tool_with_parent_config(
[
{
"name": "approver",
"description": "approves things",
"runnable": subagent,
}
]
)
parent_config: dict = {
"configurable": {"thread_id": "shared-thread"},
"recursion_limit": 100,
}
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
snap = await subagent.aget_state(parent_config)
assert snap.tasks and snap.tasks[0].interrupts, (
"fixture broken: subagent should be paused on its interrupt"
)
parent_config["configurable"]["surfsense_resume_value"] = {
"decisions": ["APPROVED"]
}
runtime = _make_runtime(parent_config)
result = await task_tool.coroutine(
description="please approve",
subagent_type="approver",
runtime=runtime,
)
assert isinstance(result, Command)
update = result.update
assert update["decision_text"] == repr({"decisions": ["APPROVED"]})
assert "surfsense_resume_value" not in parent_config["configurable"]
final = await subagent.aget_state(parent_config)
assert not final.tasks or all(not t.interrupts for t in final.tasks)
@pytest.mark.asyncio
async def test_pending_interrupt_without_resume_value_raises_runtime_error():
"""Bridge must fail loud rather than silently replay the user's interrupt."""
subagent = _build_single_interrupt_subagent()
task_tool = build_task_tool_with_parent_config(
[
{
"name": "approver",
"description": "approves things",
"runnable": subagent,
}
]
)
parent_config: dict = {
"configurable": {"thread_id": "guard-thread"},
"recursion_limit": 100,
}
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
snap = await subagent.aget_state(parent_config)
assert snap.tasks and snap.tasks[0].interrupts, "fixture broken"
runtime = _make_runtime(parent_config)
with pytest.raises(RuntimeError, match="resume bridge is broken"):
await task_tool.coroutine(
description="please approve",
subagent_type="approver",
runtime=runtime,
)
def _build_bundle_subagent():
def bundle_node(state):
from langchain_core.messages import AIMessage
decision = interrupt(
{
"action_requests": [
{"name": "create_a", "args": {}, "description": ""},
{"name": "create_b", "args": {}, "description": ""},
{"name": "create_c", "args": {}, "description": ""},
],
"review_configs": [{}, {}, {}],
}
)
return {
"messages": [AIMessage(content="bundle-done")],
"decision_text": repr(decision),
}
graph = StateGraph(_SubagentState)
graph.add_node("bundle", bundle_node)
graph.add_edge(START, "bundle")
graph.add_edge("bundle", END)
return graph.compile(checkpointer=InMemorySaver())
@pytest.mark.asyncio
async def test_bundle_three_mixed_decisions_arrive_in_order():
"""Approve / edit / reject for a 3-action bundle must land at ordinals 0/1/2."""
subagent = _build_bundle_subagent()
task_tool = build_task_tool_with_parent_config(
[
{
"name": "bundler",
"description": "creates a bundle",
"runnable": subagent,
}
]
)
parent_config: dict = {
"configurable": {"thread_id": "bundle-thread"},
"recursion_limit": 100,
}
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
decisions_payload = {
"decisions": [
{"type": "approve", "args": {}},
{"type": "edit", "args": {"args": {"name": "edited-b"}}},
{"type": "reject", "args": {"message": "no thanks"}},
]
}
parent_config["configurable"]["surfsense_resume_value"] = decisions_payload
runtime = _make_runtime(parent_config)
result = await task_tool.coroutine(
description="run bundle",
subagent_type="bundler",
runtime=runtime,
)
assert isinstance(result, Command)
received = ast.literal_eval(result.update["decision_text"])
assert received == decisions_payload
assert received["decisions"][0]["type"] == "approve"
assert received["decisions"][1]["type"] == "edit"
assert received["decisions"][1]["args"] == {"args": {"name": "edited-b"}}
assert received["decisions"][2]["type"] == "reject"

View file

@ -0,0 +1,55 @@
"""Pins the first-wins assumption of ``get_first_pending_subagent_interrupt``.
The bridge currently relies on at-most-one pending interrupt per snapshot
(sequential tool nodes). If parallel tool calls are ever enabled, the bridge
needs an id-aware lookup; these tests will need to be revisited at that point.
"""
from __future__ import annotations
from types import SimpleNamespace
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume import (
get_first_pending_subagent_interrupt,
)
class TestGetFirstPendingSubagentInterrupt:
def test_returns_first_when_multiple_top_level_interrupts_pending(self):
first = SimpleNamespace(id="i-1", value={"decision": "approve"})
second = SimpleNamespace(id="i-2", value={"decision": "reject"})
state = SimpleNamespace(interrupts=(first, second), tasks=())
assert get_first_pending_subagent_interrupt(state) == (
"i-1",
{"decision": "approve"},
)
def test_returns_first_when_multiple_subtask_interrupts_pending(self):
first = SimpleNamespace(id="i-A", value="approve")
second = SimpleNamespace(id="i-B", value="reject")
sub_task = SimpleNamespace(interrupts=(first, second))
state = SimpleNamespace(interrupts=(), tasks=(sub_task,))
assert get_first_pending_subagent_interrupt(state) == ("i-A", "approve")
def test_returns_none_when_no_interrupts(self):
state = SimpleNamespace(interrupts=(), tasks=())
assert get_first_pending_subagent_interrupt(state) == (None, None)
def test_returns_none_when_state_is_none(self):
assert get_first_pending_subagent_interrupt(None) == (None, None)
def test_skips_interrupts_with_none_value(self):
empty = SimpleNamespace(id="i-empty", value=None)
real = SimpleNamespace(id="i-real", value="approve")
state = SimpleNamespace(interrupts=(empty, real), tasks=())
assert get_first_pending_subagent_interrupt(state) == ("i-real", "approve")
def test_normalizes_non_string_id_to_none(self):
interrupt = SimpleNamespace(id=12345, value="approve")
state = SimpleNamespace(interrupts=(interrupt,), tasks=())
assert get_first_pending_subagent_interrupt(state) == (None, "approve")

View file

@ -0,0 +1,71 @@
"""Resume side-channel must be read exactly once per turn."""
from __future__ import annotations
from langchain.tools import ToolRuntime
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import (
consume_surfsense_resume,
has_surfsense_resume,
)
def _runtime_with_config(config: dict) -> ToolRuntime:
return ToolRuntime(
state=None,
context=None,
config=config,
stream_writer=None,
tool_call_id="tcid-test",
store=None,
)
class TestConsumeSurfsenseResume:
def test_pops_value_on_first_call(self):
runtime = _runtime_with_config(
{"configurable": {"surfsense_resume_value": {"decisions": ["approve"]}}}
)
assert consume_surfsense_resume(runtime) == {"decisions": ["approve"]}
def test_second_call_returns_none(self):
configurable: dict = {"surfsense_resume_value": {"decisions": ["approve"]}}
runtime = _runtime_with_config({"configurable": configurable})
consume_surfsense_resume(runtime)
assert consume_surfsense_resume(runtime) is None
assert "surfsense_resume_value" not in configurable
def test_returns_none_when_no_payload_queued(self):
runtime = _runtime_with_config({"configurable": {}})
assert consume_surfsense_resume(runtime) is None
def test_returns_none_when_configurable_missing(self):
runtime = _runtime_with_config({})
assert consume_surfsense_resume(runtime) is None
class TestHasSurfsenseResume:
def test_true_when_payload_queued(self):
runtime = _runtime_with_config(
{"configurable": {"surfsense_resume_value": "approve"}}
)
assert has_surfsense_resume(runtime) is True
def test_does_not_consume_payload(self):
configurable = {"surfsense_resume_value": "approve"}
runtime = _runtime_with_config({"configurable": configurable})
has_surfsense_resume(runtime)
assert configurable == {"surfsense_resume_value": "approve"}
def test_false_when_payload_absent(self):
runtime = _runtime_with_config({"configurable": {}})
assert has_surfsense_resume(runtime) is False

View file

@ -0,0 +1,96 @@
"""Subagent resilience contract: ``extra_middleware`` reaches the agent chain."""
from __future__ import annotations
from collections.abc import AsyncIterator, Iterator
from typing import Any
import pytest
from langchain.agents import create_agent
from langchain.agents.middleware import ModelFallbackMiddleware
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.fake_chat_models import (
FakeMessagesListChatModel,
)
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from app.agents.multi_agent_chat.subagents.shared.subagent_builder import (
pack_subagent,
)
class RateLimitError(Exception):
"""Name matches the scoped-fallback eligibility allowlist."""
class _AlwaysFailingChatModel(BaseChatModel):
@property
def _llm_type(self) -> str:
return "always-failing-test-model"
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
msg = "primary llm exploded"
raise RateLimitError(msg)
async def _agenerate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
msg = "primary llm exploded"
raise RateLimitError(msg)
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGeneration]:
msg = "primary llm exploded"
raise RateLimitError(msg)
async def _astream(
self, *args: Any, **kwargs: Any
) -> AsyncIterator[ChatGeneration]:
msg = "primary llm exploded"
raise RateLimitError(msg)
yield # pragma: no cover - unreachable, satisfies async generator typing
@pytest.mark.asyncio
async def test_subagent_recovers_when_primary_llm_fails():
"""Fallback in ``extra_middleware`` must finish the turn when primary raises."""
primary = _AlwaysFailingChatModel()
fallback = FakeMessagesListChatModel(
responses=[AIMessage(content="recovered via fallback")]
)
spec = pack_subagent(
name="resilience_test",
description="test subagent",
system_prompt="be helpful",
tools=[],
model=primary,
extra_middleware=[ModelFallbackMiddleware(fallback)],
)
agent = create_agent(
model=spec["model"],
tools=spec["tools"],
middleware=spec["middleware"],
system_prompt=spec["system_prompt"],
)
result = await agent.ainvoke({"messages": [HumanMessage(content="hi")]})
final = result["messages"][-1]
assert isinstance(final, AIMessage)
assert final.content == "recovered via fallback"

View file

@ -0,0 +1,128 @@
"""``ScopedModelFallbackMiddleware`` triggers fallback only on provider errors."""
from __future__ import annotations
from collections.abc import AsyncIterator, Iterator
from typing import Any
import pytest
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatResult
class _RaisingChatModel(BaseChatModel):
exc_to_raise: Any
@property
def _llm_type(self) -> str:
return "raising-test-model"
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
raise self.exc_to_raise
async def _agenerate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
raise self.exc_to_raise
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGeneration]:
raise self.exc_to_raise
async def _astream(
self, *args: Any, **kwargs: Any
) -> AsyncIterator[ChatGeneration]:
raise self.exc_to_raise
yield # pragma: no cover - unreachable
class _RecordingChatModel(BaseChatModel):
response_text: str = "fallback-ok"
call_count: int = 0
@property
def _llm_type(self) -> str:
return "recording-test-model"
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
self.call_count += 1
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content=self.response_text))]
)
async def _agenerate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
return self._generate(messages, stop, None, **kwargs)
class RateLimitError(Exception):
"""Name matches the scoped-fallback eligibility allowlist."""
def _build_agent(primary: BaseChatModel, fallback: BaseChatModel):
from langchain.agents import create_agent
from app.agents.new_chat.middleware.scoped_model_fallback import (
ScopedModelFallbackMiddleware,
)
return create_agent(
model=primary,
tools=[],
middleware=[ScopedModelFallbackMiddleware(fallback)],
system_prompt="be helpful",
)
@pytest.mark.asyncio
async def test_provider_errors_trigger_fallback():
"""Eligible exception names must drive the fallback chain."""
primary = _RaisingChatModel(exc_to_raise=RateLimitError("429 from provider"))
fallback = _RecordingChatModel(response_text="recovered")
agent = _build_agent(primary, fallback)
result = await agent.ainvoke({"messages": [("user", "hi")]})
final = result["messages"][-1]
assert isinstance(final, AIMessage)
assert final.content == "recovered"
assert fallback.call_count == 1
@pytest.mark.asyncio
async def test_programming_errors_propagate_without_invoking_fallback():
"""Non-eligible exceptions must propagate; fallback must not be invoked."""
primary = _RaisingChatModel(exc_to_raise=KeyError("missing_state_field"))
fallback = _RecordingChatModel(response_text="should-never-arrive")
agent = _build_agent(primary, fallback)
with pytest.raises(KeyError, match="missing_state_field"):
await agent.ainvoke({"messages": [("user", "hi")]})
assert fallback.call_count == 0

View file

@ -202,6 +202,15 @@ class FakeBudgetLLM:
class TestKnowledgeBaseSearchMiddlewarePlanner: class TestKnowledgeBaseSearchMiddlewarePlanner:
@pytest.fixture(autouse=True)
def _disable_planner_runnable(self, monkeypatch):
# ``FakeLLM`` is a duck-typed mock; ``create_agent`` (used when the
# planner Runnable path is enabled) calls ``.bind()`` on the LLM,
# which the mock does not implement. Pin the flag off so the
# planner falls through to the legacy ``self.llm.ainvoke`` path
# these tests assert against (``llm.calls[0]["config"]``).
monkeypatch.setenv("SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", "false")
def test_render_recent_conversation_prefers_latest_messages_under_budget(self): def test_render_recent_conversation_prefers_latest_messages_under_budget(self):
messages = [ messages = [
HumanMessage(content="old user context " * 40), HumanMessage(content="old user context " * 40),

View file

@ -7947,7 +7947,7 @@ wheels = [
[[package]] [[package]]
name = "surf-new-backend" name = "surf-new-backend"
version = "0.0.22" version = "0.0.23"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "alembic" }, { name = "alembic" },

View file

@ -1,7 +1,7 @@
{ {
"name": "surfsense_browser_extension", "name": "surfsense_browser_extension",
"displayName": "Surfsense Browser Extension", "displayName": "Surfsense Browser Extension",
"version": "0.0.22", "version": "0.0.23",
"description": "Extension to collect Browsing History for SurfSense.", "description": "Extension to collect Browsing History for SurfSense.",
"author": "https://github.com/MODSetter", "author": "https://github.com/MODSetter",
"engines": { "engines": {

View file

@ -1,6 +1,6 @@
{ {
"name": "surfsense-desktop", "name": "surfsense-desktop",
"version": "0.0.22", "version": "0.0.23",
"description": "SurfSense Desktop App", "description": "SurfSense Desktop App",
"main": "dist/main.js", "main": "dist/main.js",
"scripts": { "scripts": {

View file

@ -11,6 +11,7 @@ import { EditorSaveContext } from "@/components/editor/editor-save-context";
import { CitationKit, injectCitationNodes } from "@/components/editor/plugins/citation-kit"; import { CitationKit, injectCitationNodes } from "@/components/editor/plugins/citation-kit";
import { type EditorPreset, presetMap } from "@/components/editor/presets"; import { type EditorPreset, presetMap } from "@/components/editor/presets";
import { escapeMdxExpressions } from "@/components/editor/utils/escape-mdx"; import { escapeMdxExpressions } from "@/components/editor/utils/escape-mdx";
import { safeDeserializeMarkdown } from "@/components/editor/utils/safe-deserialize";
import { Editor, EditorContainer } from "@/components/ui/editor"; import { Editor, EditorContainer } from "@/components/ui/editor";
import { preprocessCitationMarkdown } from "@/lib/citations/citation-parser"; import { preprocessCitationMarkdown } from "@/lib/citations/citation-parser";
@ -169,15 +170,17 @@ export function PlateEditor({
: markdown : markdown
? (editor) => { ? (editor) => {
if (!enableCitations) { if (!enableCitations) {
return editor return safeDeserializeMarkdown(
.getApi(MarkdownPlugin) editor,
.markdown.deserialize(escapeMdxExpressions(markdown)); escapeMdxExpressions(markdown)
) as Value;
} }
const { content: rewritten, urlMap } = preprocessCitationMarkdown(markdown); const { content: rewritten, urlMap } = preprocessCitationMarkdown(markdown);
const value = editor const value = safeDeserializeMarkdown(
.getApi(MarkdownPlugin) editor,
.markdown.deserialize(escapeMdxExpressions(rewritten)); escapeMdxExpressions(rewritten)
return injectCitationNodes(value as Descendant[], urlMap) as Value; );
return injectCitationNodes(value, urlMap) as Value;
} }
: undefined, : undefined,
}); });
@ -200,14 +203,13 @@ export function PlateEditor({
let newValue: Descendant[]; let newValue: Descendant[];
if (enableCitations) { if (enableCitations) {
const { content: rewritten, urlMap } = preprocessCitationMarkdown(markdown); const { content: rewritten, urlMap } = preprocessCitationMarkdown(markdown);
const deserialized = editor const deserialized = safeDeserializeMarkdown(
.getApi(MarkdownPlugin) editor,
.markdown.deserialize(escapeMdxExpressions(rewritten)) as Descendant[]; escapeMdxExpressions(rewritten)
);
newValue = injectCitationNodes(deserialized, urlMap); newValue = injectCitationNodes(deserialized, urlMap);
} else { } else {
newValue = editor newValue = safeDeserializeMarkdown(editor, escapeMdxExpressions(markdown));
.getApi(MarkdownPlugin)
.markdown.deserialize(escapeMdxExpressions(markdown)) as Descendant[];
} }
editor.tf.reset(); editor.tf.reset();
editor.tf.setValue(newValue as Value); editor.tf.setValue(newValue as Value);

View file

@ -0,0 +1,64 @@
// ---------------------------------------------------------------------------
// Safe markdown deserialization for the Plate editor
// ---------------------------------------------------------------------------
// `remark-mdx` treats any HTML-like tag as JSX, so unbalanced inline HTML
// (very common in GitHub READMEs, web-scraped pages, PDF conversions) makes
// it throw "Expected a closing tag for `<a>`" and crash the editor.
//
// Per the MDX maintainers' guidance (mdx-js/mdx, ipikuka/next-mdx-remote-client
// #14), MDX is the wrong format for untrusted markdown and the recommended
// fix is to fall back to plain markdown parsing. `MarkdownPlugin.deserialize`
// accepts a per-call `remarkPlugins` override, so we can:
//
// 1. Try with `remarkMdx` (rich MDX features, e.g. JSX-style components).
// 2. On failure, retry without `remarkMdx` (lenient HTML, like GitHub).
// 3. As a last resort, render the raw source in a paragraph so the user
// never sees a crashed editor.
// ---------------------------------------------------------------------------
import { MarkdownPlugin, remarkMdx } from "@platejs/markdown";
import type { Descendant } from "platejs";
import remarkGfm from "remark-gfm";
import remarkMath from "remark-math";
import type { PlateEditorInstance } from "@/components/editor/plate-editor";
const STRICT_PLUGINS = [remarkGfm, remarkMath, remarkMdx];
const LENIENT_PLUGINS = [remarkGfm, remarkMath];
function plainTextFallback(markdown: string): Descendant[] {
return [
{
type: "p",
children: [{ text: markdown }],
} as unknown as Descendant,
];
}
/**
* Deserialize markdown into a Plate value, gracefully degrading when the
* MDX-strict parser rejects raw HTML. Always returns a renderable value;
* never throws.
*/
export function safeDeserializeMarkdown(
editor: PlateEditorInstance,
markdown: string
): Descendant[] {
const api = editor.getApi(MarkdownPlugin).markdown;
try {
return api.deserialize(markdown, { remarkPlugins: STRICT_PLUGINS }) as Descendant[];
} catch (mdxError) {
if (process.env.NODE_ENV !== "production") {
console.warn(
"[plate-editor] MDX parse failed, retrying without remark-mdx:",
mdxError
);
}
try {
return api.deserialize(markdown, { remarkPlugins: LENIENT_PLUGINS }) as Descendant[];
} catch (fallbackError) {
console.error("[plate-editor] markdown deserialize failed:", fallbackError);
return plainTextFallback(markdown);
}
}
}

View file

@ -29,6 +29,13 @@ const nextConfig: NextConfig = {
hostname: "**", hostname: "**",
}, },
], ],
// Allow remote SVGs (e.g. README badges from img.shields.io, trendshift.io,
// etc.) which are otherwise blocked by next/image. The CSP below sandboxes
// the SVG and forbids any embedded scripts, which is the mitigation
// recommended by Vercel's NEXTJS_SAFE_SVG_IMAGES conformance rule.
dangerouslyAllowSVG: true,
contentDispositionType: "attachment",
contentSecurityPolicy: "default-src 'self'; script-src 'none'; sandbox;",
}, },
experimental: { experimental: {
optimizePackageImports: [ optimizePackageImports: [

View file

@ -1,6 +1,6 @@
{ {
"name": "surfsense_web", "name": "surfsense_web",
"version": "0.0.22", "version": "0.0.23",
"private": true, "private": true,
"description": "SurfSense Frontend", "description": "SurfSense Frontend",
"scripts": { "scripts": {