mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 18:35:19 +02:00
Merge remote-tracking branch 'upstream/dev' into feat/e2e-testing-ci
This commit is contained in:
commit
3520877d80
351 changed files with 14437 additions and 5483 deletions
37
.cursor/skills/improve-codebase-architecture/DEEPENING.md
Normal file
37
.cursor/skills/improve-codebase-architecture/DEEPENING.md
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
# Deepening
|
||||
|
||||
How to deepen a cluster of shallow modules safely, given its dependencies. Assumes the vocabulary in [LANGUAGE.md](LANGUAGE.md) — **module**, **interface**, **seam**, **adapter**.
|
||||
|
||||
## Dependency categories
|
||||
|
||||
When assessing a candidate for deepening, classify its dependencies. The category determines how the deepened module is tested across its seam.
|
||||
|
||||
### 1. In-process
|
||||
|
||||
Pure computation, in-memory state, no I/O. Always deepenable — merge the modules and test through the new interface directly. No adapter needed.
|
||||
|
||||
### 2. Local-substitutable
|
||||
|
||||
Dependencies that have local test stand-ins (PGLite for Postgres, in-memory filesystem). Deepenable if the stand-in exists. The deepened module is tested with the stand-in running in the test suite. The seam is internal; no port at the module's external interface.
|
||||
|
||||
### 3. Remote but owned (Ports & Adapters)
|
||||
|
||||
Your own services across a network boundary (microservices, internal APIs). Define a **port** (interface) at the seam. The deep module owns the logic; the transport is injected as an **adapter**. Tests use an in-memory adapter. Production uses an HTTP/gRPC/queue adapter.
|
||||
|
||||
Recommendation shape: *"Define a port at the seam, implement an HTTP adapter for production and an in-memory adapter for testing, so the logic sits in one deep module even though it's deployed across a network."*
|
||||
|
||||
### 4. True external (Mock)
|
||||
|
||||
Third-party services (Stripe, Twilio, etc.) you don't control. The deepened module takes the external dependency as an injected port; tests provide a mock adapter.
|
||||
|
||||
## Seam discipline
|
||||
|
||||
- **One adapter means a hypothetical seam. Two adapters means a real one.** Don't introduce a port unless at least two adapters are justified (typically production + test). A single-adapter seam is just indirection.
|
||||
- **Internal seams vs external seams.** A deep module can have internal seams (private to its implementation, used by its own tests) as well as the external seam at its interface. Don't expose internal seams through the interface just because tests use them.
|
||||
|
||||
## Testing strategy: replace, don't layer
|
||||
|
||||
- Old unit tests on shallow modules become waste once tests at the deepened module's interface exist — delete them.
|
||||
- Write new tests at the deepened module's interface. The **interface is the test surface**.
|
||||
- Tests assert on observable outcomes through the interface, not internal state.
|
||||
- Tests should survive internal refactors — they describe behaviour, not implementation. If a test has to change when the implementation changes, it's testing past the interface.
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
# Interface Design
|
||||
|
||||
When the user wants to explore alternative interfaces for a chosen deepening candidate, use this parallel sub-agent pattern. Based on "Design It Twice" (Ousterhout) — your first idea is unlikely to be the best.
|
||||
|
||||
Uses the vocabulary in [LANGUAGE.md](LANGUAGE.md) — **module**, **interface**, **seam**, **adapter**, **leverage**.
|
||||
|
||||
## Process
|
||||
|
||||
### 1. Frame the problem space
|
||||
|
||||
Before spawning sub-agents, write a user-facing explanation of the problem space for the chosen candidate:
|
||||
|
||||
- The constraints any new interface would need to satisfy
|
||||
- The dependencies it would rely on, and which category they fall into (see [DEEPENING.md](DEEPENING.md))
|
||||
- A rough illustrative code sketch to ground the constraints — not a proposal, just a way to make the constraints concrete
|
||||
|
||||
Show this to the user, then immediately proceed to Step 2. The user reads and thinks while the sub-agents work in parallel.
|
||||
|
||||
### 2. Spawn sub-agents
|
||||
|
||||
Spawn 3+ sub-agents in parallel using the Agent tool. Each must produce a **radically different** interface for the deepened module.
|
||||
|
||||
Prompt each sub-agent with a separate technical brief (file paths, coupling details, dependency category from [DEEPENING.md](DEEPENING.md), what sits behind the seam). The brief is independent of the user-facing problem-space explanation in Step 1. Give each agent a different design constraint:
|
||||
|
||||
- Agent 1: "Minimize the interface — aim for 1–3 entry points max. Maximise leverage per entry point."
|
||||
- Agent 2: "Maximise flexibility — support many use cases and extension."
|
||||
- Agent 3: "Optimise for the most common caller — make the default case trivial."
|
||||
- Agent 4 (if applicable): "Design around ports & adapters for cross-seam dependencies."
|
||||
|
||||
Include both [LANGUAGE.md](LANGUAGE.md) vocabulary and CONTEXT.md vocabulary in the brief so each sub-agent names things consistently with the architecture language and the project's domain language.
|
||||
|
||||
Each sub-agent outputs:
|
||||
|
||||
1. Interface (types, methods, params — plus invariants, ordering, error modes)
|
||||
2. Usage example showing how callers use it
|
||||
3. What the implementation hides behind the seam
|
||||
4. Dependency strategy and adapters (see [DEEPENING.md](DEEPENING.md))
|
||||
5. Trade-offs — where leverage is high, where it's thin
|
||||
|
||||
### 3. Present and compare
|
||||
|
||||
Present designs sequentially so the user can absorb each one, then compare them in prose. Contrast by **depth** (leverage at the interface), **locality** (where change concentrates), and **seam placement**.
|
||||
|
||||
After comparing, give your own recommendation: which design you think is strongest and why. If elements from different designs would combine well, propose a hybrid. Be opinionated — the user wants a strong read, not a menu.
|
||||
53
.cursor/skills/improve-codebase-architecture/LANGUAGE.md
Normal file
53
.cursor/skills/improve-codebase-architecture/LANGUAGE.md
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
# Language
|
||||
|
||||
Shared vocabulary for every suggestion this skill makes. Use these terms exactly — don't substitute "component," "service," "API," or "boundary." Consistent language is the whole point.
|
||||
|
||||
## Terms
|
||||
|
||||
**Module**
|
||||
Anything with an interface and an implementation. Deliberately scale-agnostic — applies equally to a function, class, package, or tier-spanning slice.
|
||||
_Avoid_: unit, component, service.
|
||||
|
||||
**Interface**
|
||||
Everything a caller must know to use the module correctly. Includes the type signature, but also invariants, ordering constraints, error modes, required configuration, and performance characteristics.
|
||||
_Avoid_: API, signature (too narrow — those refer only to the type-level surface).
|
||||
|
||||
**Implementation**
|
||||
What's inside a module — its body of code. Distinct from **Adapter**: a thing can be a small adapter with a large implementation (a Postgres repo) or a large adapter with a small implementation (an in-memory fake). Reach for "adapter" when the seam is the topic; "implementation" otherwise.
|
||||
|
||||
**Depth**
|
||||
Leverage at the interface — the amount of behaviour a caller (or test) can exercise per unit of interface they have to learn. A module is **deep** when a large amount of behaviour sits behind a small interface. A module is **shallow** when the interface is nearly as complex as the implementation.
|
||||
|
||||
**Seam** _(from Michael Feathers)_
|
||||
A place where you can alter behaviour without editing in that place. The *location* at which a module's interface lives. Choosing where to put the seam is its own design decision, distinct from what goes behind it.
|
||||
_Avoid_: boundary (overloaded with DDD's bounded context).
|
||||
|
||||
**Adapter**
|
||||
A concrete thing that satisfies an interface at a seam. Describes *role* (what slot it fills), not substance (what's inside).
|
||||
|
||||
**Leverage**
|
||||
What callers get from depth. More capability per unit of interface they have to learn. One implementation pays back across N call sites and M tests.
|
||||
|
||||
**Locality**
|
||||
What maintainers get from depth. Change, bugs, knowledge, and verification concentrate at one place rather than spreading across callers. Fix once, fixed everywhere.
|
||||
|
||||
## Principles
|
||||
|
||||
- **Depth is a property of the interface, not the implementation.** A deep module can be internally composed of small, mockable, swappable parts — they just aren't part of the interface. A module can have **internal seams** (private to its implementation, used by its own tests) as well as the **external seam** at its interface.
|
||||
- **The deletion test.** Imagine deleting the module. If complexity vanishes, the module wasn't hiding anything (it was a pass-through). If complexity reappears across N callers, the module was earning its keep.
|
||||
- **The interface is the test surface.** Callers and tests cross the same seam. If you want to test *past* the interface, the module is probably the wrong shape.
|
||||
- **One adapter means a hypothetical seam. Two adapters means a real one.** Don't introduce a seam unless something actually varies across it.
|
||||
|
||||
## Relationships
|
||||
|
||||
- A **Module** has exactly one **Interface** (the surface it presents to callers and tests).
|
||||
- **Depth** is a property of a **Module**, measured against its **Interface**.
|
||||
- A **Seam** is where a **Module**'s **Interface** lives.
|
||||
- An **Adapter** sits at a **Seam** and satisfies the **Interface**.
|
||||
- **Depth** produces **Leverage** for callers and **Locality** for maintainers.
|
||||
|
||||
## Rejected framings
|
||||
|
||||
- **Depth as ratio of implementation-lines to interface-lines** (Ousterhout): rewards padding the implementation. We use depth-as-leverage instead.
|
||||
- **"Interface" as the TypeScript `interface` keyword or a class's public methods**: too narrow — interface here includes every fact a caller must know.
|
||||
- **"Boundary"**: overloaded with DDD's bounded context. Say **seam** or **interface**.
|
||||
71
.cursor/skills/improve-codebase-architecture/SKILL.md
Normal file
71
.cursor/skills/improve-codebase-architecture/SKILL.md
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
---
|
||||
name: improve-codebase-architecture
|
||||
description: Find deepening opportunities in a codebase, informed by the domain language in CONTEXT.md and the decisions in docs/adr/. Use when the user wants to improve architecture, find refactoring opportunities, consolidate tightly-coupled modules, or make a codebase more testable and AI-navigable.
|
||||
---
|
||||
|
||||
# Improve Codebase Architecture
|
||||
|
||||
Surface architectural friction and propose **deepening opportunities** — refactors that turn shallow modules into deep ones. The aim is testability and AI-navigability.
|
||||
|
||||
## Glossary
|
||||
|
||||
Use these terms exactly in every suggestion. Consistent language is the point — don't drift into "component," "service," "API," or "boundary." Full definitions in [LANGUAGE.md](LANGUAGE.md).
|
||||
|
||||
- **Module** — anything with an interface and an implementation (function, class, package, slice).
|
||||
- **Interface** — everything a caller must know to use the module: types, invariants, error modes, ordering, config. Not just the type signature.
|
||||
- **Implementation** — the code inside.
|
||||
- **Depth** — leverage at the interface: a lot of behaviour behind a small interface. **Deep** = high leverage. **Shallow** = interface nearly as complex as the implementation.
|
||||
- **Seam** — where an interface lives; a place behaviour can be altered without editing in place. (Use this, not "boundary.")
|
||||
- **Adapter** — a concrete thing satisfying an interface at a seam.
|
||||
- **Leverage** — what callers get from depth.
|
||||
- **Locality** — what maintainers get from depth: change, bugs, knowledge concentrated in one place.
|
||||
|
||||
Key principles (see [LANGUAGE.md](LANGUAGE.md) for the full list):
|
||||
|
||||
- **Deletion test**: imagine deleting the module. If complexity vanishes, it was a pass-through. If complexity reappears across N callers, it was earning its keep.
|
||||
- **The interface is the test surface.**
|
||||
- **One adapter = hypothetical seam. Two adapters = real seam.**
|
||||
|
||||
This skill is _informed_ by the project's domain model. The domain language gives names to good seams; ADRs record decisions the skill should not re-litigate.
|
||||
|
||||
## Process
|
||||
|
||||
### 1. Explore
|
||||
|
||||
Read the project's domain glossary and any ADRs in the area you're touching first.
|
||||
|
||||
Then use the Agent tool with `subagent_type=Explore` to walk the codebase. Don't follow rigid heuristics — explore organically and note where you experience friction:
|
||||
|
||||
- Where does understanding one concept require bouncing between many small modules?
|
||||
- Where are modules **shallow** — interface nearly as complex as the implementation?
|
||||
- Where have pure functions been extracted just for testability, but the real bugs hide in how they're called (no **locality**)?
|
||||
- Where do tightly-coupled modules leak across their seams?
|
||||
- Which parts of the codebase are untested, or hard to test through their current interface?
|
||||
|
||||
Apply the **deletion test** to anything you suspect is shallow: would deleting it concentrate complexity, or just move it? A "yes, concentrates" is the signal you want.
|
||||
|
||||
### 2. Present candidates
|
||||
|
||||
Present a numbered list of deepening opportunities. For each candidate:
|
||||
|
||||
- **Files** — which files/modules are involved
|
||||
- **Problem** — why the current architecture is causing friction
|
||||
- **Solution** — plain English description of what would change
|
||||
- **Benefits** — explained in terms of locality and leverage, and also in how tests would improve
|
||||
|
||||
**Use CONTEXT.md vocabulary for the domain, and [LANGUAGE.md](LANGUAGE.md) vocabulary for the architecture.** If `CONTEXT.md` defines "Order," talk about "the Order intake module" — not "the FooBarHandler," and not "the Order service."
|
||||
|
||||
**ADR conflicts**: if a candidate contradicts an existing ADR, only surface it when the friction is real enough to warrant revisiting the ADR. Mark it clearly (e.g. _"contradicts ADR-0007 — but worth reopening because…"_). Don't list every theoretical refactor an ADR forbids.
|
||||
|
||||
Do NOT propose interfaces yet. Ask the user: "Which of these would you like to explore?"
|
||||
|
||||
### 3. Grilling loop
|
||||
|
||||
Once the user picks a candidate, drop into a grilling conversation. Walk the design tree with them — constraints, dependencies, the shape of the deepened module, what sits behind the seam, what tests survive.
|
||||
|
||||
Side effects happen inline as decisions crystallize:
|
||||
|
||||
- **Naming a deepened module after a concept not in `CONTEXT.md`?** Add the term to `CONTEXT.md` — same discipline as `/grill-with-docs` (see [CONTEXT-FORMAT.md](../grill-with-docs/CONTEXT-FORMAT.md)). Create the file lazily if it doesn't exist.
|
||||
- **Sharpening a fuzzy term during the conversation?** Update `CONTEXT.md` right there.
|
||||
- **User rejects the candidate with a load-bearing reason?** Offer an ADR, framed as: _"Want me to record this as an ADR so future architecture reviews don't re-suggest it?"_ Only offer when the reason would actually be needed by a future explorer to avoid re-suggesting the same thing — skip ephemeral reasons ("not worth it right now") and self-evident ones. See [ADR-FORMAT.md](../grill-with-docs/ADR-FORMAT.md).
|
||||
- **Want to explore alternative interfaces for the deepened module?** See [INTERFACE-DESIGN.md](INTERFACE-DESIGN.md).
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -15,3 +15,4 @@ surfsense_web/playwright/.auth/
|
|||
surfsense_web/playwright-report/
|
||||
surfsense_web/test-results/
|
||||
surfsense_web/blob-report/
|
||||
hermes-agent/
|
||||
|
|
|
|||
7
.vscode/settings.json
vendored
7
.vscode/settings.json
vendored
|
|
@ -1,4 +1,9 @@
|
|||
{
|
||||
"biome.configurationPath": "./surfsense_web/biome.json",
|
||||
"deepscan.ignoreConfirmWarning": true
|
||||
"deepscan.ignoreConfirmWarning": true,
|
||||
"python.defaultInterpreterPath": "${workspaceFolder}/surfsense_backend/.venv/bin/python",
|
||||
"basedpyright.analysis.extraPaths": [
|
||||
"${workspaceFolder}/surfsense_backend"
|
||||
],
|
||||
"python-envs.pythonProjects": []
|
||||
}
|
||||
2
VERSION
2
VERSION
|
|
@ -1 +1 @@
|
|||
0.0.22
|
||||
0.0.23
|
||||
|
|
|
|||
|
|
@ -324,7 +324,6 @@ SURFSENSE_ENABLE_ACTION_LOG=true
|
|||
SURFSENSE_ENABLE_REVERT_ROUTE=true
|
||||
SURFSENSE_ENABLE_PERMISSION=true
|
||||
SURFSENSE_ENABLE_DOOM_LOOP=true
|
||||
SURFSENSE_ENABLE_STREAM_PARITY_V2=true
|
||||
|
||||
# Periodic connector sync interval (default: 5m)
|
||||
# SCHEDULE_CHECKER_INTERVAL=5m
|
||||
|
|
|
|||
|
|
@ -46,6 +46,12 @@
|
|||
"sourceType": "github",
|
||||
"computedHash": "ddd61f32254be1303ce4b7be5d507c932de4af53489a0ebb1309bf61de99018c"
|
||||
},
|
||||
"improve-codebase-architecture": {
|
||||
"source": "mattpocock/skills",
|
||||
"sourceType": "github",
|
||||
"skillPath": "skills/engineering/improve-codebase-architecture/SKILL.md",
|
||||
"computedHash": "2da1d23b8f53cfe67f2e0b68924ab9f4ec400bb6480de097007eeaeb517d1722"
|
||||
},
|
||||
"internal-linking-optimizer": {
|
||||
"source": "aaron-he-zhu/seo-geo-claude-skills",
|
||||
"sourceType": "github",
|
||||
|
|
|
|||
|
|
@ -315,14 +315,6 @@ LANGSMITH_PROJECT=surfsense
|
|||
# SURFSENSE_ENABLE_ACTION_LOG=false
|
||||
# SURFSENSE_ENABLE_REVERT_ROUTE=false # Backend-only; flip when UI ships
|
||||
|
||||
# Streaming parity v2 — opt in to LangChain's structured AIMessageChunk
|
||||
# content (typed reasoning blocks, tool-input deltas) and propagate the
|
||||
# real tool_call_id to the SSE layer. When OFF, the stream falls back to
|
||||
# the str-only text path and synthetic "call_<run_id>" tool-call ids.
|
||||
# Schema migrations 135/136 ship unconditionally because they are
|
||||
# forward-compatible.
|
||||
# SURFSENSE_ENABLE_STREAM_PARITY_V2=false
|
||||
|
||||
# Plugins
|
||||
# SURFSENSE_ENABLE_PLUGIN_LOADER=false
|
||||
# Comma-separated allowlist of plugin entry-point names
|
||||
|
|
|
|||
|
|
@ -2,6 +2,6 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from .main_agent import create_surfsense_deep_agent
|
||||
from .main_agent import create_multi_agent_chat_deep_agent
|
||||
|
||||
__all__ = ["create_surfsense_deep_agent"]
|
||||
__all__ = ["create_multi_agent_chat_deep_agent"]
|
||||
|
|
|
|||
|
|
@ -2,6 +2,6 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from .runtime import create_surfsense_deep_agent
|
||||
from .runtime import create_multi_agent_chat_deep_agent
|
||||
|
||||
__all__ = ["create_surfsense_deep_agent"]
|
||||
__all__ = ["create_multi_agent_chat_deep_agent"]
|
||||
|
|
|
|||
|
|
@ -11,6 +11,9 @@ from langchain_core.language_models import BaseChatModel
|
|||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.multi_agent_chat.middleware import (
|
||||
build_main_agent_deepagent_middleware,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
)
|
||||
|
|
@ -19,8 +22,6 @@ from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
|||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from .middleware import build_main_agent_deepagent_middleware
|
||||
|
||||
|
||||
def build_compiled_agent_graph_sync(
|
||||
*,
|
||||
|
|
|
|||
|
|
@ -1,7 +0,0 @@
|
|||
"""Main-agent graph middleware assembly (SurfSense + LangChain + deepagents)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .deepagent_stack import build_main_agent_deepagent_middleware
|
||||
|
||||
__all__ = ["build_main_agent_deepagent_middleware"]
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
"""RunnableConfig wiring for nested subagent invocations.
|
||||
|
||||
Forwards the parent's ``runtime.config`` (thread_id, …) into the subagent and
|
||||
exposes the side-channel ``stream_resume_chat`` uses to ferry resume payloads.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
from .constants import DEFAULT_SUBAGENT_RECURSION_LIMIT
|
||||
|
||||
|
||||
def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]:
|
||||
"""RunnableConfig for the nested invoke; raises ``recursion_limit`` to the parent's budget."""
|
||||
merged: dict[str, Any] = dict(runtime.config) if runtime.config else {}
|
||||
current_limit = merged.get("recursion_limit")
|
||||
try:
|
||||
current_int = int(current_limit) if current_limit is not None else 0
|
||||
except (TypeError, ValueError):
|
||||
current_int = 0
|
||||
if current_int < DEFAULT_SUBAGENT_RECURSION_LIMIT:
|
||||
merged["recursion_limit"] = DEFAULT_SUBAGENT_RECURSION_LIMIT
|
||||
return merged
|
||||
|
||||
|
||||
def consume_surfsense_resume(runtime: ToolRuntime) -> Any:
|
||||
"""Pop the resume payload; siblings share ``configurable`` by reference."""
|
||||
cfg = runtime.config or {}
|
||||
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
|
||||
if not isinstance(configurable, dict):
|
||||
return None
|
||||
return configurable.pop("surfsense_resume_value", None)
|
||||
|
||||
|
||||
def has_surfsense_resume(runtime: ToolRuntime) -> bool:
|
||||
"""True iff a resume payload is queued on this runtime (non-destructive)."""
|
||||
cfg = runtime.config or {}
|
||||
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
|
||||
if not isinstance(configurable, dict):
|
||||
return False
|
||||
return "surfsense_resume_value" in configurable
|
||||
|
|
@ -1,506 +0,0 @@
|
|||
"""Assemble the main-agent deep-agent middleware list (LangChain + SurfSense + deepagents)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from deepagents import SubAgent
|
||||
from deepagents.backends import StateBackend
|
||||
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||
from deepagents.middleware.skills import SkillsMiddleware
|
||||
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
|
||||
from langchain.agents.middleware import (
|
||||
LLMToolSelectorMiddleware,
|
||||
ModelCallLimitMiddleware,
|
||||
ModelFallbackMiddleware,
|
||||
TodoListMiddleware,
|
||||
ToolCallLimitMiddleware,
|
||||
)
|
||||
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.multi_agent_chat.subagents import (
|
||||
build_subagents,
|
||||
get_subagents_to_exclude,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
)
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import (
|
||||
ActionLogMiddleware,
|
||||
AnonymousDocumentMiddleware,
|
||||
BusyMutexMiddleware,
|
||||
ClearToolUsesEdit,
|
||||
DedupHITLToolCallsMiddleware,
|
||||
DoomLoopMiddleware,
|
||||
FileIntentMiddleware,
|
||||
KnowledgeBasePersistenceMiddleware,
|
||||
KnowledgePriorityMiddleware,
|
||||
KnowledgeTreeMiddleware,
|
||||
MemoryInjectionMiddleware,
|
||||
NoopInjectionMiddleware,
|
||||
OtelSpanMiddleware,
|
||||
PermissionMiddleware,
|
||||
RetryAfterMiddleware,
|
||||
SpillingContextEditingMiddleware,
|
||||
SpillToBackendEdit,
|
||||
SurfSenseFilesystemMiddleware,
|
||||
ToolCallNameRepairMiddleware,
|
||||
build_skills_backend_factory,
|
||||
create_surfsense_compaction_middleware,
|
||||
default_skills_sources,
|
||||
)
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
from app.agents.new_chat.plugin_loader import (
|
||||
PluginContext,
|
||||
load_allowed_plugin_names_from_env,
|
||||
load_plugin_middlewares,
|
||||
)
|
||||
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from ...context_prune.prune_tool_names import safe_exclude_tools
|
||||
from .checkpointed_subagent_middleware import SurfSenseCheckpointedSubAgentMiddleware
|
||||
|
||||
|
||||
def build_main_agent_deepagent_middleware(
|
||||
*,
|
||||
llm: BaseChatModel,
|
||||
tools: Sequence[BaseTool],
|
||||
backend_resolver: Any,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
visibility: ChatVisibility,
|
||||
anon_session_id: str | None,
|
||||
available_connectors: list[str] | None,
|
||||
available_document_types: list[str] | None,
|
||||
mentioned_document_ids: list[int] | None,
|
||||
max_input_tokens: int | None,
|
||||
flags: AgentFeatureFlags,
|
||||
subagent_dependencies: dict[str, Any],
|
||||
checkpointer: Checkpointer,
|
||||
mcp_tools_by_agent: dict[str, ToolsPermissions] | None = None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
) -> list[Any]:
|
||||
"""Build ordered middleware for ``create_agent`` (Nones already stripped)."""
|
||||
_memory_middleware = MemoryInjectionMiddleware(
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
thread_visibility=visibility,
|
||||
)
|
||||
|
||||
gp_middleware = [
|
||||
TodoListMiddleware(),
|
||||
_memory_middleware,
|
||||
FileIntentMiddleware(llm=llm),
|
||||
SurfSenseFilesystemMiddleware(
|
||||
backend=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
create_surfsense_compaction_middleware(llm, StateBackend),
|
||||
PatchToolCallsMiddleware(),
|
||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
||||
]
|
||||
|
||||
# Build permission rulesets up front so the GP subagent can mirror ``ask``
|
||||
# rules into ``interrupt_on``: tool calls emitted from within ``task`` runs
|
||||
# never reach the parent's ``PermissionMiddleware``.
|
||||
is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER
|
||||
permission_enabled = flags.enable_permission and not flags.disable_new_agent_stack
|
||||
permission_rulesets: list[Ruleset] = []
|
||||
if permission_enabled or is_desktop_fs:
|
||||
permission_rulesets.append(
|
||||
Ruleset(
|
||||
rules=[Rule(permission="*", pattern="*", action="allow")],
|
||||
origin="surfsense_defaults",
|
||||
)
|
||||
)
|
||||
if is_desktop_fs:
|
||||
permission_rulesets.append(
|
||||
Ruleset(
|
||||
rules=[
|
||||
Rule(permission="rm", pattern="*", action="ask"),
|
||||
Rule(permission="rmdir", pattern="*", action="ask"),
|
||||
Rule(permission="move_file", pattern="*", action="ask"),
|
||||
Rule(permission="edit_file", pattern="*", action="ask"),
|
||||
Rule(permission="write_file", pattern="*", action="ask"),
|
||||
],
|
||||
origin="desktop_safety",
|
||||
)
|
||||
)
|
||||
|
||||
# Tools that self-prompt via ``request_approval`` must not also appear
|
||||
# as ``ask`` rules — that would double-prompt the user for one call.
|
||||
_tool_names_in_use = {t.name for t in tools}
|
||||
|
||||
# Deny parent-bound tools whose ``required_connector`` is missing.
|
||||
# No-op today (connector subagents are pruned upstream); guards future
|
||||
# additions to the parent's tool list.
|
||||
if permission_enabled:
|
||||
_available_set = set(available_connectors or [])
|
||||
_synthesized: list[Rule] = []
|
||||
for tool_def in BUILTIN_TOOLS:
|
||||
if tool_def.name not in _tool_names_in_use:
|
||||
continue
|
||||
rc = tool_def.required_connector
|
||||
if rc and rc not in _available_set:
|
||||
_synthesized.append(
|
||||
Rule(permission=tool_def.name, pattern="*", action="deny")
|
||||
)
|
||||
if _synthesized:
|
||||
permission_rulesets.append(
|
||||
Ruleset(rules=_synthesized, origin="connector_synthesized")
|
||||
)
|
||||
gp_interrupt_on: dict[str, bool] = {
|
||||
rule.permission: True
|
||||
for rs in permission_rulesets
|
||||
for rule in rs.rules
|
||||
if rule.action == "ask" and rule.permission in _tool_names_in_use
|
||||
}
|
||||
|
||||
general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key]
|
||||
**GENERAL_PURPOSE_SUBAGENT,
|
||||
"model": llm,
|
||||
"tools": tools,
|
||||
"middleware": gp_middleware,
|
||||
}
|
||||
if gp_interrupt_on:
|
||||
general_purpose_spec["interrupt_on"] = gp_interrupt_on
|
||||
|
||||
# Deny-only on subagents: ``task`` runs bypass the parent's
|
||||
# PermissionMiddleware, while bucket-based ask gates own the ask path.
|
||||
subagent_deny_rulesets: list[Ruleset] = [
|
||||
Ruleset(
|
||||
rules=[r for r in rs.rules if r.action == "deny"],
|
||||
origin=rs.origin,
|
||||
)
|
||||
for rs in permission_rulesets
|
||||
]
|
||||
subagent_deny_rulesets = [rs for rs in subagent_deny_rulesets if rs.rules]
|
||||
|
||||
subagent_deny_permission_mw: PermissionMiddleware | None = (
|
||||
PermissionMiddleware(rulesets=subagent_deny_rulesets)
|
||||
if subagent_deny_rulesets
|
||||
else None
|
||||
)
|
||||
|
||||
if subagent_deny_permission_mw is not None:
|
||||
# Run deny check on already-repaired tool calls; insert before
|
||||
# PatchToolCallsMiddleware (append if the slot moves).
|
||||
_patch_idx = next(
|
||||
(
|
||||
i
|
||||
for i, m in enumerate(gp_middleware)
|
||||
if isinstance(m, PatchToolCallsMiddleware)
|
||||
),
|
||||
len(gp_middleware),
|
||||
)
|
||||
gp_middleware.insert(_patch_idx, subagent_deny_permission_mw)
|
||||
|
||||
registry_subagents: list[SubAgent] = []
|
||||
try:
|
||||
subagent_extra_middleware: list[Any] = [
|
||||
TodoListMiddleware(),
|
||||
SurfSenseFilesystemMiddleware(
|
||||
backend=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
]
|
||||
if subagent_deny_permission_mw is not None:
|
||||
subagent_extra_middleware.append(subagent_deny_permission_mw)
|
||||
registry_subagents = build_subagents(
|
||||
dependencies=subagent_dependencies,
|
||||
model=llm,
|
||||
extra_middleware=subagent_extra_middleware,
|
||||
mcp_tools_by_agent=mcp_tools_by_agent or {},
|
||||
exclude=get_subagents_to_exclude(available_connectors),
|
||||
disabled_tools=disabled_tools,
|
||||
)
|
||||
logging.info(
|
||||
"Registry subagents: %s",
|
||||
[s["name"] for s in registry_subagents],
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Registry subagent build failed")
|
||||
raise
|
||||
|
||||
subagent_specs: list[SubAgent] = [general_purpose_spec, *registry_subagents]
|
||||
|
||||
summarization_mw = create_surfsense_compaction_middleware(llm, StateBackend)
|
||||
|
||||
context_edit_mw = None
|
||||
if (
|
||||
flags.enable_context_editing
|
||||
and not flags.disable_new_agent_stack
|
||||
and max_input_tokens
|
||||
):
|
||||
spill_edit = SpillToBackendEdit(
|
||||
trigger=int(max_input_tokens * 0.55),
|
||||
clear_at_least=int(max_input_tokens * 0.15),
|
||||
keep=5,
|
||||
exclude_tools=safe_exclude_tools(tools),
|
||||
clear_tool_inputs=True,
|
||||
)
|
||||
clear_edit = ClearToolUsesEdit(
|
||||
trigger=int(max_input_tokens * 0.55),
|
||||
clear_at_least=int(max_input_tokens * 0.15),
|
||||
keep=5,
|
||||
exclude_tools=safe_exclude_tools(tools),
|
||||
clear_tool_inputs=True,
|
||||
placeholder="[cleared - older tool output trimmed for context]",
|
||||
)
|
||||
context_edit_mw = SpillingContextEditingMiddleware(
|
||||
edits=[spill_edit, clear_edit],
|
||||
backend_resolver=backend_resolver,
|
||||
)
|
||||
|
||||
retry_mw = (
|
||||
RetryAfterMiddleware(max_retries=3)
|
||||
if flags.enable_retry_after and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
fallback_mw: ModelFallbackMiddleware | None = None
|
||||
if flags.enable_model_fallback and not flags.disable_new_agent_stack:
|
||||
try:
|
||||
fallback_mw = ModelFallbackMiddleware(
|
||||
"openai:gpt-4o-mini",
|
||||
"anthropic:claude-3-5-haiku-20241022",
|
||||
)
|
||||
except Exception:
|
||||
logging.warning("ModelFallbackMiddleware init failed; skipping.")
|
||||
fallback_mw = None
|
||||
model_call_limit_mw = (
|
||||
ModelCallLimitMiddleware(
|
||||
thread_limit=120,
|
||||
run_limit=80,
|
||||
exit_behavior="end",
|
||||
)
|
||||
if flags.enable_model_call_limit and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
tool_call_limit_mw = (
|
||||
ToolCallLimitMiddleware(
|
||||
thread_limit=300, run_limit=80, exit_behavior="continue"
|
||||
)
|
||||
if flags.enable_tool_call_limit and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
|
||||
noop_mw = (
|
||||
NoopInjectionMiddleware()
|
||||
if flags.enable_compaction_v2 and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
|
||||
repair_mw = None
|
||||
if flags.enable_tool_call_repair and not flags.disable_new_agent_stack:
|
||||
registered_names: set[str] = {t.name for t in tools}
|
||||
registered_names |= {
|
||||
"write_todos",
|
||||
"ls",
|
||||
"read_file",
|
||||
"write_file",
|
||||
"edit_file",
|
||||
"glob",
|
||||
"grep",
|
||||
"execute",
|
||||
"task",
|
||||
"mkdir",
|
||||
"cd",
|
||||
"pwd",
|
||||
"move_file",
|
||||
"rm",
|
||||
"rmdir",
|
||||
"list_tree",
|
||||
"execute_code",
|
||||
}
|
||||
repair_mw = ToolCallNameRepairMiddleware(
|
||||
registered_tool_names=registered_names,
|
||||
fuzzy_match_threshold=None,
|
||||
)
|
||||
|
||||
doom_loop_mw = (
|
||||
DoomLoopMiddleware(threshold=3)
|
||||
if flags.enable_doom_loop and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
|
||||
permission_mw: PermissionMiddleware | None = (
|
||||
PermissionMiddleware(rulesets=permission_rulesets)
|
||||
if permission_rulesets
|
||||
else None
|
||||
)
|
||||
|
||||
action_log_mw: ActionLogMiddleware | None = None
|
||||
if (
|
||||
flags.enable_action_log
|
||||
and not flags.disable_new_agent_stack
|
||||
and thread_id is not None
|
||||
):
|
||||
try:
|
||||
tool_defs_by_name = {td.name: td for td in BUILTIN_TOOLS}
|
||||
action_log_mw = ActionLogMiddleware(
|
||||
thread_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
tool_definitions=tool_defs_by_name,
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logging.warning(
|
||||
"ActionLogMiddleware init failed; running without it.",
|
||||
exc_info=True,
|
||||
)
|
||||
action_log_mw = None
|
||||
|
||||
busy_mutex_mw: BusyMutexMiddleware | None = (
|
||||
BusyMutexMiddleware()
|
||||
if flags.enable_busy_mutex and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
|
||||
otel_mw: OtelSpanMiddleware | None = (
|
||||
OtelSpanMiddleware()
|
||||
if flags.enable_otel and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
|
||||
plugin_middlewares: list[Any] = []
|
||||
if flags.enable_plugin_loader and not flags.disable_new_agent_stack:
|
||||
try:
|
||||
allowed_names = load_allowed_plugin_names_from_env()
|
||||
if allowed_names:
|
||||
plugin_middlewares = load_plugin_middlewares(
|
||||
PluginContext.build(
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_visibility=visibility,
|
||||
llm=llm,
|
||||
),
|
||||
allowed_plugin_names=allowed_names,
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logging.warning(
|
||||
"Plugin loader failed; continuing without plugins.",
|
||||
exc_info=True,
|
||||
)
|
||||
plugin_middlewares = []
|
||||
|
||||
skills_mw: SkillsMiddleware | None = None
|
||||
if flags.enable_skills and not flags.disable_new_agent_stack:
|
||||
try:
|
||||
skills_factory = build_skills_backend_factory(
|
||||
search_space_id=search_space_id
|
||||
if filesystem_mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
)
|
||||
skills_mw = SkillsMiddleware(
|
||||
backend=skills_factory,
|
||||
sources=default_skills_sources(),
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logging.warning("SkillsMiddleware init failed; skipping: %s", exc)
|
||||
skills_mw = None
|
||||
|
||||
selector_mw: LLMToolSelectorMiddleware | None = None
|
||||
if (
|
||||
flags.enable_llm_tool_selector
|
||||
and not flags.disable_new_agent_stack
|
||||
and len(tools) > 30
|
||||
):
|
||||
try:
|
||||
selector_mw = LLMToolSelectorMiddleware(
|
||||
model="openai:gpt-4o-mini",
|
||||
max_tools=12,
|
||||
always_include=[
|
||||
name
|
||||
for name in (
|
||||
"update_memory",
|
||||
"get_connected_accounts",
|
||||
"scrape_webpage",
|
||||
)
|
||||
if name in {t.name for t in tools}
|
||||
],
|
||||
)
|
||||
except Exception:
|
||||
logging.warning("LLMToolSelectorMiddleware init failed; skipping.")
|
||||
selector_mw = None
|
||||
|
||||
deepagent_middleware = [
|
||||
busy_mutex_mw,
|
||||
otel_mw,
|
||||
TodoListMiddleware(),
|
||||
_memory_middleware,
|
||||
AnonymousDocumentMiddleware(
|
||||
anon_session_id=anon_session_id,
|
||||
)
|
||||
if filesystem_mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
KnowledgeTreeMiddleware(
|
||||
search_space_id=search_space_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
llm=llm,
|
||||
)
|
||||
if filesystem_mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
KnowledgePriorityMiddleware(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
),
|
||||
FileIntentMiddleware(llm=llm),
|
||||
SurfSenseFilesystemMiddleware(
|
||||
backend=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
KnowledgeBasePersistenceMiddleware(
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
if filesystem_mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
skills_mw,
|
||||
SurfSenseCheckpointedSubAgentMiddleware(
|
||||
checkpointer=checkpointer,
|
||||
backend=StateBackend,
|
||||
subagents=subagent_specs,
|
||||
),
|
||||
selector_mw,
|
||||
model_call_limit_mw,
|
||||
tool_call_limit_mw,
|
||||
context_edit_mw,
|
||||
summarization_mw,
|
||||
noop_mw,
|
||||
retry_mw,
|
||||
fallback_mw,
|
||||
repair_mw,
|
||||
permission_mw,
|
||||
doom_loop_mw,
|
||||
action_log_mw,
|
||||
PatchToolCallsMiddleware(),
|
||||
DedupHITLToolCallsMiddleware(agent_tools=list(tools)),
|
||||
*plugin_middlewares,
|
||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
||||
]
|
||||
return [m for m in deepagent_middleware if m is not None]
|
||||
|
|
@ -2,6 +2,6 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from .factory import create_surfsense_deep_agent
|
||||
from .factory import create_multi_agent_chat_deep_agent
|
||||
|
||||
__all__ = ["create_surfsense_deep_agent"]
|
||||
__all__ = ["create_multi_agent_chat_deep_agent"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,117 @@
|
|||
"""Compiled agent graph caching for the multi-agent path."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import ToolsPermissions
|
||||
from app.agents.new_chat.agent_cache import (
|
||||
flags_signature,
|
||||
get_cache,
|
||||
stable_hash,
|
||||
system_prompt_hash,
|
||||
tools_signature,
|
||||
)
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
|
||||
|
||||
|
||||
def mcp_signature(mcp_tools_by_agent: dict[str, ToolsPermissions]) -> str:
|
||||
"""Hash the per-agent MCP tool surface so a change rotates the cache key."""
|
||||
rows = []
|
||||
for agent_name in sorted(mcp_tools_by_agent.keys()):
|
||||
perms = mcp_tools_by_agent[agent_name]
|
||||
allow_names = sorted(item.get("name", "") for item in perms.get("allow", []))
|
||||
ask_names = sorted(item.get("name", "") for item in perms.get("ask", []))
|
||||
rows.append((agent_name, allow_names, ask_names))
|
||||
return stable_hash(rows)
|
||||
|
||||
|
||||
async def build_agent_with_cache(
|
||||
*,
|
||||
llm: BaseChatModel,
|
||||
tools: Sequence[BaseTool],
|
||||
final_system_prompt: str,
|
||||
backend_resolver: Any,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
visibility: ChatVisibility,
|
||||
anon_session_id: str | None,
|
||||
available_connectors: list[str],
|
||||
available_document_types: list[str],
|
||||
mentioned_document_ids: list[int] | None,
|
||||
max_input_tokens: int | None,
|
||||
flags: AgentFeatureFlags,
|
||||
checkpointer: Checkpointer,
|
||||
subagent_dependencies: dict[str, Any],
|
||||
mcp_tools_by_agent: dict[str, ToolsPermissions],
|
||||
disabled_tools: list[str] | None,
|
||||
config_id: str | None,
|
||||
) -> Any:
|
||||
"""Compile the multi-agent graph, serving from cache when key components are stable."""
|
||||
|
||||
async def _build() -> Any:
|
||||
return await asyncio.to_thread(
|
||||
build_compiled_agent_graph_sync,
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
final_system_prompt=final_system_prompt,
|
||||
backend_resolver=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
visibility=visibility,
|
||||
anon_session_id=anon_session_id,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
max_input_tokens=max_input_tokens,
|
||||
flags=flags,
|
||||
checkpointer=checkpointer,
|
||||
subagent_dependencies=subagent_dependencies,
|
||||
mcp_tools_by_agent=mcp_tools_by_agent,
|
||||
disabled_tools=disabled_tools,
|
||||
)
|
||||
|
||||
if not (flags.enable_agent_cache and not flags.disable_new_agent_stack):
|
||||
return await _build()
|
||||
|
||||
# Every per-request value any middleware closes over at __init__ must be in
|
||||
# the key, otherwise a hit will leak state across threads. Bump the schema
|
||||
# version when the component list changes shape.
|
||||
cache_key = stable_hash(
|
||||
"multi-agent-v1",
|
||||
config_id,
|
||||
thread_id,
|
||||
user_id,
|
||||
search_space_id,
|
||||
visibility,
|
||||
filesystem_mode,
|
||||
anon_session_id,
|
||||
tools_signature(
|
||||
tools,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
),
|
||||
mcp_signature(mcp_tools_by_agent),
|
||||
flags_signature(flags),
|
||||
system_prompt_hash(final_system_prompt),
|
||||
max_input_tokens,
|
||||
sorted(disabled_tools) if disabled_tools else None,
|
||||
)
|
||||
return await get_cache().get_or_build(cache_key, builder=_build)
|
||||
|
||||
|
||||
__all__ = ["build_agent_with_cache", "mcp_signature"]
|
||||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
|
|
@ -26,23 +25,24 @@ from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
|
|||
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
|
||||
from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_tool
|
||||
from app.agents.new_chat.tools.registry import build_tools_async
|
||||
from app.db import ChatVisibility
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
|
||||
from ..system_prompt import build_main_agent_system_prompt
|
||||
from ..tools import (
|
||||
MAIN_AGENT_SURFSENSE_TOOL_NAMES,
|
||||
MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED,
|
||||
)
|
||||
from .agent_cache import build_agent_with_cache
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
async def create_surfsense_deep_agent(
|
||||
async def create_multi_agent_chat_deep_agent(
|
||||
llm: BaseChatModel,
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
|
|
@ -62,6 +62,9 @@ async def create_surfsense_deep_agent(
|
|||
):
|
||||
"""Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled."""
|
||||
_t_agent_total = time.perf_counter()
|
||||
|
||||
apply_litellm_prompt_caching(llm, agent_config=agent_config, thread_id=thread_id)
|
||||
|
||||
filesystem_selection = filesystem_selection or FilesystemSelection()
|
||||
backend_resolver = build_backend_resolver(
|
||||
filesystem_selection,
|
||||
|
|
@ -85,7 +88,18 @@ async def create_surfsense_deep_agent(
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.warning("Failed to discover available connectors/document types: %s", e)
|
||||
logging.warning(
|
||||
"Connector/doc-type discovery failed; excluding connector subagents this turn: %s",
|
||||
e,
|
||||
)
|
||||
|
||||
# Fail closed: a None list short-circuits ``get_subagents_to_exclude`` to "exclude
|
||||
# nothing", which would silently advertise every connector specialist on a flaky
|
||||
# discovery call. Empty list excludes connector-gated subagents while keeping builtins.
|
||||
if available_connectors is None:
|
||||
available_connectors = []
|
||||
if available_document_types is None:
|
||||
available_document_types = []
|
||||
_perf_log.info(
|
||||
"[create_agent] Connector/doc-type discovery in %.3fs",
|
||||
time.perf_counter() - _t0,
|
||||
|
|
@ -115,7 +129,18 @@ async def create_surfsense_deep_agent(
|
|||
}
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
mcp_tools_by_agent = await load_mcp_tools_by_connector(db_session, search_space_id)
|
||||
try:
|
||||
mcp_tools_by_agent = await load_mcp_tools_by_connector(
|
||||
db_session, search_space_id
|
||||
)
|
||||
except Exception as e:
|
||||
# Degrade to builtins-only rather than aborting the turn: a transient
|
||||
# DB or MCP-server hiccup should not deny the user a response.
|
||||
logging.warning(
|
||||
"MCP tool discovery failed; subagents will run without MCP tools this turn: %s",
|
||||
e,
|
||||
)
|
||||
mcp_tools_by_agent = {}
|
||||
_perf_log.info(
|
||||
"[create_agent] load_mcp_tools_by_connector in %.3fs (%d buckets)",
|
||||
time.perf_counter() - _t0,
|
||||
|
|
@ -195,9 +220,10 @@ async def create_surfsense_deep_agent(
|
|||
|
||||
final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT
|
||||
|
||||
config_id = agent_config.config_id if agent_config is not None else None
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
agent = await asyncio.to_thread(
|
||||
build_compiled_agent_graph_sync,
|
||||
agent = await build_agent_with_cache(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
final_system_prompt=final_system_prompt,
|
||||
|
|
@ -217,6 +243,7 @@ async def create_surfsense_deep_agent(
|
|||
subagent_dependencies=dependencies,
|
||||
mcp_tools_by_agent=mcp_tools_by_agent,
|
||||
disabled_tools=disabled_tools,
|
||||
config_id=config_id,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[create_agent] Middleware stack + graph compiled in %.3fs",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,7 @@
|
|||
"""Multi-agent middleware stack assembly."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .stack import build_main_agent_deepagent_middleware
|
||||
|
||||
__all__ = ["build_main_agent_deepagent_middleware"]
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
"""Audit row per tool call (reversibility metadata)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import ActionLogMiddleware
|
||||
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_action_log_mw(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
thread_id: int | None,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
) -> ActionLogMiddleware | None:
|
||||
if not enabled(flags, "enable_action_log") or thread_id is None:
|
||||
return None
|
||||
try:
|
||||
tool_defs_by_name = {td.name: td for td in BUILTIN_TOOLS}
|
||||
return ActionLogMiddleware(
|
||||
thread_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
tool_definitions=tool_defs_by_name,
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logging.warning(
|
||||
"ActionLogMiddleware init failed; running without it.",
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
"""Anonymous document hydration from Redis (cloud only)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import AnonymousDocumentMiddleware
|
||||
|
||||
|
||||
def build_anonymous_doc_mw(
|
||||
*,
|
||||
filesystem_mode: FilesystemMode,
|
||||
anon_session_id: str | None,
|
||||
) -> AnonymousDocumentMiddleware | None:
|
||||
if filesystem_mode != FilesystemMode.CLOUD:
|
||||
return None
|
||||
return AnonymousDocumentMiddleware(anon_session_id=anon_session_id)
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
"""Per-thread cooperative lock around the whole turn."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import BusyMutexMiddleware
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_busy_mutex_mw(flags: AgentFeatureFlags) -> BusyMutexMiddleware | None:
|
||||
return BusyMutexMiddleware() if enabled(flags, "enable_busy_mutex") else None
|
||||
|
|
@ -0,0 +1,90 @@
|
|||
"""RunnableConfig wiring for nested subagent invocations.
|
||||
|
||||
Forwards the parent's ``runtime.config`` (thread_id, …) into the subagent and
|
||||
exposes the side-channel ``stream_resume_chat`` uses to ferry resume payloads.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
from .constants import DEFAULT_SUBAGENT_RECURSION_LIMIT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# langgraph stores the parent task's scratchpad under this configurable key;
|
||||
# subagents inherit the chain via ``parent_scratchpad`` fallback.
|
||||
_LANGGRAPH_SCRATCHPAD_KEY = "__pregel_scratchpad"
|
||||
|
||||
|
||||
def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]:
|
||||
"""RunnableConfig for the nested invoke; raises ``recursion_limit`` to the parent's budget."""
|
||||
merged: dict[str, Any] = dict(runtime.config) if runtime.config else {}
|
||||
current_limit = merged.get("recursion_limit")
|
||||
try:
|
||||
current_int = int(current_limit) if current_limit is not None else 0
|
||||
except (TypeError, ValueError):
|
||||
current_int = 0
|
||||
if current_int < DEFAULT_SUBAGENT_RECURSION_LIMIT:
|
||||
merged["recursion_limit"] = DEFAULT_SUBAGENT_RECURSION_LIMIT
|
||||
return merged
|
||||
|
||||
|
||||
def consume_surfsense_resume(runtime: ToolRuntime) -> Any:
|
||||
"""Pop the resume payload; siblings share ``configurable`` by reference."""
|
||||
cfg = runtime.config or {}
|
||||
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
|
||||
if not isinstance(configurable, dict):
|
||||
return None
|
||||
return configurable.pop("surfsense_resume_value", None)
|
||||
|
||||
|
||||
def has_surfsense_resume(runtime: ToolRuntime) -> bool:
|
||||
"""True iff a resume payload is queued on this runtime (non-destructive)."""
|
||||
cfg = runtime.config or {}
|
||||
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
|
||||
if not isinstance(configurable, dict):
|
||||
return False
|
||||
return "surfsense_resume_value" in configurable
|
||||
|
||||
|
||||
def drain_parent_null_resume(runtime: ToolRuntime) -> None:
|
||||
"""Consume the parent's lingering ``NULL_TASK_ID/RESUME`` write before delegating.
|
||||
|
||||
``stream_resume_chat`` wakes the main agent with
|
||||
``Command(resume={"decisions": [...]})`` so the propagated
|
||||
``_lg_interrupt(...)`` can return. langgraph stores that payload as the
|
||||
parent task's ``null_resume`` pending write, which only gets consumed
|
||||
*after* ``subagent.[a]invoke`` returns (when the post-call propagation
|
||||
re-fires). While the subagent is mid-execution, any *new* ``interrupt()``
|
||||
inside it (e.g. a follow-up tool call after a mixed approve/reject) walks
|
||||
``subagent_scratchpad → parent_scratchpad.get_null_resume`` and picks up
|
||||
the parent's still-live decisions — mismatching against a different number
|
||||
of hanging tool calls and crashing ``HumanInTheLoopMiddleware``.
|
||||
|
||||
Draining the write here closes that cross-graph leak so subagent
|
||||
interrupts pause cleanly and re-propagate as a fresh approval card.
|
||||
"""
|
||||
cfg = runtime.config or {}
|
||||
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
|
||||
if not isinstance(configurable, dict):
|
||||
return
|
||||
scratchpad = configurable.get(_LANGGRAPH_SCRATCHPAD_KEY)
|
||||
if scratchpad is None:
|
||||
return
|
||||
consume = getattr(scratchpad, "get_null_resume", None)
|
||||
if not callable(consume):
|
||||
return
|
||||
try:
|
||||
consume(True)
|
||||
except Exception:
|
||||
# Defensive: if langgraph's internal scratchpad shape changes we don't
|
||||
# want to break the resume path. Worst case the original ValueError
|
||||
# still surfaces — same behavior as before this fix.
|
||||
logger.debug(
|
||||
"drain_parent_null_resume: scratchpad.get_null_resume raised",
|
||||
exc_info=True,
|
||||
)
|
||||
|
|
@ -20,6 +20,7 @@ from langgraph.types import Command
|
|||
|
||||
from .config import (
|
||||
consume_surfsense_resume,
|
||||
drain_parent_null_resume,
|
||||
has_surfsense_resume,
|
||||
subagent_invoke_config,
|
||||
)
|
||||
|
|
@ -69,9 +70,16 @@ def build_task_tool_with_parent_config(
|
|||
raise ValueError(msg)
|
||||
|
||||
state_update = {k: v for k, v in result.items() if k not in EXCLUDED_STATE_KEYS}
|
||||
message_text = (
|
||||
result["messages"][-1].text.rstrip() if result["messages"][-1].text else ""
|
||||
)
|
||||
messages = result["messages"]
|
||||
if not messages:
|
||||
msg = (
|
||||
"CompiledSubAgent returned an empty 'messages' list. "
|
||||
"Subagents must produce at least one message so the parent has "
|
||||
"output to forward back to the user."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
last_text = getattr(messages[-1], "text", None) or ""
|
||||
message_text = last_text.rstrip()
|
||||
return Command(
|
||||
update={
|
||||
**state_update,
|
||||
|
|
@ -150,6 +158,9 @@ def build_task_tool_with_parent_config(
|
|||
)
|
||||
expected = hitlrequest_action_count(pending_value)
|
||||
resume_value = fan_out_decisions_to_match(resume_value, expected)
|
||||
# Prevent the parent's resume payload from leaking into subagent
|
||||
# interrupts via langgraph's parent_scratchpad fallback.
|
||||
drain_parent_null_resume(runtime)
|
||||
result = subagent.invoke(
|
||||
build_resume_command(resume_value, pending_id),
|
||||
config=sub_config,
|
||||
|
|
@ -214,6 +225,9 @@ def build_task_tool_with_parent_config(
|
|||
)
|
||||
expected = hitlrequest_action_count(pending_value)
|
||||
resume_value = fan_out_decisions_to_match(resume_value, expected)
|
||||
# Prevent the parent's resume payload from leaking into subagent
|
||||
# interrupts via langgraph's parent_scratchpad fallback.
|
||||
drain_parent_null_resume(runtime)
|
||||
result = await subagent.ainvoke(
|
||||
build_resume_command(resume_value, pending_id),
|
||||
config=sub_config,
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
"""Spill + clear-tool-uses passes to keep payloads under budget."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.multi_agent_chat.main_agent.context_prune.prune_tool_names import (
|
||||
safe_exclude_tools,
|
||||
)
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import (
|
||||
ClearToolUsesEdit,
|
||||
SpillingContextEditingMiddleware,
|
||||
SpillToBackendEdit,
|
||||
)
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_context_editing_mw(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
max_input_tokens: int | None,
|
||||
tools: Sequence[BaseTool],
|
||||
backend_resolver: Any,
|
||||
) -> SpillingContextEditingMiddleware | None:
|
||||
if not enabled(flags, "enable_context_editing") or not max_input_tokens:
|
||||
return None
|
||||
spill_edit = SpillToBackendEdit(
|
||||
trigger=int(max_input_tokens * 0.55),
|
||||
clear_at_least=int(max_input_tokens * 0.15),
|
||||
keep=5,
|
||||
exclude_tools=safe_exclude_tools(tools),
|
||||
clear_tool_inputs=True,
|
||||
)
|
||||
clear_edit = ClearToolUsesEdit(
|
||||
trigger=int(max_input_tokens * 0.55),
|
||||
clear_at_least=int(max_input_tokens * 0.15),
|
||||
keep=5,
|
||||
exclude_tools=safe_exclude_tools(tools),
|
||||
clear_tool_inputs=True,
|
||||
placeholder="[cleared - older tool output trimmed for context]",
|
||||
)
|
||||
return SpillingContextEditingMiddleware(
|
||||
edits=[spill_edit, clear_edit],
|
||||
backend_resolver=backend_resolver,
|
||||
)
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
"""Drop duplicate HITL tool calls before execution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.new_chat.middleware import DedupHITLToolCallsMiddleware
|
||||
|
||||
|
||||
def build_dedup_hitl_mw(tools: Sequence[BaseTool]) -> DedupHITLToolCallsMiddleware:
|
||||
return DedupHITLToolCallsMiddleware(agent_tools=list(tools))
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
"""Stop N identical tool calls in a row via interrupt."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import DoomLoopMiddleware
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_doom_loop_mw(flags: AgentFeatureFlags) -> DoomLoopMiddleware | None:
|
||||
return (
|
||||
DoomLoopMiddleware(threshold=3) if enabled(flags, "enable_doom_loop") else None
|
||||
)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
"""Commit staged cloud filesystem mutations to Postgres at end of turn."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import KnowledgeBasePersistenceMiddleware
|
||||
|
||||
|
||||
def build_kb_persistence_mw(
|
||||
*,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
) -> KnowledgeBasePersistenceMiddleware | None:
|
||||
if filesystem_mode != FilesystemMode.CLOUD:
|
||||
return None
|
||||
return KnowledgeBasePersistenceMiddleware(
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
"""KB priority planner: <priority_documents> injection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import KnowledgePriorityMiddleware
|
||||
|
||||
|
||||
def build_knowledge_priority_mw(
|
||||
*,
|
||||
llm: BaseChatModel,
|
||||
search_space_id: int,
|
||||
filesystem_mode: FilesystemMode,
|
||||
available_connectors: list[str] | None,
|
||||
available_document_types: list[str] | None,
|
||||
mentioned_document_ids: list[int] | None,
|
||||
) -> KnowledgePriorityMiddleware:
|
||||
return KnowledgePriorityMiddleware(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
"""<workspace_tree> injection (cloud only)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import KnowledgeTreeMiddleware
|
||||
|
||||
|
||||
def build_knowledge_tree_mw(
|
||||
*,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
llm: BaseChatModel,
|
||||
) -> KnowledgeTreeMiddleware | None:
|
||||
if filesystem_mode != FilesystemMode.CLOUD:
|
||||
return None
|
||||
return KnowledgeTreeMiddleware(
|
||||
search_space_id=search_space_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
llm=llm,
|
||||
)
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
"""Provider-compat: append a `_noop` tool when tools=[] but history has tool calls."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import NoopInjectionMiddleware
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_noop_injection_mw(flags: AgentFeatureFlags) -> NoopInjectionMiddleware | None:
|
||||
return NoopInjectionMiddleware() if enabled(flags, "enable_compaction_v2") else None
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
"""OTel spans on model and tool calls."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import OtelSpanMiddleware
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_otel_mw(flags: AgentFeatureFlags) -> OtelSpanMiddleware | None:
|
||||
return OtelSpanMiddleware() if enabled(flags, "enable_otel") else None
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
"""Tail-of-stack plugin slot driven by env allowlist."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.plugin_loader import (
|
||||
PluginContext,
|
||||
load_allowed_plugin_names_from_env,
|
||||
load_plugin_middlewares,
|
||||
)
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_plugin_middlewares(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
visibility: ChatVisibility,
|
||||
llm: BaseChatModel,
|
||||
) -> list[Any]:
|
||||
if not enabled(flags, "enable_plugin_loader"):
|
||||
return []
|
||||
try:
|
||||
allowed_names = load_allowed_plugin_names_from_env()
|
||||
if not allowed_names:
|
||||
return []
|
||||
return load_plugin_middlewares(
|
||||
PluginContext.build(
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_visibility=visibility,
|
||||
llm=llm,
|
||||
),
|
||||
allowed_plugin_names=allowed_names,
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logging.warning(
|
||||
"Plugin loader failed; continuing without plugins.",
|
||||
exc_info=True,
|
||||
)
|
||||
return []
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
"""Repair miscased / unknown tool names to the registered set or invalid_tool."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import ToolCallNameRepairMiddleware
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
# deepagents-built-in tool names the repair pass treats as known.
|
||||
_DEEPAGENT_BUILTIN_TOOL_NAMES: frozenset[str] = frozenset(
|
||||
{
|
||||
"write_todos",
|
||||
"ls",
|
||||
"read_file",
|
||||
"write_file",
|
||||
"edit_file",
|
||||
"glob",
|
||||
"grep",
|
||||
"execute",
|
||||
"task",
|
||||
"mkdir",
|
||||
"cd",
|
||||
"pwd",
|
||||
"move_file",
|
||||
"rm",
|
||||
"rmdir",
|
||||
"list_tree",
|
||||
"execute_code",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def build_repair_mw(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
tools: Sequence[BaseTool],
|
||||
) -> ToolCallNameRepairMiddleware | None:
|
||||
if not enabled(flags, "enable_tool_call_repair"):
|
||||
return None
|
||||
registered_names: set[str] = {t.name for t in tools}
|
||||
registered_names |= _DEEPAGENT_BUILTIN_TOOL_NAMES
|
||||
return ToolCallNameRepairMiddleware(
|
||||
registered_tool_names=registered_names,
|
||||
fuzzy_match_threshold=None,
|
||||
)
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
"""LLM-based tool subset selection (only when >30 tools)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain.agents.middleware import LLMToolSelectorMiddleware
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_selector_mw(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
tools: Sequence[BaseTool],
|
||||
) -> LLMToolSelectorMiddleware | None:
|
||||
if not enabled(flags, "enable_llm_tool_selector") or len(tools) <= 30:
|
||||
return None
|
||||
try:
|
||||
return LLMToolSelectorMiddleware(
|
||||
model="openai:gpt-4o-mini",
|
||||
max_tools=12,
|
||||
always_include=[
|
||||
name
|
||||
for name in (
|
||||
"update_memory",
|
||||
"get_connected_accounts",
|
||||
"scrape_webpage",
|
||||
)
|
||||
if name in {t.name for t in tools}
|
||||
],
|
||||
)
|
||||
except Exception:
|
||||
logging.warning("LLMToolSelectorMiddleware init failed; skipping.")
|
||||
return None
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
"""Skill discovery + injection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from deepagents.middleware.skills import SkillsMiddleware
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import (
|
||||
build_skills_backend_factory,
|
||||
default_skills_sources,
|
||||
)
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_skills_mw(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
) -> SkillsMiddleware | None:
|
||||
if not enabled(flags, "enable_skills"):
|
||||
return None
|
||||
try:
|
||||
skills_factory = build_skills_backend_factory(
|
||||
search_space_id=search_space_id
|
||||
if filesystem_mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
)
|
||||
return SkillsMiddleware(
|
||||
backend=skills_factory,
|
||||
sources=default_skills_sources(),
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logging.warning("SkillsMiddleware init failed; skipping: %s", exc)
|
||||
return None
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
"""Anthropic prompt caching annotations on system/tool/message blocks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||
|
||||
|
||||
def build_anthropic_cache_mw() -> AnthropicPromptCachingMiddleware:
|
||||
return AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
"""Context-window summarization with SurfSense protected sections."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from deepagents.backends import StateBackend
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.new_chat.middleware import create_surfsense_compaction_middleware
|
||||
|
||||
|
||||
def build_compaction_mw(llm: BaseChatModel) -> Any:
|
||||
return create_surfsense_compaction_middleware(llm, StateBackend)
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
"""File-intent classifier that gates strict write contracts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.new_chat.middleware import FileIntentMiddleware
|
||||
|
||||
|
||||
def build_file_intent_mw(llm: BaseChatModel) -> FileIntentMiddleware:
|
||||
return FileIntentMiddleware(llm=llm)
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
"""SurfSense filesystem tools/middleware."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import SurfSenseFilesystemMiddleware
|
||||
|
||||
|
||||
def build_filesystem_mw(
|
||||
*,
|
||||
backend_resolver: Any,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
) -> SurfSenseFilesystemMiddleware:
|
||||
return SurfSenseFilesystemMiddleware(
|
||||
backend=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
"""Single source of truth for the feature-flag predicate."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
|
||||
|
||||
def enabled(flags: AgentFeatureFlags, attr: str) -> bool:
|
||||
"""``flags.<attr>`` is on AND the new-agent-stack kill switch is off."""
|
||||
return getattr(flags, attr) and not flags.disable_new_agent_stack
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
"""User/team memory injection prepended to the conversation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.middleware import MemoryInjectionMiddleware
|
||||
from app.db import ChatVisibility
|
||||
|
||||
|
||||
def build_memory_mw(
|
||||
*,
|
||||
user_id: str | None,
|
||||
search_space_id: int,
|
||||
visibility: ChatVisibility,
|
||||
) -> MemoryInjectionMiddleware:
|
||||
return MemoryInjectionMiddleware(
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
thread_visibility=visibility,
|
||||
)
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
"""Repair dangling tool-call sequences before each agent turn."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||
|
||||
|
||||
def build_patch_tool_calls_mw() -> PatchToolCallsMiddleware:
|
||||
return PatchToolCallsMiddleware()
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
"""Permission rulesets fanned out to parent / general-purpose / subagent stacks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .context import PermissionContext, build_permission_context
|
||||
from .middleware import build_full_permission_mw
|
||||
|
||||
__all__ = [
|
||||
"PermissionContext",
|
||||
"build_full_permission_mw",
|
||||
"build_permission_context",
|
||||
]
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
"""Derive shared permission context once; fan out to all three stack layers.
|
||||
|
||||
The context carries:
|
||||
- ``rulesets``: full ask/deny/allow rules for the main-agent permission middleware.
|
||||
- ``general_purpose_interrupt_on``: ``ask`` rules mirrored as deepagents
|
||||
``interrupt_on`` so HITL still triggers from inside ``task`` runs (subagents
|
||||
bypass the main-agent permission middleware).
|
||||
- ``subagent_deny_mw``: a deny-only ``PermissionMiddleware`` instance shared
|
||||
across the general-purpose and registry subagent stacks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import PermissionMiddleware
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
|
||||
|
||||
from ..flags import enabled
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PermissionContext:
|
||||
rulesets: list[Ruleset]
|
||||
general_purpose_interrupt_on: dict[str, bool]
|
||||
subagent_deny_mw: PermissionMiddleware | None
|
||||
|
||||
|
||||
def build_permission_context(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
filesystem_mode: FilesystemMode,
|
||||
tools: Sequence[BaseTool],
|
||||
available_connectors: list[str] | None,
|
||||
) -> PermissionContext:
|
||||
is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER
|
||||
permission_enabled = enabled(flags, "enable_permission")
|
||||
|
||||
rulesets: list[Ruleset] = []
|
||||
if permission_enabled or is_desktop_fs:
|
||||
rulesets.append(
|
||||
Ruleset(
|
||||
rules=[Rule(permission="*", pattern="*", action="allow")],
|
||||
origin="surfsense_defaults",
|
||||
)
|
||||
)
|
||||
if is_desktop_fs:
|
||||
rulesets.append(
|
||||
Ruleset(
|
||||
rules=[
|
||||
Rule(permission="rm", pattern="*", action="ask"),
|
||||
Rule(permission="rmdir", pattern="*", action="ask"),
|
||||
Rule(permission="move_file", pattern="*", action="ask"),
|
||||
Rule(permission="edit_file", pattern="*", action="ask"),
|
||||
Rule(permission="write_file", pattern="*", action="ask"),
|
||||
],
|
||||
origin="desktop_safety",
|
||||
)
|
||||
)
|
||||
|
||||
tool_names_in_use = {t.name for t in tools}
|
||||
|
||||
if permission_enabled:
|
||||
available_set = set(available_connectors or [])
|
||||
synthesized: list[Rule] = []
|
||||
for tool_def in BUILTIN_TOOLS:
|
||||
if tool_def.name not in tool_names_in_use:
|
||||
continue
|
||||
rc = tool_def.required_connector
|
||||
if rc and rc not in available_set:
|
||||
synthesized.append(
|
||||
Rule(permission=tool_def.name, pattern="*", action="deny")
|
||||
)
|
||||
if synthesized:
|
||||
rulesets.append(Ruleset(rules=synthesized, origin="connector_synthesized"))
|
||||
|
||||
general_purpose_interrupt_on: dict[str, bool] = {
|
||||
rule.permission: True
|
||||
for rs in rulesets
|
||||
for rule in rs.rules
|
||||
if rule.action == "ask" and rule.permission in tool_names_in_use
|
||||
}
|
||||
|
||||
deny_rulesets = [
|
||||
Ruleset(
|
||||
rules=[r for r in rs.rules if r.action == "deny"],
|
||||
origin=rs.origin,
|
||||
)
|
||||
for rs in rulesets
|
||||
]
|
||||
deny_rulesets = [rs for rs in deny_rulesets if rs.rules]
|
||||
|
||||
subagent_deny_mw: PermissionMiddleware | None = (
|
||||
PermissionMiddleware(rulesets=deny_rulesets) if deny_rulesets else None
|
||||
)
|
||||
|
||||
return PermissionContext(
|
||||
rulesets=rulesets,
|
||||
general_purpose_interrupt_on=general_purpose_interrupt_on,
|
||||
subagent_deny_mw=subagent_deny_mw,
|
||||
)
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
"""Main-agent permission middleware (full ask/deny/allow rules)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.middleware import PermissionMiddleware
|
||||
from app.agents.new_chat.permissions import Ruleset
|
||||
|
||||
|
||||
def build_full_permission_mw(rulesets: list[Ruleset]) -> PermissionMiddleware | None:
|
||||
return PermissionMiddleware(rulesets=rulesets) if rulesets else None
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
"""Resilience middleware shared as the same instances across parent / general-purpose / registry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .bundle import ResilienceBundle, build_resilience_bundle
|
||||
|
||||
__all__ = ["ResilienceBundle", "build_resilience_bundle"]
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
"""Construct each resilience middleware once; same instances flow into every consumer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import (
|
||||
ModelCallLimitMiddleware,
|
||||
ToolCallLimitMiddleware,
|
||||
)
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import RetryAfterMiddleware
|
||||
from app.agents.new_chat.middleware.scoped_model_fallback import (
|
||||
ScopedModelFallbackMiddleware,
|
||||
)
|
||||
|
||||
from .fallback import build_fallback_mw
|
||||
from .model_call_limit import build_model_call_limit_mw
|
||||
from .retry import build_retry_mw
|
||||
from .tool_call_limit import build_tool_call_limit_mw
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResilienceBundle:
|
||||
retry: RetryAfterMiddleware | None
|
||||
fallback: ScopedModelFallbackMiddleware | None
|
||||
model_call_limit: ModelCallLimitMiddleware | None
|
||||
tool_call_limit: ToolCallLimitMiddleware | None
|
||||
|
||||
def as_list(self) -> list[Any]:
|
||||
return [
|
||||
m
|
||||
for m in (
|
||||
self.retry,
|
||||
self.fallback,
|
||||
self.model_call_limit,
|
||||
self.tool_call_limit,
|
||||
)
|
||||
if m is not None
|
||||
]
|
||||
|
||||
|
||||
def build_resilience_bundle(flags: AgentFeatureFlags) -> ResilienceBundle:
|
||||
return ResilienceBundle(
|
||||
retry=build_retry_mw(flags),
|
||||
fallback=build_fallback_mw(flags),
|
||||
model_call_limit=build_model_call_limit_mw(flags),
|
||||
tool_call_limit=build_tool_call_limit_mw(flags),
|
||||
)
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
"""Switch to a fallback model on provider/network errors only."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware.scoped_model_fallback import (
|
||||
ScopedModelFallbackMiddleware,
|
||||
)
|
||||
|
||||
from ..flags import enabled
|
||||
|
||||
|
||||
def build_fallback_mw(
|
||||
flags: AgentFeatureFlags,
|
||||
) -> ScopedModelFallbackMiddleware | None:
|
||||
if not enabled(flags, "enable_model_fallback"):
|
||||
return None
|
||||
try:
|
||||
return ScopedModelFallbackMiddleware(
|
||||
"openai:gpt-4o-mini",
|
||||
"anthropic:claude-3-5-haiku-20241022",
|
||||
)
|
||||
except Exception:
|
||||
logging.warning("ScopedModelFallbackMiddleware init failed; skipping.")
|
||||
return None
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
"""Cap model calls per thread / per run to prevent runaway cost."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain.agents.middleware import ModelCallLimitMiddleware
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
|
||||
from ..flags import enabled
|
||||
|
||||
|
||||
def build_model_call_limit_mw(
|
||||
flags: AgentFeatureFlags,
|
||||
) -> ModelCallLimitMiddleware | None:
|
||||
if not enabled(flags, "enable_model_call_limit"):
|
||||
return None
|
||||
return ModelCallLimitMiddleware(
|
||||
thread_limit=120,
|
||||
run_limit=80,
|
||||
exit_behavior="end",
|
||||
)
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
"""Retry on transient model errors (e.g. Retry-After-bearing 429s)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import RetryAfterMiddleware
|
||||
|
||||
from ..flags import enabled
|
||||
|
||||
|
||||
def build_retry_mw(flags: AgentFeatureFlags) -> RetryAfterMiddleware | None:
|
||||
return (
|
||||
RetryAfterMiddleware(max_retries=3)
|
||||
if enabled(flags, "enable_retry_after")
|
||||
else None
|
||||
)
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
"""Cap tool calls per thread / per run to bound infinite-loop blast radius."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain.agents.middleware import ToolCallLimitMiddleware
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
|
||||
from ..flags import enabled
|
||||
|
||||
|
||||
def build_tool_call_limit_mw(
|
||||
flags: AgentFeatureFlags,
|
||||
) -> ToolCallLimitMiddleware | None:
|
||||
if not enabled(flags, "enable_tool_call_limit"):
|
||||
return None
|
||||
return ToolCallLimitMiddleware(
|
||||
thread_limit=300,
|
||||
run_limit=80,
|
||||
exit_behavior="continue",
|
||||
)
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
"""Todo-list middleware (each consumer needs its own instance)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain.agents.middleware import TodoListMiddleware
|
||||
|
||||
|
||||
def build_todos_mw() -> TodoListMiddleware:
|
||||
return TodoListMiddleware()
|
||||
|
|
@ -0,0 +1,216 @@
|
|||
"""Main-agent middleware list assembly: one line per slot."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from deepagents import SubAgent
|
||||
from deepagents.backends import StateBackend
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.multi_agent_chat.subagents import (
|
||||
build_subagents,
|
||||
get_subagents_to_exclude,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.builtins.general_purpose.agent import (
|
||||
build_subagent as build_general_purpose_subagent,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import ToolsPermissions
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from .main_agent.action_log import build_action_log_mw
|
||||
from .main_agent.anonymous_doc import build_anonymous_doc_mw
|
||||
from .main_agent.busy_mutex import build_busy_mutex_mw
|
||||
from .main_agent.checkpointed_subagent_middleware import (
|
||||
SurfSenseCheckpointedSubAgentMiddleware,
|
||||
)
|
||||
from .main_agent.context_editing import build_context_editing_mw
|
||||
from .main_agent.dedup_hitl import build_dedup_hitl_mw
|
||||
from .main_agent.doom_loop import build_doom_loop_mw
|
||||
from .main_agent.kb_persistence import build_kb_persistence_mw
|
||||
from .main_agent.knowledge_priority import build_knowledge_priority_mw
|
||||
from .main_agent.knowledge_tree import build_knowledge_tree_mw
|
||||
from .main_agent.noop_injection import build_noop_injection_mw
|
||||
from .main_agent.otel import build_otel_mw
|
||||
from .main_agent.plugins import build_plugin_middlewares
|
||||
from .main_agent.repair import build_repair_mw
|
||||
from .main_agent.selector import build_selector_mw
|
||||
from .main_agent.skills import build_skills_mw
|
||||
from .shared.anthropic_cache import build_anthropic_cache_mw
|
||||
from .shared.compaction import build_compaction_mw
|
||||
from .shared.file_intent import build_file_intent_mw
|
||||
from .shared.filesystem import build_filesystem_mw
|
||||
from .shared.memory import build_memory_mw
|
||||
from .shared.patch_tool_calls import build_patch_tool_calls_mw
|
||||
from .shared.permissions import (
|
||||
build_full_permission_mw,
|
||||
build_permission_context,
|
||||
)
|
||||
from .shared.resilience import build_resilience_bundle
|
||||
from .shared.todos import build_todos_mw
|
||||
from .subagent.extras import build_subagent_extras
|
||||
|
||||
|
||||
def build_main_agent_deepagent_middleware(
|
||||
*,
|
||||
llm: BaseChatModel,
|
||||
tools: Sequence[BaseTool],
|
||||
backend_resolver: Any,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
visibility: ChatVisibility,
|
||||
anon_session_id: str | None,
|
||||
available_connectors: list[str] | None,
|
||||
available_document_types: list[str] | None,
|
||||
mentioned_document_ids: list[int] | None,
|
||||
max_input_tokens: int | None,
|
||||
flags: AgentFeatureFlags,
|
||||
subagent_dependencies: dict[str, Any],
|
||||
checkpointer: Checkpointer,
|
||||
mcp_tools_by_agent: dict[str, ToolsPermissions] | None = None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
) -> list[Any]:
|
||||
"""Ordered middleware for ``create_agent`` (None entries already stripped)."""
|
||||
permissions = build_permission_context(
|
||||
flags=flags,
|
||||
filesystem_mode=filesystem_mode,
|
||||
tools=tools,
|
||||
available_connectors=available_connectors,
|
||||
)
|
||||
resilience = build_resilience_bundle(flags)
|
||||
|
||||
# Single instance threaded into both the main-agent stack and the general-purpose subagent.
|
||||
memory_mw = build_memory_mw(
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
visibility=visibility,
|
||||
)
|
||||
|
||||
general_purpose_subagent = build_general_purpose_subagent(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
backend_resolver=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
permissions=permissions,
|
||||
resilience=resilience,
|
||||
memory_mw=memory_mw,
|
||||
)
|
||||
|
||||
subagents_registry: list[SubAgent] = []
|
||||
try:
|
||||
subagent_extras = build_subagent_extras(
|
||||
permissions=permissions,
|
||||
resilience=resilience,
|
||||
)
|
||||
subagents_registry = build_subagents(
|
||||
dependencies=subagent_dependencies,
|
||||
model=llm,
|
||||
extra_middleware=subagent_extras,
|
||||
mcp_tools_by_agent=mcp_tools_by_agent or {},
|
||||
exclude=get_subagents_to_exclude(available_connectors),
|
||||
disabled_tools=disabled_tools,
|
||||
)
|
||||
logging.debug(
|
||||
"Subagents registry: %s",
|
||||
[s["name"] for s in subagents_registry],
|
||||
)
|
||||
except Exception:
|
||||
# Degrade to general-purpose-only rather than aborting the turn:
|
||||
# one bad subagent dep should not deny the user a response.
|
||||
logging.exception(
|
||||
"Subagents registry build failed; falling back to general-purpose only"
|
||||
)
|
||||
subagents_registry = []
|
||||
|
||||
subagents: list[SubAgent] = [general_purpose_subagent, *subagents_registry]
|
||||
|
||||
stack: list[Any] = [
|
||||
build_busy_mutex_mw(flags),
|
||||
build_otel_mw(flags),
|
||||
build_todos_mw(),
|
||||
memory_mw,
|
||||
build_anonymous_doc_mw(
|
||||
filesystem_mode=filesystem_mode, anon_session_id=anon_session_id
|
||||
),
|
||||
build_knowledge_tree_mw(
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
llm=llm,
|
||||
),
|
||||
build_knowledge_priority_mw(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
),
|
||||
build_file_intent_mw(llm),
|
||||
build_filesystem_mw(
|
||||
backend_resolver=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
build_kb_persistence_mw(
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
build_skills_mw(
|
||||
flags=flags,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
),
|
||||
SurfSenseCheckpointedSubAgentMiddleware(
|
||||
checkpointer=checkpointer,
|
||||
backend=StateBackend,
|
||||
subagents=subagents,
|
||||
),
|
||||
build_selector_mw(flags=flags, tools=tools),
|
||||
resilience.model_call_limit,
|
||||
resilience.tool_call_limit,
|
||||
build_context_editing_mw(
|
||||
flags=flags,
|
||||
max_input_tokens=max_input_tokens,
|
||||
tools=tools,
|
||||
backend_resolver=backend_resolver,
|
||||
),
|
||||
build_compaction_mw(llm),
|
||||
build_noop_injection_mw(flags),
|
||||
resilience.retry,
|
||||
resilience.fallback,
|
||||
build_repair_mw(flags=flags, tools=tools),
|
||||
build_full_permission_mw(permissions.rulesets),
|
||||
build_doom_loop_mw(flags),
|
||||
build_action_log_mw(
|
||||
flags=flags,
|
||||
thread_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
),
|
||||
build_patch_tool_calls_mw(),
|
||||
build_dedup_hitl_mw(tools),
|
||||
*build_plugin_middlewares(
|
||||
flags=flags,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
visibility=visibility,
|
||||
llm=llm,
|
||||
),
|
||||
build_anthropic_cache_mw(),
|
||||
]
|
||||
return [m for m in stack if m is not None]
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
"""Extra middleware threaded into every registry subagent's stack.
|
||||
|
||||
Registry subagents are scoped to one domain (deliverables, research, memory,
|
||||
connectors, MCP) and never read or write the SurfSense filesystem — that
|
||||
capability belongs to the main agent and is delegated to the general-purpose
|
||||
subagent as an escape hatch. Keeping FS off the registry stacks avoids
|
||||
polluting their tool surface with FS tools they never act on.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ..shared.permissions import PermissionContext
|
||||
from ..shared.resilience import ResilienceBundle
|
||||
from ..shared.todos import build_todos_mw
|
||||
|
||||
|
||||
def build_subagent_extras(
|
||||
*,
|
||||
permissions: PermissionContext,
|
||||
resilience: ResilienceBundle,
|
||||
) -> list[Any]:
|
||||
extras: list[Any] = [build_todos_mw()]
|
||||
if permissions.subagent_deny_mw is not None:
|
||||
extras.append(permissions.subagent_deny_mw)
|
||||
extras.extend(resilience.as_list())
|
||||
return extras
|
||||
|
|
@ -0,0 +1,105 @@
|
|||
"""General-purpose subagent for the multi-agent main agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from deepagents import SubAgent
|
||||
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
|
||||
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.shared.anthropic_cache import (
|
||||
build_anthropic_cache_mw,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.compaction import (
|
||||
build_compaction_mw,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.file_intent import (
|
||||
build_file_intent_mw,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.filesystem import (
|
||||
build_filesystem_mw,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.patch_tool_calls import (
|
||||
build_patch_tool_calls_mw,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.permissions import (
|
||||
PermissionContext,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.resilience import (
|
||||
ResilienceBundle,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.todos import build_todos_mw
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import MemoryInjectionMiddleware
|
||||
|
||||
NAME = "general-purpose"
|
||||
|
||||
|
||||
def build_subagent(
|
||||
*,
|
||||
llm: BaseChatModel,
|
||||
tools: Sequence[BaseTool],
|
||||
backend_resolver: Any,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
permissions: PermissionContext,
|
||||
resilience: ResilienceBundle,
|
||||
memory_mw: MemoryInjectionMiddleware,
|
||||
) -> SubAgent:
|
||||
"""Deny + resilience inserts encapsulated here so the orchestrator never mutates the list."""
|
||||
middleware: list[Any] = [
|
||||
build_todos_mw(),
|
||||
memory_mw,
|
||||
build_file_intent_mw(llm),
|
||||
build_filesystem_mw(
|
||||
backend_resolver=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
build_compaction_mw(llm),
|
||||
build_patch_tool_calls_mw(),
|
||||
build_anthropic_cache_mw(),
|
||||
]
|
||||
|
||||
if permissions.subagent_deny_mw is not None:
|
||||
patch_idx = next(
|
||||
(
|
||||
i
|
||||
for i, m in enumerate(middleware)
|
||||
if isinstance(m, PatchToolCallsMiddleware)
|
||||
),
|
||||
len(middleware),
|
||||
)
|
||||
middleware.insert(patch_idx, permissions.subagent_deny_mw)
|
||||
|
||||
resilience_mws = resilience.as_list()
|
||||
if resilience_mws:
|
||||
cache_idx = next(
|
||||
(
|
||||
i
|
||||
for i, m in enumerate(middleware)
|
||||
if isinstance(m, AnthropicPromptCachingMiddleware)
|
||||
),
|
||||
len(middleware),
|
||||
)
|
||||
for offset, mw in enumerate(resilience_mws):
|
||||
middleware.insert(cache_idx + offset, mw)
|
||||
|
||||
spec: dict[str, Any] = {
|
||||
**GENERAL_PURPOSE_SUBAGENT,
|
||||
"model": llm,
|
||||
"tools": tools,
|
||||
"middleware": middleware,
|
||||
}
|
||||
if permissions.general_purpose_interrupt_on:
|
||||
spec["interrupt_on"] = permissions.general_purpose_interrupt_on
|
||||
return cast(SubAgent, spec)
|
||||
|
|
@ -168,20 +168,46 @@ def create_create_calendar_event_tool(
|
|||
f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
tz = context.get("timezone", "UTC")
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
(
|
||||
event_id,
|
||||
html_link,
|
||||
error,
|
||||
) = await ComposioService().create_calendar_event(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
summary=final_summary,
|
||||
start_datetime=final_start_datetime,
|
||||
end_datetime=final_end_datetime,
|
||||
timezone=tz,
|
||||
description=final_description,
|
||||
location=final_location,
|
||||
attendees=final_attendees,
|
||||
)
|
||||
if error:
|
||||
return {"status": "error", "message": error}
|
||||
created = {
|
||||
"id": event_id,
|
||||
"summary": final_summary,
|
||||
"htmlLink": html_link,
|
||||
}
|
||||
logger.info(
|
||||
f"Calendar event created via Composio: id={event_id}, summary={final_summary}"
|
||||
)
|
||||
else:
|
||||
config_data = dict(connector.config)
|
||||
|
||||
|
|
@ -211,70 +237,69 @@ def create_create_calendar_event_tool(
|
|||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
tz = context.get("timezone", "UTC")
|
||||
event_body: dict[str, Any] = {
|
||||
"summary": final_summary,
|
||||
"start": {"dateTime": final_start_datetime, "timeZone": tz},
|
||||
"end": {"dateTime": final_end_datetime, "timeZone": tz},
|
||||
}
|
||||
if final_description:
|
||||
event_body["description"] = final_description
|
||||
if final_location:
|
||||
event_body["location"] = final_location
|
||||
if final_attendees:
|
||||
event_body["attendees"] = [
|
||||
{"email": e.strip()} for e in final_attendees if e.strip()
|
||||
]
|
||||
|
||||
try:
|
||||
created = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.insert(calendarId="primary", body=event_body)
|
||||
.execute()
|
||||
),
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
event_body: dict[str, Any] = {
|
||||
"summary": final_summary,
|
||||
"start": {"dateTime": final_start_datetime, "timeZone": tz},
|
||||
"end": {"dateTime": final_end_datetime, "timeZone": tz},
|
||||
}
|
||||
if final_description:
|
||||
event_body["description"] = final_description
|
||||
if final_location:
|
||||
event_body["location"] = final_location
|
||||
if final_attendees:
|
||||
event_body["attendees"] = [
|
||||
{"email": e.strip()} for e in final_attendees if e.strip()
|
||||
]
|
||||
|
||||
try:
|
||||
created = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.insert(calendarId="primary", body=event_body)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
logger.info(
|
||||
f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}"
|
||||
)
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Calendar event created via Google API: id={created.get('id')}, summary={created.get('summary')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -163,16 +163,22 @@ def create_delete_calendar_event_tool(
|
|||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
error = await ComposioService().delete_calendar_event(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
event_id=final_event_id,
|
||||
)
|
||||
if error:
|
||||
return {"status": "error", "message": error}
|
||||
else:
|
||||
config_data = dict(connector.config)
|
||||
|
||||
|
|
@ -202,51 +208,51 @@ def create_delete_calendar_event_tool(
|
|||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.delete(calendarId="primary", eventId=final_event_id)
|
||||
.execute()
|
||||
),
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
try:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.delete(calendarId="primary", eventId=final_event_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Calendar event deleted: event_id={final_event_id}")
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,14 @@ _CALENDAR_TYPES = [
|
|||
]
|
||||
|
||||
|
||||
def _to_calendar_boundary(value: str, *, is_end: bool) -> str:
|
||||
"""Promote a bare YYYY-MM-DD to RFC3339 with a day-edge time, leave full datetimes alone."""
|
||||
if "T" in value:
|
||||
return value
|
||||
time = "23:59:59" if is_end else "00:00:00"
|
||||
return f"{value}T{time}Z"
|
||||
|
||||
|
||||
def create_search_calendar_events_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
|
|
@ -61,22 +69,47 @@ def create_search_calendar_events_tool(
|
|||
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
|
||||
}
|
||||
|
||||
creds = _build_credentials(connector)
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
|
||||
from app.connectors.google_calendar_connector import GoogleCalendarConnector
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
cal = GoogleCalendarConnector(
|
||||
credentials=creds,
|
||||
session=db_session,
|
||||
user_id=user_id,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
events_raw, error = await ComposioService().get_calendar_events(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
time_min=_to_calendar_boundary(start_date, is_end=False),
|
||||
time_max=_to_calendar_boundary(end_date, is_end=True),
|
||||
max_results=max_results,
|
||||
)
|
||||
if not events_raw and not error:
|
||||
error = "No events found in the specified date range."
|
||||
else:
|
||||
creds = _build_credentials(connector)
|
||||
|
||||
events_raw, error = await cal.get_all_primary_calendar_events(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
max_results=max_results,
|
||||
)
|
||||
from app.connectors.google_calendar_connector import (
|
||||
GoogleCalendarConnector,
|
||||
)
|
||||
|
||||
cal = GoogleCalendarConnector(
|
||||
credentials=creds,
|
||||
session=db_session,
|
||||
user_id=user_id,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
|
||||
events_raw, error = await cal.get_all_primary_calendar_events(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
if error:
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -192,20 +192,62 @@ def create_update_calendar_event_tool(
|
|||
f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
has_changes = any(
|
||||
v is not None
|
||||
for v in (
|
||||
final_new_summary,
|
||||
final_new_start_datetime,
|
||||
final_new_end_datetime,
|
||||
final_new_description,
|
||||
final_new_location,
|
||||
final_new_attendees,
|
||||
)
|
||||
)
|
||||
if not has_changes:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No changes specified. Please provide at least one field to update.",
|
||||
}
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
tz_for_composio: str | None = None
|
||||
if final_new_start_datetime is not None and not _is_date_only(
|
||||
final_new_start_datetime
|
||||
):
|
||||
tz_for_composio = (
|
||||
context.get("timezone") if isinstance(context, dict) else None
|
||||
)
|
||||
|
||||
_, html_link, error = await ComposioService().update_calendar_event(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
event_id=final_event_id,
|
||||
summary=final_new_summary,
|
||||
start_time=final_new_start_datetime,
|
||||
end_time=final_new_end_datetime,
|
||||
timezone=tz_for_composio,
|
||||
description=final_new_description,
|
||||
location=final_new_location,
|
||||
attendees=final_new_attendees,
|
||||
)
|
||||
if error:
|
||||
return {"status": "error", "message": error}
|
||||
updated = {"htmlLink": html_link}
|
||||
logger.info(
|
||||
f"Calendar event updated via Composio: event_id={final_event_id}"
|
||||
)
|
||||
else:
|
||||
config_data = dict(connector.config)
|
||||
|
||||
|
|
@ -235,81 +277,79 @@ def create_update_calendar_event_tool(
|
|||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
update_body: dict[str, Any] = {}
|
||||
if final_new_summary is not None:
|
||||
update_body["summary"] = final_new_summary
|
||||
if final_new_start_datetime is not None:
|
||||
update_body["start"] = _build_time_body(
|
||||
final_new_start_datetime, context
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
if final_new_end_datetime is not None:
|
||||
update_body["end"] = _build_time_body(final_new_end_datetime, context)
|
||||
if final_new_description is not None:
|
||||
update_body["description"] = final_new_description
|
||||
if final_new_location is not None:
|
||||
update_body["location"] = final_new_location
|
||||
if final_new_attendees is not None:
|
||||
update_body["attendees"] = [
|
||||
{"email": e.strip()} for e in final_new_attendees if e.strip()
|
||||
]
|
||||
|
||||
if not update_body:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No changes specified. Please provide at least one field to update.",
|
||||
}
|
||||
|
||||
try:
|
||||
updated = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.patch(
|
||||
calendarId="primary",
|
||||
eventId=final_event_id,
|
||||
body=update_body,
|
||||
)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
update_body: dict[str, Any] = {}
|
||||
if final_new_summary is not None:
|
||||
update_body["summary"] = final_new_summary
|
||||
if final_new_start_datetime is not None:
|
||||
update_body["start"] = _build_time_body(
|
||||
final_new_start_datetime, context
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
if final_new_end_datetime is not None:
|
||||
update_body["end"] = _build_time_body(
|
||||
final_new_end_datetime, context
|
||||
)
|
||||
if final_new_description is not None:
|
||||
update_body["description"] = final_new_description
|
||||
if final_new_location is not None:
|
||||
update_body["location"] = final_new_location
|
||||
if final_new_attendees is not None:
|
||||
update_body["attendees"] = [
|
||||
{"email": e.strip()} for e in final_new_attendees if e.strip()
|
||||
]
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
try:
|
||||
updated = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.patch(
|
||||
calendarId="primary",
|
||||
eventId=final_event_id,
|
||||
body=update_body,
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
logger.info(f"Calendar event updated: event_id={final_event_id}")
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Calendar event updated via Google API: event_id={final_event_id}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
if document_id is not None:
|
||||
|
|
|
|||
|
|
@ -161,16 +161,39 @@ def create_create_gmail_draft_tool(
|
|||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
(
|
||||
draft_id,
|
||||
draft_message_id,
|
||||
draft_thread_id,
|
||||
error,
|
||||
) = await ComposioService().create_gmail_draft(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
to=final_to,
|
||||
subject=final_subject,
|
||||
body=final_body,
|
||||
cc=final_cc,
|
||||
bcc=final_bcc,
|
||||
)
|
||||
if error:
|
||||
return {"status": "error", "message": error}
|
||||
created = {
|
||||
"id": draft_id,
|
||||
"message": {
|
||||
"id": draft_message_id,
|
||||
"threadId": draft_thread_id,
|
||||
},
|
||||
}
|
||||
logger.info(f"Gmail draft created via Composio: id={draft_id}")
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
|
|
@ -208,63 +231,65 @@ def create_create_gmail_draft_tool(
|
|||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
message = MIMEText(final_body)
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
message = MIMEText(final_body)
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
created = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.drafts()
|
||||
.create(userId="me", body={"message": {"raw": raw}})
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
try:
|
||||
created = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.drafts()
|
||||
.create(userId="me", body={"message": {"raw": raw}})
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
logger.info(f"Gmail draft created: id={created.get('id')}")
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Gmail draft created via Google API: id={created.get('id')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -50,7 +50,56 @@ def create_read_gmail_email_tool(
|
|||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
|
||||
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
|
||||
from app.agents.new_chat.tools.gmail.search_emails import (
|
||||
_format_gmail_summary,
|
||||
)
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
detail, error = await ComposioService().get_gmail_message_detail(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
message_id=message_id,
|
||||
)
|
||||
if error:
|
||||
return {"status": "error", "message": error}
|
||||
if not detail:
|
||||
return {
|
||||
"status": "not_found",
|
||||
"message": f"Email with ID '{message_id}' not found.",
|
||||
}
|
||||
|
||||
summary = _format_gmail_summary(detail)
|
||||
content = (
|
||||
f"# {summary['subject']}\n\n"
|
||||
f"**From:** {summary['from']}\n"
|
||||
f"**To:** {summary['to']}\n"
|
||||
f"**Date:** {summary['date']}\n\n"
|
||||
f"## Message Content\n\n"
|
||||
f"{detail.get('messageText') or detail.get('snippet') or ''}\n\n"
|
||||
f"## Message Details\n\n"
|
||||
f"- **Message ID:** {summary['message_id']}\n"
|
||||
f"- **Thread ID:** {summary['thread_id']}\n"
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"message_id": summary["message_id"] or message_id,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
from app.agents.new_chat.tools.gmail.search_emails import (
|
||||
_build_credentials,
|
||||
)
|
||||
|
||||
creds = _build_credentials(connector)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
|
@ -15,57 +14,6 @@ _GMAIL_TYPES = [
|
|||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
_token_encryption_cache: object | None = None
|
||||
|
||||
|
||||
def _get_token_encryption():
|
||||
global _token_encryption_cache
|
||||
if _token_encryption_cache is None:
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise RuntimeError("SECRET_KEY not configured for token decryption.")
|
||||
_token_encryption_cache = TokenEncryption(config.SECRET_KEY)
|
||||
return _token_encryption_cache
|
||||
|
||||
|
||||
def _build_credentials(connector: SearchSourceConnector):
|
||||
"""Build Google OAuth Credentials from a connector's stored config.
|
||||
|
||||
Handles both native OAuth connectors (with encrypted tokens) and
|
||||
Composio-backed connectors. Shared by Gmail and Calendar tools.
|
||||
"""
|
||||
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||
|
||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
raise ValueError("Composio connected account ID not found.")
|
||||
return build_composio_credentials(cca_id)
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
cfg = dict(connector.config)
|
||||
if cfg.get("_token_encrypted"):
|
||||
enc = _get_token_encryption()
|
||||
for key in ("token", "refresh_token", "client_secret"):
|
||||
if cfg.get(key):
|
||||
cfg[key] = enc.decrypt_token(cfg[key])
|
||||
|
||||
exp = (cfg.get("expiry") or "").replace("Z", "")
|
||||
return Credentials(
|
||||
token=cfg.get("token"),
|
||||
refresh_token=cfg.get("refresh_token"),
|
||||
token_uri=cfg.get("token_uri"),
|
||||
client_id=cfg.get("client_id"),
|
||||
client_secret=cfg.get("client_secret"),
|
||||
scopes=cfg.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
|
||||
def create_search_gmail_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
|
|
@ -110,6 +58,50 @@ def create_search_gmail_tool(
|
|||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
|
||||
from app.agents.new_chat.tools.gmail.search_emails import (
|
||||
_format_gmail_summary,
|
||||
)
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
(
|
||||
messages,
|
||||
_next,
|
||||
_estimate,
|
||||
error,
|
||||
) = await ComposioService().get_gmail_messages(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
query=query,
|
||||
max_results=max_results,
|
||||
)
|
||||
if error:
|
||||
return {"status": "error", "message": error}
|
||||
|
||||
emails = [_format_gmail_summary(m) for m in messages]
|
||||
if not emails:
|
||||
return {
|
||||
"status": "success",
|
||||
"emails": [],
|
||||
"total": 0,
|
||||
"message": "No emails found.",
|
||||
}
|
||||
return {"status": "success", "emails": emails, "total": len(emails)}
|
||||
|
||||
from app.agents.new_chat.tools.gmail.search_emails import (
|
||||
_build_credentials,
|
||||
)
|
||||
|
||||
creds = _build_credentials(connector)
|
||||
|
||||
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||
|
|
|
|||
|
|
@ -162,16 +162,31 @@ def create_send_gmail_email_tool(
|
|||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
(
|
||||
sent_message_id,
|
||||
sent_thread_id,
|
||||
error,
|
||||
) = await ComposioService().send_gmail_email(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
to=final_to,
|
||||
subject=final_subject,
|
||||
body=final_body,
|
||||
cc=final_cc,
|
||||
bcc=final_bcc,
|
||||
)
|
||||
if error:
|
||||
return {"status": "error", "message": error}
|
||||
sent = {"id": sent_message_id, "threadId": sent_thread_id}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
|
|
@ -209,61 +224,61 @@ def create_send_gmail_email_tool(
|
|||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
message = MIMEText(final_body)
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
message = MIMEText(final_body)
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
sent = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.messages()
|
||||
.send(userId="me", body={"raw": raw})
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
try:
|
||||
sent = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.messages()
|
||||
.send(userId="me", body={"raw": raw})
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}"
|
||||
|
|
|
|||
|
|
@ -162,16 +162,22 @@ def create_trash_gmail_email_tool(
|
|||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
error = await ComposioService().trash_gmail_message(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
message_id=final_message_id,
|
||||
)
|
||||
if error:
|
||||
return {"status": "error", "message": error}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
|
|
@ -209,49 +215,49 @@ def create_trash_gmail_email_tool(
|
|||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
try:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.messages()
|
||||
.trash(userId="me", id=final_message_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {api_err}"
|
||||
try:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.messages()
|
||||
.trash(userId="me", id=final_message_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
f"Insufficient permissions for connector {connector.id}: {api_err}"
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Gmail email trashed: message_id={final_message_id}")
|
||||
|
||||
|
|
|
|||
|
|
@ -192,16 +192,51 @@ def create_update_gmail_draft_tool(
|
|||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
|
||||
if not final_draft_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
"Could not find this draft in Gmail. "
|
||||
"It may have already been sent or deleted."
|
||||
),
|
||||
}
|
||||
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
(
|
||||
new_draft_id,
|
||||
new_message_id,
|
||||
error,
|
||||
) = await ComposioService().update_gmail_draft(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
draft_id=final_draft_id,
|
||||
to=final_to or None,
|
||||
subject=final_subject,
|
||||
body=final_body,
|
||||
cc=final_cc,
|
||||
bcc=final_bcc,
|
||||
)
|
||||
if error:
|
||||
if "not found" in error.lower() or "no longer" in error.lower():
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
|
||||
}
|
||||
return {"status": "error", "message": error}
|
||||
|
||||
updated = {
|
||||
"id": new_draft_id or final_draft_id,
|
||||
"message": {"id": new_message_id} if new_message_id else {},
|
||||
}
|
||||
logger.info(f"Gmail draft updated via Composio: id={updated.get('id')}")
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
|
|
@ -239,88 +274,90 @@ def create_update_gmail_draft_tool(
|
|||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
# Resolve draft_id if not already available
|
||||
if not final_draft_id:
|
||||
logger.info(
|
||||
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
|
||||
)
|
||||
final_draft_id = await _find_draft_id_by_message(
|
||||
gmail_service, message_id
|
||||
)
|
||||
|
||||
if not final_draft_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
"Could not find this draft in Gmail. "
|
||||
"It may have already been sent or deleted."
|
||||
),
|
||||
}
|
||||
|
||||
message = MIMEText(final_body)
|
||||
if final_to:
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
updated = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.drafts()
|
||||
.update(
|
||||
userId="me",
|
||||
id=final_draft_id,
|
||||
body={"message": {"raw": raw}},
|
||||
)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {api_err}"
|
||||
# Resolve draft_id if not already available
|
||||
if not final_draft_id:
|
||||
logger.info(
|
||||
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
|
||||
)
|
||||
final_draft_id = await _find_draft_id_by_message(
|
||||
gmail_service, message_id
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 404:
|
||||
if not final_draft_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
|
||||
"message": (
|
||||
"Could not find this draft in Gmail. "
|
||||
"It may have already been sent or deleted."
|
||||
),
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Gmail draft updated: id={updated.get('id')}")
|
||||
message = MIMEText(final_body)
|
||||
if final_to:
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
updated = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.drafts()
|
||||
.update(
|
||||
userId="me",
|
||||
id=final_draft_id,
|
||||
body={"message": {"raw": raw}},
|
||||
)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 404:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Gmail draft updated via Google API: id={updated.get('id')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
if document_id:
|
||||
|
|
|
|||
|
|
@ -179,59 +179,96 @@ def create_create_google_drive_file_tool(
|
|||
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
pre_built_creds = None
|
||||
async def _flag_auth_expired() -> None:
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
pre_built_creds = build_composio_credentials(cca_id)
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Google Drive connector.",
|
||||
}
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=db_session,
|
||||
connector_id=actual_connector_id,
|
||||
credentials=pre_built_creds,
|
||||
)
|
||||
try:
|
||||
created = await client.create_file(
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
created, error = await ComposioService().create_drive_file_from_text(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
name=final_name,
|
||||
mime_type=mime_type,
|
||||
parent_folder_id=final_parent_folder_id,
|
||||
content=final_content,
|
||||
parent_id=final_parent_folder_id,
|
||||
)
|
||||
except HttpError as http_err:
|
||||
if http_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
if error or not created:
|
||||
err_lower = (error or "").lower()
|
||||
if (
|
||||
"insufficient" in err_lower
|
||||
or "permission" in err_lower
|
||||
or "403" in err_lower
|
||||
):
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
f"Insufficient permissions for Composio Drive connector {actual_connector_id}: {error}"
|
||||
)
|
||||
await _flag_auth_expired()
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
logger.error(
|
||||
f"Composio Drive create_file failed for connector {actual_connector_id}: {error}"
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
"status": "error",
|
||||
"message": "Something went wrong while creating the file. Please try again.",
|
||||
}
|
||||
raise
|
||||
else:
|
||||
client = GoogleDriveClient(
|
||||
session=db_session,
|
||||
connector_id=actual_connector_id,
|
||||
)
|
||||
try:
|
||||
created = await client.create_file(
|
||||
name=final_name,
|
||||
mime_type=mime_type,
|
||||
parent_folder_id=final_parent_folder_id,
|
||||
content=final_content,
|
||||
)
|
||||
except HttpError as http_err:
|
||||
if http_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
|
||||
)
|
||||
await _flag_auth_expired()
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"
|
||||
|
|
|
|||
|
|
@ -158,51 +158,84 @@ def create_delete_google_drive_file_tool(
|
|||
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
|
||||
)
|
||||
|
||||
pre_built_creds = None
|
||||
async def _flag_auth_expired() -> None:
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
pre_built_creds = build_composio_credentials(cca_id)
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=db_session,
|
||||
connector_id=connector.id,
|
||||
credentials=pre_built_creds,
|
||||
)
|
||||
try:
|
||||
await client.trash_file(file_id=final_file_id)
|
||||
except HttpError as http_err:
|
||||
if http_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {http_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
if not cca_id:
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Google Drive connector.",
|
||||
}
|
||||
raise
|
||||
|
||||
from app.services.composio_service import ComposioService
|
||||
|
||||
error = await ComposioService().trash_drive_file(
|
||||
connected_account_id=cca_id,
|
||||
entity_id=f"surfsense_{user_id}",
|
||||
file_id=final_file_id,
|
||||
)
|
||||
if error:
|
||||
err_lower = error.lower()
|
||||
if (
|
||||
"insufficient" in err_lower
|
||||
or "permission" in err_lower
|
||||
or "403" in err_lower
|
||||
):
|
||||
logger.warning(
|
||||
f"Insufficient permissions for Composio Drive connector {connector.id}: {error}"
|
||||
)
|
||||
await _flag_auth_expired()
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
logger.error(
|
||||
f"Composio Drive trash_file failed for connector {connector.id}: {error}"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while trashing the file. Please try again.",
|
||||
}
|
||||
else:
|
||||
client = GoogleDriveClient(
|
||||
session=db_session,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
try:
|
||||
await client.trash_file(file_id=final_file_id)
|
||||
except HttpError as http_err:
|
||||
if http_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {http_err}"
|
||||
)
|
||||
await _flag_auth_expired()
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Google Drive file deleted (moved to trash): file_id={final_file_id}"
|
||||
|
|
|
|||
|
|
@ -1,11 +1,3 @@
|
|||
"""Jira tools for creating, updating, and deleting issues."""
|
||||
"""Jira route: native tool factories are empty; MCP supplies tools when configured."""
|
||||
|
||||
from .create_issue import create_create_jira_issue_tool
|
||||
from .delete_issue import create_delete_jira_issue_tool
|
||||
from .update_issue import create_update_jira_issue_tool
|
||||
|
||||
__all__ = [
|
||||
"create_create_jira_issue_tool",
|
||||
"create_delete_jira_issue_tool",
|
||||
"create_update_jira_issue_tool",
|
||||
]
|
||||
__all__: list[str] = []
|
||||
|
|
|
|||
|
|
@ -6,29 +6,9 @@ from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
|||
ToolsPermissions,
|
||||
)
|
||||
|
||||
from .create_issue import create_create_jira_issue_tool
|
||||
from .delete_issue import create_delete_jira_issue_tool
|
||||
from .update_issue import create_update_jira_issue_tool
|
||||
|
||||
|
||||
def load_tools(
|
||||
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
|
||||
) -> ToolsPermissions:
|
||||
d = {**(dependencies or {}), **kwargs}
|
||||
common = {
|
||||
"db_session": d["db_session"],
|
||||
"search_space_id": d["search_space_id"],
|
||||
"user_id": d["user_id"],
|
||||
"connector_id": d.get("connector_id"),
|
||||
}
|
||||
create = create_create_jira_issue_tool(**common)
|
||||
update = create_update_jira_issue_tool(**common)
|
||||
delete = create_delete_jira_issue_tool(**common)
|
||||
return {
|
||||
"allow": [],
|
||||
"ask": [
|
||||
{"name": getattr(create, "name", "") or "", "tool": create},
|
||||
{"name": getattr(update, "name", "") or "", "tool": update},
|
||||
{"name": getattr(delete, "name", "") or "", "tool": delete},
|
||||
],
|
||||
}
|
||||
_ = {**(dependencies or {}), **kwargs}
|
||||
return {"allow": [], "ask": []}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,3 @@
|
|||
"""Linear tools for creating, updating, and deleting issues."""
|
||||
"""Linear route: native tool factories are empty; MCP supplies tools when configured."""
|
||||
|
||||
from .create_issue import create_create_linear_issue_tool
|
||||
from .delete_issue import create_delete_linear_issue_tool
|
||||
from .update_issue import create_update_linear_issue_tool
|
||||
|
||||
__all__ = [
|
||||
"create_create_linear_issue_tool",
|
||||
"create_delete_linear_issue_tool",
|
||||
"create_update_linear_issue_tool",
|
||||
]
|
||||
__all__: list[str] = []
|
||||
|
|
|
|||
|
|
@ -1,248 +0,0 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.linear_connector import LinearAPIError, LinearConnector
|
||||
from app.services.linear import LinearToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_create_linear_issue_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the create_linear_issue tool.
|
||||
|
||||
Args:
|
||||
db_session: Database session for accessing the Linear connector
|
||||
search_space_id: Search space ID to find the Linear connector
|
||||
user_id: User ID for fetching user-specific context
|
||||
connector_id: Optional specific connector ID (if known)
|
||||
|
||||
Returns:
|
||||
Configured create_linear_issue tool
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def create_linear_issue(
|
||||
title: str,
|
||||
description: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new issue in Linear.
|
||||
|
||||
Use this tool when the user explicitly asks to create, add, or file
|
||||
a new issue / ticket / task in Linear. The user MUST describe the issue
|
||||
before you call this tool. If the request is vague, ask what the issue
|
||||
should be about. Never call this tool without a clear topic from the user.
|
||||
|
||||
Args:
|
||||
title: Short, descriptive issue title. Infer from the user's request.
|
||||
description: Optional markdown body for the issue. Generate from context.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", or "error"
|
||||
- issue_id: Linear issue UUID (if success)
|
||||
- identifier: Human-readable ID like "ENG-42" (if success)
|
||||
- url: URL to the created issue (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT: If status is "rejected", the user explicitly declined the action.
|
||||
Respond with a brief acknowledgment (e.g., "Understood, I won't create the issue.")
|
||||
and move on. Do NOT retry, troubleshoot, or suggest alternatives.
|
||||
|
||||
Examples:
|
||||
- "Create a Linear issue for the login bug"
|
||||
- "File a ticket about the payment timeout problem"
|
||||
- "Add an issue for the broken search feature"
|
||||
"""
|
||||
logger.info(f"create_linear_issue called: title='{title}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
logger.error(
|
||||
"Linear tool not properly configured - missing required parameters"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Linear tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = LinearToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
workspaces = context.get("workspaces", [])
|
||||
if workspaces and all(w.get("auth_expired") for w in workspaces):
|
||||
logger.warning("All Linear accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Linear accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "linear",
|
||||
}
|
||||
|
||||
logger.info(f"Requesting approval for creating Linear issue: '{title}'")
|
||||
result = request_approval(
|
||||
action_type="linear_issue_creation",
|
||||
tool_name="create_linear_issue",
|
||||
params={
|
||||
"title": title,
|
||||
"description": description,
|
||||
"team_id": None,
|
||||
"state_id": None,
|
||||
"assignee_id": None,
|
||||
"priority": None,
|
||||
"label_ids": [],
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
logger.info("Linear issue creation rejected by user")
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_title = result.params.get("title", title)
|
||||
final_description = result.params.get("description", description)
|
||||
final_team_id = result.params.get("team_id")
|
||||
final_state_id = result.params.get("state_id")
|
||||
final_assignee_id = result.params.get("assignee_id")
|
||||
final_priority = result.params.get("priority")
|
||||
final_label_ids = result.params.get("label_ids") or []
|
||||
final_connector_id = result.params.get("connector_id", connector_id)
|
||||
|
||||
if not final_title or not final_title.strip():
|
||||
logger.error("Title is empty or contains only whitespace")
|
||||
return {"status": "error", "message": "Issue title cannot be empty."}
|
||||
if not final_team_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "A team must be selected to create an issue.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
actual_connector_id = final_connector_id
|
||||
if actual_connector_id is None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Linear connector found. Please connect Linear in your workspace settings.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
logger.info(f"Found Linear connector: id={actual_connector_id}")
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == actual_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Linear connector is invalid or has been disconnected.",
|
||||
}
|
||||
logger.info(f"Validated Linear connector: id={actual_connector_id}")
|
||||
|
||||
logger.info(
|
||||
f"Creating Linear issue with final params: title='{final_title}'"
|
||||
)
|
||||
linear_client = LinearConnector(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
result = await linear_client.create_issue(
|
||||
team_id=final_team_id,
|
||||
title=final_title,
|
||||
description=final_description,
|
||||
state_id=final_state_id,
|
||||
assignee_id=final_assignee_id,
|
||||
priority=final_priority,
|
||||
label_ids=final_label_ids if final_label_ids else None,
|
||||
)
|
||||
|
||||
if result.get("status") == "error":
|
||||
logger.error(f"Failed to create Linear issue: {result.get('message')}")
|
||||
return {"status": "error", "message": result.get("message")}
|
||||
|
||||
logger.info(
|
||||
f"Linear issue created: {result.get('identifier')} - {result.get('title')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.linear import LinearKBSyncService
|
||||
|
||||
kb_service = LinearKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
issue_id=result.get("id"),
|
||||
issue_identifier=result.get("identifier", ""),
|
||||
issue_title=result.get("title", final_title),
|
||||
issue_url=result.get("url"),
|
||||
description=final_description,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"issue_id": result.get("id"),
|
||||
"identifier": result.get("identifier"),
|
||||
"url": result.get("url"),
|
||||
"message": (result.get("message", "") + kb_message_suffix),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error creating Linear issue: {e}", exc_info=True)
|
||||
if isinstance(e, ValueError | LinearAPIError):
|
||||
message = str(e)
|
||||
else:
|
||||
message = (
|
||||
"Something went wrong while creating the issue. Please try again."
|
||||
)
|
||||
return {"status": "error", "message": message}
|
||||
|
||||
return create_linear_issue
|
||||
|
|
@ -1,245 +0,0 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.linear_connector import LinearAPIError, LinearConnector
|
||||
from app.services.linear import LinearToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_delete_linear_issue_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the delete_linear_issue tool.
|
||||
|
||||
Args:
|
||||
db_session: Database session for accessing the Linear connector
|
||||
search_space_id: Search space ID to find the Linear connector
|
||||
user_id: User ID for finding the correct Linear connector
|
||||
connector_id: Optional specific connector ID (if known)
|
||||
|
||||
Returns:
|
||||
Configured delete_linear_issue tool
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def delete_linear_issue(
|
||||
issue_ref: str,
|
||||
delete_from_kb: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Archive (delete) a Linear issue.
|
||||
|
||||
Use this tool when the user asks to delete, remove, or archive a Linear issue.
|
||||
Note that Linear archives issues rather than permanently deleting them
|
||||
(they can be restored from the archive).
|
||||
|
||||
|
||||
Args:
|
||||
issue_ref: The issue to delete. Can be the issue title (e.g. "Fix login bug"),
|
||||
the identifier (e.g. "ENG-42"), or the full document title
|
||||
(e.g. "ENG-42: Fix login bug").
|
||||
delete_from_kb: Whether to also remove the issue from the knowledge base.
|
||||
Default is False. Set to True to remove from both Linear
|
||||
and the knowledge base.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "not_found", or "error"
|
||||
- identifier: Human-readable ID like "ENG-42" (if success)
|
||||
- message: Success or error message
|
||||
- deleted_from_kb: Whether the issue was also removed from the knowledge base (if success)
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined the action.
|
||||
Respond with a brief acknowledgment (e.g., "Understood, I won't delete the issue.")
|
||||
and move on. Do NOT ask for alternatives or troubleshoot.
|
||||
- If status is "not_found", inform the user conversationally using the exact message
|
||||
provided. Do NOT treat this as an error. Simply relay the message and ask the user
|
||||
to verify the issue title or identifier, or check if it has been indexed.
|
||||
Examples:
|
||||
- "Delete the 'Fix login bug' Linear issue"
|
||||
- "Archive ENG-42"
|
||||
- "Remove the 'Old payment flow' issue from Linear"
|
||||
"""
|
||||
logger.info(
|
||||
f"delete_linear_issue called: issue_ref='{issue_ref}', delete_from_kb={delete_from_kb}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
logger.error(
|
||||
"Linear tool not properly configured - missing required parameters"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Linear tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = LinearToolMetadataService(db_session)
|
||||
context = await metadata_service.get_delete_context(
|
||||
search_space_id, user_id, issue_ref
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
logger.warning(f"Auth expired for delete context: {error_msg}")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "linear",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Issue not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
else:
|
||||
logger.error(f"Failed to fetch delete context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
issue_id = context["issue"]["id"]
|
||||
issue_identifier = context["issue"].get("identifier", "")
|
||||
document_id = context["issue"]["document_id"]
|
||||
connector_id_from_context = context.get("workspace", {}).get("id")
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for deleting Linear issue: '{issue_ref}' "
|
||||
f"(id={issue_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="linear_issue_deletion",
|
||||
tool_name="delete_linear_issue",
|
||||
params={
|
||||
"issue_id": issue_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
logger.info("Linear issue deletion rejected by user")
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_issue_id = result.params.get("issue_id", issue_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
logger.info(
|
||||
f"Deleting Linear issue with final params: issue_id={final_issue_id}, "
|
||||
f"connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
|
||||
)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if final_connector_id:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
logger.error(
|
||||
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Linear connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
logger.info(f"Validated Linear connector: id={actual_connector_id}")
|
||||
else:
|
||||
logger.error("No connector found for this issue")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this issue.",
|
||||
}
|
||||
|
||||
linear_client = LinearConnector(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
|
||||
result = await linear_client.archive_issue(issue_id=final_issue_id)
|
||||
|
||||
logger.info(
|
||||
f"archive_issue result: {result.get('status')} - {result.get('message', '')}"
|
||||
)
|
||||
|
||||
deleted_from_kb = False
|
||||
if (
|
||||
result.get("status") == "success"
|
||||
and final_delete_from_kb
|
||||
and document_id
|
||||
):
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
result["warning"] = (
|
||||
f"Issue archived in Linear, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
if result.get("status") == "success":
|
||||
result["deleted_from_kb"] = deleted_from_kb
|
||||
if issue_identifier:
|
||||
result["message"] = (
|
||||
f"Issue {issue_identifier} archived successfully."
|
||||
)
|
||||
if deleted_from_kb:
|
||||
result["message"] = (
|
||||
f"{result.get('message', '')} Also removed from the knowledge base."
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error deleting Linear issue: {e}", exc_info=True)
|
||||
if isinstance(e, ValueError | LinearAPIError):
|
||||
message = str(e)
|
||||
else:
|
||||
message = (
|
||||
"Something went wrong while deleting the issue. Please try again."
|
||||
)
|
||||
return {"status": "error", "message": message}
|
||||
|
||||
return delete_linear_issue
|
||||
|
|
@ -6,29 +6,9 @@ from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
|||
ToolsPermissions,
|
||||
)
|
||||
|
||||
from .create_issue import create_create_linear_issue_tool
|
||||
from .delete_issue import create_delete_linear_issue_tool
|
||||
from .update_issue import create_update_linear_issue_tool
|
||||
|
||||
|
||||
def load_tools(
|
||||
*, dependencies: dict[str, Any] | None = None, **kwargs: Any
|
||||
) -> ToolsPermissions:
|
||||
d = {**(dependencies or {}), **kwargs}
|
||||
common = {
|
||||
"db_session": d["db_session"],
|
||||
"search_space_id": d["search_space_id"],
|
||||
"user_id": d["user_id"],
|
||||
"connector_id": d.get("connector_id"),
|
||||
}
|
||||
create = create_create_linear_issue_tool(**common)
|
||||
update = create_update_linear_issue_tool(**common)
|
||||
delete = create_delete_linear_issue_tool(**common)
|
||||
return {
|
||||
"allow": [],
|
||||
"ask": [
|
||||
{"name": getattr(create, "name", "") or "", "tool": create},
|
||||
{"name": getattr(update, "name", "") or "", "tool": update},
|
||||
{"name": getattr(delete, "name", "") or "", "tool": delete},
|
||||
],
|
||||
}
|
||||
_ = {**(dependencies or {}), **kwargs}
|
||||
return {"allow": [], "ask": []}
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ from langchain.agents import create_agent
|
|||
from langchain.agents.middleware import (
|
||||
LLMToolSelectorMiddleware,
|
||||
ModelCallLimitMiddleware,
|
||||
ModelFallbackMiddleware,
|
||||
TodoListMiddleware,
|
||||
ToolCallLimitMiddleware,
|
||||
)
|
||||
|
|
@ -77,6 +76,9 @@ from app.agents.new_chat.middleware import (
|
|||
create_surfsense_compaction_middleware,
|
||||
default_skills_sources,
|
||||
)
|
||||
from app.agents.new_chat.middleware.scoped_model_fallback import (
|
||||
ScopedModelFallbackMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
from app.agents.new_chat.plugin_loader import (
|
||||
PluginContext,
|
||||
|
|
@ -792,15 +794,15 @@ def _build_compiled_agent_blocking(
|
|||
# Fallback chain — primary is the agent's own model; we add cheap
|
||||
# alternatives. Off by default; only the first call site that
|
||||
# configures the chain via env should enable it.
|
||||
fallback_mw: ModelFallbackMiddleware | None = None
|
||||
fallback_mw: ScopedModelFallbackMiddleware | None = None
|
||||
if flags.enable_model_fallback and not flags.disable_new_agent_stack:
|
||||
try:
|
||||
fallback_mw = ModelFallbackMiddleware(
|
||||
fallback_mw = ScopedModelFallbackMiddleware(
|
||||
"openai:gpt-4o-mini",
|
||||
"anthropic:claude-3-5-haiku-20241022",
|
||||
)
|
||||
except Exception:
|
||||
logging.warning("ModelFallbackMiddleware init failed; skipping.")
|
||||
logging.warning("ScopedModelFallbackMiddleware init failed; skipping.")
|
||||
fallback_mw = None
|
||||
model_call_limit_mw = (
|
||||
ModelCallLimitMiddleware(
|
||||
|
|
|
|||
|
|
@ -46,6 +46,10 @@ class SurfSenseContextSchema:
|
|||
Read by ``KnowledgePriorityMiddleware`` to seed its priority
|
||||
list. Stays out of the compiled-agent cache key — that's the
|
||||
whole point of putting it here.
|
||||
mentioned_folder_ids: KB folders the user @-mentioned this turn
|
||||
(cloud filesystem mode). Surfaced as ``[USER-MENTIONED]``
|
||||
entries in ``<priority_documents>`` so the agent prioritises
|
||||
walking those folders with ``ls`` / ``find_documents``.
|
||||
file_operation_contract: One-shot file operation contract emitted
|
||||
by ``FileIntentMiddleware`` for the upcoming turn.
|
||||
turn_id / request_id: Correlation IDs surfaced by the streaming
|
||||
|
|
@ -59,6 +63,7 @@ class SurfSenseContextSchema:
|
|||
|
||||
search_space_id: int | None = None
|
||||
mentioned_document_ids: list[int] = field(default_factory=list)
|
||||
mentioned_folder_ids: list[int] = field(default_factory=list)
|
||||
file_operation_contract: FileOperationContractState | None = None
|
||||
turn_id: str | None = None
|
||||
request_id: str | None = None
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ Defaults:
|
|||
SURFSENSE_ENABLE_PERMISSION=true
|
||||
SURFSENSE_ENABLE_DOOM_LOOP=true
|
||||
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call
|
||||
SURFSENSE_ENABLE_STREAM_PARITY_V2=true
|
||||
|
||||
Master kill-switch (overrides everything else):
|
||||
|
||||
|
|
@ -88,15 +87,6 @@ class AgentFeatureFlags:
|
|||
enable_action_log: bool = True
|
||||
enable_revert_route: bool = True
|
||||
|
||||
# Streaming parity v2 — opt in to LangChain's structured
|
||||
# ``AIMessageChunk`` content (typed reasoning blocks, tool-input
|
||||
# deltas) and propagate the real ``tool_call_id`` to the SSE layer.
|
||||
# When OFF the ``stream_new_chat`` task falls back to the str-only
|
||||
# text path and the synthetic ``call_<run_id>`` tool-call id (no
|
||||
# ``langchainToolCallId`` propagation). Schema migrations 135/136
|
||||
# ship unconditionally because they're forward-compatible.
|
||||
enable_stream_parity_v2: bool = True
|
||||
|
||||
# Plugins
|
||||
enable_plugin_loader: bool = False
|
||||
|
||||
|
|
@ -169,7 +159,6 @@ class AgentFeatureFlags:
|
|||
enable_kb_planner_runnable=False,
|
||||
enable_action_log=False,
|
||||
enable_revert_route=False,
|
||||
enable_stream_parity_v2=False,
|
||||
enable_plugin_loader=False,
|
||||
enable_otel=False,
|
||||
enable_agent_cache=False,
|
||||
|
|
@ -208,10 +197,6 @@ class AgentFeatureFlags:
|
|||
# Snapshot / revert
|
||||
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", True),
|
||||
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", True),
|
||||
# Streaming parity v2
|
||||
enable_stream_parity_v2=_env_bool(
|
||||
"SURFSENSE_ENABLE_STREAM_PARITY_V2", True
|
||||
),
|
||||
# Plugins
|
||||
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
|
||||
# Observability
|
||||
|
|
|
|||
281
surfsense_backend/app/agents/new_chat/mention_resolver.py
Normal file
281
surfsense_backend/app/agents/new_chat/mention_resolver.py
Normal file
|
|
@ -0,0 +1,281 @@
|
|||
"""Resolve @-mention chips to canonical virtual paths and substitute the
|
||||
user-visible ``@title`` tokens with backtick-wrapped paths in the prompt
|
||||
the agent sees.
|
||||
|
||||
The frontend's mention seam is a single discriminated-union list of
|
||||
``{kind: "doc" | "folder", id, title, document_type?}`` chips (see
|
||||
``surfsense_web/atoms/chat/mentioned-documents.atom.ts``). When a turn
|
||||
reaches the backend stream task we have three needs that this module
|
||||
centralises:
|
||||
|
||||
1. Map each chip to its canonical virtual path
|
||||
(``/documents/.../file.xml`` for docs, ``/documents/MyFolder/`` for
|
||||
folders) so the agent sees concrete filesystem locations instead of
|
||||
ambiguous ``@``-titles.
|
||||
2. Substitute ``@title`` tokens in the user-typed text with backtick-
|
||||
wrapped paths so the path becomes part of the ``HumanMessage`` body
|
||||
the LLM consumes — without rewriting the persisted user message
|
||||
text (which keeps ``@title`` so chip rendering on reload is
|
||||
unchanged).
|
||||
3. Surface the resolved id sets (docs + folders) to the priority
|
||||
middleware so it can render ``[USER-MENTIONED]`` priority entries
|
||||
without re-doing path resolution.
|
||||
|
||||
This is intentionally one module — see the architectural note in
|
||||
``mention-paths-and-folders`` plan: previously the doc-resolution lived
|
||||
inline in ``stream_new_chat`` and the folder mention had no resolution
|
||||
at all. Centralising both behind a single ``resolve_mentions`` call
|
||||
turns a leaky multi-field seam into a single deeper interface.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.path_resolver import (
|
||||
DOCUMENTS_ROOT,
|
||||
build_path_index,
|
||||
doc_to_virtual_path,
|
||||
)
|
||||
from app.db import Document, Folder
|
||||
from app.schemas.new_chat import MentionedDocumentInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResolvedMention:
|
||||
"""Canonical view of a single @-mention chip.
|
||||
|
||||
``virtual_path`` is the path the agent will see (no trailing slash
|
||||
for documents, trailing ``/`` for folders to match the convention
|
||||
used by ``KnowledgeTreeMiddleware``).
|
||||
"""
|
||||
|
||||
kind: str # "doc" | "folder"
|
||||
id: int
|
||||
title: str
|
||||
virtual_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolvedMentionSet:
|
||||
"""Aggregate result of resolving a turn's mention chips.
|
||||
|
||||
``token_to_path`` maps ``@title`` (the literal token the user typed
|
||||
and the editor emitted) to the canonical virtual path for that
|
||||
chip. It is produced longest-token-first so substitution mirrors
|
||||
``parseMentionSegments`` on the frontend (a longer title like
|
||||
``@Project Roadmap`` is never shadowed by a shorter prefix
|
||||
``@Project``).
|
||||
|
||||
``mentioned_document_ids`` collapses doc + surfsense_doc chips into
|
||||
a single ordered, deduped list because the priority middleware
|
||||
treats them uniformly downstream — see
|
||||
``KnowledgePriorityMiddleware._compute_priority_paths``.
|
||||
"""
|
||||
|
||||
mentions: list[ResolvedMention] = field(default_factory=list)
|
||||
token_to_path: list[tuple[str, str]] = field(default_factory=list)
|
||||
mentioned_document_ids: list[int] = field(default_factory=list)
|
||||
mentioned_folder_ids: list[int] = field(default_factory=list)
|
||||
|
||||
|
||||
def _folder_virtual_path(folder_id: int, folder_paths: dict[int, str]) -> str:
|
||||
"""Return ``/documents/Folder/Sub/`` for a folder id.
|
||||
|
||||
Falls back to the documents root when the folder is missing from
|
||||
the index (deleted or in a different search space). Trailing slash
|
||||
matches ``KnowledgeTreeMiddleware`` (``/documents/MyFolder/``) so
|
||||
the agent's ``ls`` can dispatch on it as a directory.
|
||||
"""
|
||||
base = folder_paths.get(folder_id, DOCUMENTS_ROOT)
|
||||
return f"{base}/" if not base.endswith("/") else base
|
||||
|
||||
|
||||
async def resolve_mentions(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
mentioned_documents: list[MentionedDocumentInfo] | None,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
mentioned_surfsense_doc_ids: list[int] | None = None,
|
||||
mentioned_folder_ids: list[int] | None = None,
|
||||
) -> ResolvedMentionSet:
|
||||
"""Resolve every @-mention chip on a turn into virtual paths.
|
||||
|
||||
The function takes both the ``mentioned_documents`` discriminated
|
||||
list (chip metadata used for substitution + persistence) and the
|
||||
parallel id arrays (``mentioned_document_ids``,
|
||||
``mentioned_surfsense_doc_ids``, ``mentioned_folder_ids``) for two
|
||||
reasons:
|
||||
|
||||
* Legacy clients that haven't migrated to the unified chip list
|
||||
still send the id arrays — we treat the union as authoritative.
|
||||
* The id arrays are the canonical input to
|
||||
``KnowledgePriorityMiddleware`` (via ``SurfSenseContextSchema``);
|
||||
returning the deduped, validated lists lets the route forward
|
||||
them unchanged.
|
||||
|
||||
Resolution is best-effort: a chip whose id no longer exists (e.g.
|
||||
document was deleted between mention and submit) is silently
|
||||
dropped. The agent still sees the user's original text, just
|
||||
without a backtick-path substitution for that chip.
|
||||
"""
|
||||
chip_doc_ids: list[int] = []
|
||||
chip_folder_ids: list[int] = []
|
||||
chip_titles_by_id: dict[tuple[str, int], str] = {}
|
||||
if mentioned_documents:
|
||||
for chip in mentioned_documents:
|
||||
kind = chip.kind
|
||||
if kind == "folder":
|
||||
chip_folder_ids.append(chip.id)
|
||||
else:
|
||||
chip_doc_ids.append(chip.id)
|
||||
chip_titles_by_id[(kind, chip.id)] = chip.title
|
||||
|
||||
doc_id_pool: list[int] = list(
|
||||
dict.fromkeys(
|
||||
[
|
||||
*(mentioned_document_ids or []),
|
||||
*(mentioned_surfsense_doc_ids or []),
|
||||
*chip_doc_ids,
|
||||
]
|
||||
)
|
||||
)
|
||||
folder_id_pool: list[int] = list(
|
||||
dict.fromkeys([*(mentioned_folder_ids or []), *chip_folder_ids])
|
||||
)
|
||||
|
||||
if not doc_id_pool and not folder_id_pool:
|
||||
return ResolvedMentionSet()
|
||||
|
||||
index = await build_path_index(session, search_space_id)
|
||||
|
||||
doc_rows: dict[int, Document] = {}
|
||||
if doc_id_pool:
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.id.in_(doc_id_pool),
|
||||
)
|
||||
)
|
||||
for row in result.scalars().all():
|
||||
doc_rows[row.id] = row
|
||||
|
||||
folder_rows: dict[int, Folder] = {}
|
||||
if folder_id_pool:
|
||||
result = await session.execute(
|
||||
select(Folder).where(
|
||||
Folder.search_space_id == search_space_id,
|
||||
Folder.id.in_(folder_id_pool),
|
||||
)
|
||||
)
|
||||
for row in result.scalars().all():
|
||||
folder_rows[row.id] = row
|
||||
|
||||
resolved: list[ResolvedMention] = []
|
||||
accepted_doc_ids: list[int] = []
|
||||
accepted_folder_ids: list[int] = []
|
||||
|
||||
for doc_id in doc_id_pool:
|
||||
row = doc_rows.get(doc_id)
|
||||
if row is None:
|
||||
logger.debug(
|
||||
"mention_resolver: dropping doc id=%s (not found in space=%s)",
|
||||
doc_id,
|
||||
search_space_id,
|
||||
)
|
||||
continue
|
||||
title = chip_titles_by_id.get(("doc", doc_id), str(row.title or ""))
|
||||
path = doc_to_virtual_path(
|
||||
doc_id=row.id,
|
||||
title=str(row.title or "untitled"),
|
||||
folder_id=row.folder_id,
|
||||
index=index,
|
||||
)
|
||||
resolved.append(
|
||||
ResolvedMention(kind="doc", id=row.id, title=title, virtual_path=path)
|
||||
)
|
||||
accepted_doc_ids.append(row.id)
|
||||
|
||||
for folder_id in folder_id_pool:
|
||||
row = folder_rows.get(folder_id)
|
||||
if row is None:
|
||||
logger.debug(
|
||||
"mention_resolver: dropping folder id=%s (not found in space=%s)",
|
||||
folder_id,
|
||||
search_space_id,
|
||||
)
|
||||
continue
|
||||
title = chip_titles_by_id.get(("folder", folder_id), str(row.name or ""))
|
||||
path = _folder_virtual_path(row.id, index.folder_paths)
|
||||
resolved.append(
|
||||
ResolvedMention(kind="folder", id=row.id, title=title, virtual_path=path)
|
||||
)
|
||||
accepted_folder_ids.append(row.id)
|
||||
|
||||
token_to_path: list[tuple[str, str]] = []
|
||||
seen_tokens: set[str] = set()
|
||||
for mention in resolved:
|
||||
if not mention.title:
|
||||
continue
|
||||
token = f"@{mention.title}"
|
||||
if token in seen_tokens:
|
||||
continue
|
||||
seen_tokens.add(token)
|
||||
token_to_path.append((token, mention.virtual_path))
|
||||
token_to_path.sort(key=lambda pair: len(pair[0]), reverse=True)
|
||||
|
||||
return ResolvedMentionSet(
|
||||
mentions=resolved,
|
||||
token_to_path=token_to_path,
|
||||
mentioned_document_ids=accepted_doc_ids,
|
||||
mentioned_folder_ids=accepted_folder_ids,
|
||||
)
|
||||
|
||||
|
||||
def substitute_in_text(text: str, token_to_path: list[tuple[str, str]]) -> str:
|
||||
"""Replace each ``@title`` token with a backtick-wrapped virtual path.
|
||||
|
||||
Mirrors ``parseMentionSegments`` on the frontend: longest token
|
||||
first, single forward pass, no regex (titles can contain regex
|
||||
metacharacters). The substitution is idempotent for already-
|
||||
substituted text because the backtick-wrapped path no longer
|
||||
starts with ``@``.
|
||||
|
||||
Empty / no-op cases short-circuit so callers can pass this through
|
||||
unconditionally without paying for a scan.
|
||||
"""
|
||||
if not text or not token_to_path:
|
||||
return text
|
||||
|
||||
out: list[str] = []
|
||||
i = 0
|
||||
n = len(text)
|
||||
while i < n:
|
||||
matched: tuple[str, str] | None = None
|
||||
for token, path in token_to_path:
|
||||
if text.startswith(token, i):
|
||||
matched = (token, path)
|
||||
break
|
||||
if matched is None:
|
||||
out.append(text[i])
|
||||
i += 1
|
||||
continue
|
||||
token, path = matched
|
||||
out.append(f"`{path}`")
|
||||
i += len(token)
|
||||
return "".join(out)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ResolvedMention",
|
||||
"ResolvedMentionSet",
|
||||
"resolve_mentions",
|
||||
"substitute_in_text",
|
||||
]
|
||||
|
|
@ -54,6 +54,7 @@ from app.db import (
|
|||
NATIVE_TO_LEGACY_DOCTYPE,
|
||||
Chunk,
|
||||
Document,
|
||||
Folder,
|
||||
shielded_async_session,
|
||||
)
|
||||
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
|
|
@ -832,6 +833,22 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
mention_ids = list(self.mentioned_document_ids)
|
||||
self.mentioned_document_ids = []
|
||||
|
||||
# Folder mentions live alongside doc mentions on the runtime
|
||||
# context. They never feed hybrid search (folders aren't
|
||||
# embedded) — they're surfaced purely as ``[USER-MENTIONED]``
|
||||
# priority entries so the agent walks the folder with ``ls`` /
|
||||
# ``find_documents`` instead of ignoring it. Cloud filesystem
|
||||
# mode only.
|
||||
folder_mention_ids: list[int] = []
|
||||
if (
|
||||
ctx is not None
|
||||
and getattr(self, "filesystem_mode", FilesystemMode.CLOUD)
|
||||
== FilesystemMode.CLOUD
|
||||
):
|
||||
ctx_folders = getattr(ctx, "mentioned_folder_ids", None)
|
||||
if ctx_folders:
|
||||
folder_mention_ids = list(ctx_folders)
|
||||
|
||||
mentioned_results: list[dict[str, Any]] = []
|
||||
if mention_ids:
|
||||
mentioned_results = await fetch_mentioned_documents(
|
||||
|
|
@ -876,16 +893,21 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
|
||||
priority, matched_chunk_ids = await self._materialize_priority(merged)
|
||||
|
||||
if folder_mention_ids:
|
||||
folder_entries = await self._materialize_folder_priority(folder_mention_ids)
|
||||
priority = folder_entries + priority
|
||||
|
||||
new_messages = list(messages)
|
||||
insert_at = max(len(new_messages) - 1, 0)
|
||||
new_messages.insert(insert_at, _render_priority_message(priority))
|
||||
|
||||
_perf_log.info(
|
||||
"[kb_priority] completed in %.3fs query=%r priority=%d mentioned=%d",
|
||||
"[kb_priority] completed in %.3fs query=%r priority=%d mentioned=%d folders=%d",
|
||||
asyncio.get_event_loop().time() - t0,
|
||||
user_text[:80],
|
||||
len(priority),
|
||||
len(mentioned_results),
|
||||
len(folder_mention_ids),
|
||||
)
|
||||
|
||||
return {
|
||||
|
|
@ -894,6 +916,58 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
"messages": new_messages,
|
||||
}
|
||||
|
||||
async def _materialize_folder_priority(
|
||||
self, folder_ids: list[int]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Resolve user-mentioned folder ids to ``<priority_documents>`` entries.
|
||||
|
||||
Each entry uses the canonical ``/documents/Folder/Sub/`` virtual
|
||||
path (matching ``KnowledgeTreeMiddleware`` and the agent's
|
||||
``ls`` adapter) and is flagged ``mentioned=True`` so the
|
||||
rendered line carries ``[USER-MENTIONED]``. ``score`` is left
|
||||
``None`` so the renderer prints ``n/a`` — folders aren't
|
||||
ranked, the agent decides which children to read.
|
||||
"""
|
||||
if not folder_ids:
|
||||
return []
|
||||
async with shielded_async_session() as session:
|
||||
index: PathIndex = await build_path_index(session, self.search_space_id)
|
||||
folder_rows = await session.execute(
|
||||
select(Folder.id, Folder.name).where(
|
||||
Folder.search_space_id == self.search_space_id,
|
||||
Folder.id.in_(folder_ids),
|
||||
)
|
||||
)
|
||||
folder_titles: dict[int, str] = {
|
||||
row.id: row.name for row in folder_rows.all()
|
||||
}
|
||||
|
||||
entries: list[dict[str, Any]] = []
|
||||
seen: set[int] = set()
|
||||
for folder_id in folder_ids:
|
||||
if folder_id in seen:
|
||||
continue
|
||||
seen.add(folder_id)
|
||||
base = index.folder_paths.get(folder_id)
|
||||
if base is None:
|
||||
logger.debug(
|
||||
"kb_priority: dropping folder id=%s (missing from path index)",
|
||||
folder_id,
|
||||
)
|
||||
continue
|
||||
path = base if base.endswith("/") else f"{base}/"
|
||||
entries.append(
|
||||
{
|
||||
"path": path,
|
||||
"score": None,
|
||||
"document_id": None,
|
||||
"folder_id": folder_id,
|
||||
"title": folder_titles.get(folder_id, ""),
|
||||
"mentioned": True,
|
||||
}
|
||||
)
|
||||
return entries
|
||||
|
||||
async def _materialize_priority(
|
||||
self, merged: list[dict[str, Any]]
|
||||
) -> tuple[list[dict[str, Any]], dict[int, list[int]]]:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,91 @@
|
|||
"""Fallback only on provider/network errors; let programming bugs raise."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain.agents.middleware import ModelFallbackMiddleware
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
|
||||
# Matched by class name across the MRO so we don't have to import every
|
||||
# provider SDK (openai/anthropic/google/...). Extend as new providers ship.
|
||||
_FALLBACK_ELIGIBLE_NAMES: frozenset[str] = frozenset(
|
||||
{
|
||||
"RateLimitError",
|
||||
"APIStatusError",
|
||||
"InternalServerError",
|
||||
"ServiceUnavailableError",
|
||||
"BadGatewayError",
|
||||
"GatewayTimeoutError",
|
||||
"APIConnectionError",
|
||||
"APITimeoutError",
|
||||
"ConnectError",
|
||||
"ConnectTimeout",
|
||||
"ReadTimeout",
|
||||
"RemoteProtocolError",
|
||||
"TimeoutError",
|
||||
"TimeoutException",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _is_fallback_eligible(exc: BaseException) -> bool:
|
||||
return any(cls.__name__ in _FALLBACK_ELIGIBLE_NAMES for cls in type(exc).__mro__)
|
||||
|
||||
|
||||
class ScopedModelFallbackMiddleware(ModelFallbackMiddleware):
|
||||
"""Re-raise non-provider exceptions instead of walking the fallback chain."""
|
||||
|
||||
def wrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[Any],
|
||||
handler: Callable[[ModelRequest[Any]], ModelResponse[Any]],
|
||||
) -> ModelResponse[Any] | AIMessage:
|
||||
last_exception: Exception
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception as e:
|
||||
if not _is_fallback_eligible(e):
|
||||
raise
|
||||
last_exception = e
|
||||
|
||||
for fallback_model in self.models:
|
||||
try:
|
||||
return handler(request.override(model=fallback_model))
|
||||
except Exception as e:
|
||||
if not _is_fallback_eligible(e):
|
||||
raise
|
||||
last_exception = e
|
||||
continue
|
||||
|
||||
raise last_exception
|
||||
|
||||
async def awrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[Any],
|
||||
handler: Callable[[ModelRequest[Any]], Awaitable[ModelResponse[Any]]],
|
||||
) -> ModelResponse[Any] | AIMessage:
|
||||
last_exception: Exception
|
||||
try:
|
||||
return await handler(request)
|
||||
except Exception as e:
|
||||
if not _is_fallback_eligible(e):
|
||||
raise
|
||||
last_exception = e
|
||||
|
||||
for fallback_model in self.models:
|
||||
try:
|
||||
return await handler(request.override(model=fallback_model))
|
||||
except Exception as e:
|
||||
if not _is_fallback_eligible(e):
|
||||
raise
|
||||
last_exception = e
|
||||
continue
|
||||
|
||||
raise last_exception
|
||||
|
|
@ -1,5 +1,15 @@
|
|||
import re
|
||||
|
||||
from app.config import config
|
||||
|
||||
# Regex that matches a Markdown table block (header + separator + one or more rows)
|
||||
# A table block starts with a | at the beginning of a line and ends when a
|
||||
# non-table line (or end of string) is encountered.
|
||||
_TABLE_BLOCK_RE = re.compile(
|
||||
r"(?:(?:^|\n)(?=[ \t]*\|)(?:[ \t]*\|[^\n]*\n)+)",
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
|
||||
def chunk_text(text: str, use_code_chunker: bool = False) -> list[str]:
|
||||
"""Chunk a text string using the configured chunker and return the chunk texts."""
|
||||
|
|
@ -7,3 +17,43 @@ def chunk_text(text: str, use_code_chunker: bool = False) -> list[str]:
|
|||
config.code_chunker_instance if use_code_chunker else config.chunker_instance
|
||||
)
|
||||
return [c.text for c in chunker.chunk(text)]
|
||||
|
||||
|
||||
def chunk_text_hybrid(text: str) -> list[str]:
|
||||
"""Table-aware chunker that prevents Markdown tables from being split mid-row.
|
||||
|
||||
Algorithm:
|
||||
1. Scan the document for Markdown table blocks.
|
||||
2. Each table block is emitted as a single, unmodified chunk so that its
|
||||
header, separator row, and data rows always stay together.
|
||||
3. The non-table prose segments between (and around) tables are passed through
|
||||
the normal ``chunk_text`` chunker and their sub-chunks are interleaved in
|
||||
document order.
|
||||
|
||||
This ensures that table data is never sliced in the middle by the token-based
|
||||
chunker, which would otherwise produce garbled rows that are useless for RAG.
|
||||
|
||||
Fixes #1334.
|
||||
"""
|
||||
chunks: list[str] = []
|
||||
cursor = 0
|
||||
|
||||
for match in _TABLE_BLOCK_RE.finditer(text):
|
||||
# Prose before this table
|
||||
prose = text[cursor : match.start()].strip()
|
||||
if prose:
|
||||
chunks.extend(chunk_text(prose))
|
||||
|
||||
# The table itself is kept as one indivisible chunk
|
||||
table_block = match.group(0).strip()
|
||||
if table_block:
|
||||
chunks.append(table_block)
|
||||
|
||||
cursor = match.end()
|
||||
|
||||
# Remaining prose after the last table (or entire text if no tables)
|
||||
trailing = text[cursor:].strip()
|
||||
if trailing:
|
||||
chunks.extend(chunk_text(trailing))
|
||||
|
||||
return chunks
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from app.db import (
|
|||
DocumentType,
|
||||
)
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_chunker import chunk_text
|
||||
from app.indexing_pipeline.document_chunker import chunk_text, chunk_text_hybrid
|
||||
from app.indexing_pipeline.document_embedder import embed_texts
|
||||
from app.indexing_pipeline.document_hashing import (
|
||||
compute_content_hash,
|
||||
|
|
@ -387,11 +387,19 @@ class IndexingPipelineService:
|
|||
)
|
||||
|
||||
t_step = time.perf_counter()
|
||||
chunk_texts = await asyncio.to_thread(
|
||||
chunk_text,
|
||||
connector_doc.source_markdown,
|
||||
use_code_chunker=connector_doc.should_use_code_chunker,
|
||||
)
|
||||
if connector_doc.should_use_code_chunker:
|
||||
chunk_texts = await asyncio.to_thread(
|
||||
chunk_text,
|
||||
connector_doc.source_markdown,
|
||||
use_code_chunker=True,
|
||||
)
|
||||
else:
|
||||
# Use the table-aware hybrid chunker so Markdown tables are not
|
||||
# split mid-row (see issue #1334).
|
||||
chunk_texts = await asyncio.to_thread(
|
||||
chunk_text_hybrid,
|
||||
connector_doc.source_markdown,
|
||||
)
|
||||
|
||||
texts_to_embed = [content, *chunk_texts]
|
||||
embeddings = await asyncio.to_thread(embed_texts, texts_to_embed)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.agents.new_chat.path_resolver import virtual_path_to_doc
|
||||
from app.db import (
|
||||
Chunk,
|
||||
Document,
|
||||
|
|
@ -752,7 +753,24 @@ async def get_document_by_virtual_path(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Resolve a knowledge-base document id by exact virtual path."""
|
||||
"""Resolve a knowledge-base document by its agent-facing virtual path.
|
||||
|
||||
The agent renders every document under ``/documents/...`` with a
|
||||
``.xml`` extension appended via ``safe_filename`` (so a PDF titled
|
||||
``2025-W2.pdf`` becomes ``/documents/2025-W2.pdf.xml``). When the user
|
||||
clicks that path in an answer, this endpoint must round-trip back to
|
||||
the underlying ``Document`` row regardless of its type — agent-created
|
||||
NOTE docs (which carry ``virtual_path`` in metadata), uploaded PDFs,
|
||||
and connector docs all flow through here.
|
||||
|
||||
Resolution is delegated to :func:`virtual_path_to_doc`, the single
|
||||
source of truth that handles:
|
||||
|
||||
* ``unique_identifier_hash`` lookup (agent NOTE fast path)
|
||||
* ``" (<doc_id>).xml"`` disambiguation suffixes
|
||||
* ``.xml`` extension stripping for title-based fallback
|
||||
* ``safe_filename`` round-trip for connector titles with lossy chars
|
||||
"""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
|
|
@ -762,24 +780,19 @@ async def get_document_by_virtual_path(
|
|||
"You don't have permission to read documents in this search space",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(
|
||||
Document.id,
|
||||
Document.title,
|
||||
Document.document_type,
|
||||
).filter(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_metadata["virtual_path"].as_string() == virtual_path,
|
||||
)
|
||||
document = await virtual_path_to_doc(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
virtual_path=virtual_path,
|
||||
)
|
||||
row = result.first()
|
||||
if row is None:
|
||||
if document is None:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
return DocumentTitleRead(
|
||||
id=row.id,
|
||||
title=row.title,
|
||||
document_type=row.document_type,
|
||||
id=document.id,
|
||||
title=document.title,
|
||||
document_type=document.document_type,
|
||||
folder_id=document.folder_id,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -71,7 +71,10 @@ from app.schemas.new_chat import (
|
|||
TokenUsageSummary,
|
||||
TurnStatusResponse,
|
||||
)
|
||||
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
||||
from app.tasks.chat.stream_new_chat import (
|
||||
stream_new_chat,
|
||||
stream_resume_chat,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.perf import get_perf_logger
|
||||
from app.utils.rbac import check_permission
|
||||
|
|
@ -1778,6 +1781,7 @@ async def handle_new_chat(
|
|||
llm_config_id=llm_config_id,
|
||||
mentioned_document_ids=request.mentioned_document_ids,
|
||||
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
|
||||
mentioned_folder_ids=request.mentioned_folder_ids,
|
||||
mentioned_documents=mentioned_documents_payload,
|
||||
needs_history_bootstrap=thread.needs_history_bootstrap,
|
||||
thread_visibility=thread.visibility,
|
||||
|
|
@ -2263,6 +2267,7 @@ async def regenerate_response(
|
|||
llm_config_id=llm_config_id,
|
||||
mentioned_document_ids=request.mentioned_document_ids,
|
||||
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
|
||||
mentioned_folder_ids=request.mentioned_folder_ids,
|
||||
mentioned_documents=mentioned_documents_payload,
|
||||
checkpoint_id=target_checkpoint_id,
|
||||
needs_history_bootstrap=thread.needs_history_bootstrap,
|
||||
|
|
|
|||
|
|
@ -201,18 +201,34 @@ class NewChatUserImagePart(BaseModel):
|
|||
|
||||
|
||||
class MentionedDocumentInfo(BaseModel):
|
||||
"""Display metadata for a single ``@``-mentioned document.
|
||||
"""Display metadata for a single ``@``-mention chip.
|
||||
|
||||
The full triple ``{id, title, document_type}`` is forwarded by the
|
||||
frontend mention chip so the server can embed it in the persisted
|
||||
user message ``ContentPart[]`` (single ``mentioned-documents`` part).
|
||||
The history loader then renders the chips on reload without an extra
|
||||
Carries either a knowledge-base document or a knowledge-base folder
|
||||
(discriminated by ``kind``). The full triple
|
||||
``{id, title, document_type}`` is forwarded by the frontend mention
|
||||
chip so the server can embed it in the persisted user message
|
||||
``ContentPart[]`` (single ``mentioned-documents`` part). The
|
||||
history loader then renders the chips on reload without an extra
|
||||
fetch — mirrors the pre-refactor frontend ``persistUserTurn`` shape.
|
||||
|
||||
``kind`` defaults to ``"doc"`` so legacy clients and persisted rows
|
||||
that predate folder mentions deserialise unchanged.
|
||||
"""
|
||||
|
||||
id: int
|
||||
title: str = Field(..., min_length=1, max_length=500)
|
||||
document_type: str = Field(..., min_length=1, max_length=100)
|
||||
kind: Literal["doc", "folder"] = Field(
|
||||
default="doc",
|
||||
description=(
|
||||
"Discriminator for the chip's referent: ``doc`` is a "
|
||||
"knowledge-base ``Document`` row, ``folder`` is a "
|
||||
"knowledge-base ``Folder`` row. Folders carry the sentinel "
|
||||
"``document_type='FOLDER'`` to keep the frontend dedup key "
|
||||
"``(kind:document_type:id)`` from colliding doc and folder "
|
||||
"ids that happen to share an integer value."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class NewChatRequest(BaseModel):
|
||||
|
|
@ -228,15 +244,26 @@ class NewChatRequest(BaseModel):
|
|||
mentioned_surfsense_doc_ids: list[int] | None = (
|
||||
None # Optional SurfSense documentation IDs mentioned with @ in the chat
|
||||
)
|
||||
mentioned_folder_ids: list[int] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional knowledge-base folder IDs the user mentioned with "
|
||||
"@. Resolved to virtual paths (``/documents/.../``) by "
|
||||
"``mention_resolver`` and surfaced to the agent via "
|
||||
"(a) backtick-wrapped substitution in ``user_query`` and "
|
||||
"(b) a ``[USER-MENTIONED]`` entry in ``<priority_documents>``. "
|
||||
"The agent's ``ls`` tool can then walk the folder itself."
|
||||
),
|
||||
)
|
||||
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Display metadata (id, title, document_type) for every "
|
||||
"@-mentioned document. Persisted as a ``mentioned-documents`` "
|
||||
"ContentPart on the user message so reload renders chips "
|
||||
"without an extra fetch. Optional and additive — when None "
|
||||
"the user message is persisted without a mentioned-documents "
|
||||
"part."
|
||||
"Display metadata (id, title, document_type, kind) for every "
|
||||
"@-mention chip — both documents and folders. Persisted as a "
|
||||
"``mentioned-documents`` ContentPart on the user message so "
|
||||
"reload renders chips without an extra fetch. Optional and "
|
||||
"additive — when None the user message is persisted without "
|
||||
"a mentioned-documents part."
|
||||
),
|
||||
)
|
||||
disabled_tools: list[str] | None = (
|
||||
|
|
@ -290,14 +317,22 @@ class RegenerateRequest(BaseModel):
|
|||
)
|
||||
mentioned_document_ids: list[int] | None = None
|
||||
mentioned_surfsense_doc_ids: list[int] | None = None
|
||||
mentioned_folder_ids: list[int] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional knowledge-base folder IDs the user mentioned with "
|
||||
"@ on the edited user turn. Only used when ``user_query`` is "
|
||||
"non-None (edit). Mirrors ``NewChatRequest.mentioned_folder_ids``."
|
||||
),
|
||||
)
|
||||
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Display metadata (id, title, document_type) for every "
|
||||
"@-mentioned document on the edited user turn. Only used "
|
||||
"when ``user_query`` is non-None (edit). Persisted as a "
|
||||
"``mentioned-documents`` ContentPart on the new user "
|
||||
"message. None means no chip metadata."
|
||||
"Display metadata (id, title, document_type, kind) for every "
|
||||
"@-mention chip on the edited user turn — both documents and "
|
||||
"folders. Only used when ``user_query`` is non-None (edit). "
|
||||
"Persisted as a ``mentioned-documents`` ContentPart on the "
|
||||
"new user message. None means no chip metadata."
|
||||
),
|
||||
)
|
||||
disabled_tools: list[str] | None = None
|
||||
|
|
@ -373,6 +408,16 @@ class ResumeRequest(BaseModel):
|
|||
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
|
||||
client_platform: Literal["web", "desktop"] = "web"
|
||||
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
|
||||
mentioned_folder_ids: list[int] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Forwarded for symmetry with /new_chat and /regenerate. "
|
||||
"Resume reuses the original interrupted user turn so this "
|
||||
"field is informational only — the originating turn's "
|
||||
"folder mentions already shaped the priority hints baked "
|
||||
"into the agent's checkpoint."
|
||||
),
|
||||
)
|
||||
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
|
|
@ -380,7 +425,7 @@ class ResumeRequest(BaseModel):
|
|||
"/regenerate. Resume reuses the original interrupted user "
|
||||
"turn so the server does not write a new user message. "
|
||||
"Currently unused but accepted to keep request bodies "
|
||||
"uniform across the three streaming entrypoints."
|
||||
"uniform across new-message, regenerate, and resume stream routes."
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1027,6 +1027,505 @@ class ComposioService:
|
|||
logger.error(f"Failed to list Calendar events: {e!s}")
|
||||
return [], str(e)
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_response_data(data: Any) -> Any:
|
||||
"""Composio responses often nest the meaningful payload under
|
||||
``data.data.response_data``. Walk that envelope safely and return
|
||||
whichever inner dict actually has the result keys."""
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
inner = data.get("data", data)
|
||||
if isinstance(inner, dict):
|
||||
return inner.get("response_data", inner)
|
||||
return inner
|
||||
|
||||
@staticmethod
|
||||
def _split_email_csv(value: str | None) -> list[str] | None:
|
||||
"""Tools accept comma-separated cc/bcc strings; Composio expects an array."""
|
||||
if not value:
|
||||
return None
|
||||
addrs = [e.strip() for e in value.split(",") if e.strip()]
|
||||
return addrs or None
|
||||
|
||||
# ===== Gmail write methods =====
|
||||
|
||||
async def send_gmail_email(
|
||||
self,
|
||||
connected_account_id: str,
|
||||
entity_id: str,
|
||||
to: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
cc: str | None = None,
|
||||
bcc: str | None = None,
|
||||
is_html: bool = False,
|
||||
) -> tuple[str | None, str | None, str | None]:
|
||||
"""Send a Gmail message via the Composio ``GMAIL_SEND_EMAIL`` toolkit.
|
||||
|
||||
Returns:
|
||||
Tuple of (message_id, thread_id, error). On success ``error`` is
|
||||
None and at least one of the IDs is populated when Composio
|
||||
returns them; on failure both IDs are None.
|
||||
"""
|
||||
try:
|
||||
params: dict[str, Any] = {
|
||||
"recipient_email": to,
|
||||
"subject": subject,
|
||||
"body": body,
|
||||
"is_html": is_html,
|
||||
}
|
||||
if cc:
|
||||
cc_list = self._split_email_csv(cc)
|
||||
if cc_list:
|
||||
params["cc"] = cc_list
|
||||
if bcc:
|
||||
bcc_list = self._split_email_csv(bcc)
|
||||
if bcc_list:
|
||||
params["bcc"] = bcc_list
|
||||
|
||||
result = await self.execute_tool(
|
||||
connected_account_id=connected_account_id,
|
||||
tool_name="GMAIL_SEND_EMAIL",
|
||||
params=params,
|
||||
entity_id=entity_id,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return None, None, result.get("error", "Unknown error")
|
||||
|
||||
payload = self._unwrap_response_data(result.get("data", {}))
|
||||
message_id = None
|
||||
thread_id = None
|
||||
if isinstance(payload, dict):
|
||||
message_id = (
|
||||
payload.get("id")
|
||||
or payload.get("message_id")
|
||||
or payload.get("messageId")
|
||||
)
|
||||
thread_id = payload.get("threadId") or payload.get("thread_id")
|
||||
return message_id, thread_id, None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send Gmail email: {e!s}")
|
||||
return None, None, str(e)
|
||||
|
||||
async def create_gmail_draft(
|
||||
self,
|
||||
connected_account_id: str,
|
||||
entity_id: str,
|
||||
to: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
cc: str | None = None,
|
||||
bcc: str | None = None,
|
||||
is_html: bool = False,
|
||||
) -> tuple[str | None, str | None, str | None, str | None]:
|
||||
"""Create a Gmail draft via the Composio ``GMAIL_CREATE_EMAIL_DRAFT`` toolkit.
|
||||
|
||||
Returns:
|
||||
Tuple of (draft_id, message_id, thread_id, error). On success
|
||||
``error`` is None and ``draft_id`` is populated.
|
||||
"""
|
||||
try:
|
||||
params: dict[str, Any] = {
|
||||
"recipient_email": to,
|
||||
"subject": subject,
|
||||
"body": body,
|
||||
"is_html": is_html,
|
||||
}
|
||||
cc_list = self._split_email_csv(cc)
|
||||
if cc_list:
|
||||
params["cc"] = cc_list
|
||||
bcc_list = self._split_email_csv(bcc)
|
||||
if bcc_list:
|
||||
params["bcc"] = bcc_list
|
||||
|
||||
result = await self.execute_tool(
|
||||
connected_account_id=connected_account_id,
|
||||
tool_name="GMAIL_CREATE_EMAIL_DRAFT",
|
||||
params=params,
|
||||
entity_id=entity_id,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return None, None, None, result.get("error", "Unknown error")
|
||||
|
||||
payload = self._unwrap_response_data(result.get("data", {}))
|
||||
draft_id = None
|
||||
message_id = None
|
||||
thread_id = None
|
||||
if isinstance(payload, dict):
|
||||
draft_id = payload.get("id") or payload.get("draft_id")
|
||||
draft_message = payload.get("message") or {}
|
||||
if isinstance(draft_message, dict):
|
||||
message_id = draft_message.get("id") or draft_message.get(
|
||||
"message_id"
|
||||
)
|
||||
thread_id = draft_message.get("threadId") or draft_message.get(
|
||||
"thread_id"
|
||||
)
|
||||
if message_id is None:
|
||||
message_id = payload.get("message_id") or payload.get("messageId")
|
||||
if thread_id is None:
|
||||
thread_id = payload.get("thread_id") or payload.get("threadId")
|
||||
return draft_id, message_id, thread_id, None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create Gmail draft: {e!s}")
|
||||
return None, None, None, str(e)
|
||||
|
||||
async def update_gmail_draft(
|
||||
self,
|
||||
connected_account_id: str,
|
||||
entity_id: str,
|
||||
draft_id: str,
|
||||
to: str | None = None,
|
||||
subject: str | None = None,
|
||||
body: str | None = None,
|
||||
cc: str | None = None,
|
||||
bcc: str | None = None,
|
||||
is_html: bool = False,
|
||||
) -> tuple[str | None, str | None, str | None]:
|
||||
"""Update an existing Gmail draft via ``GMAIL_UPDATE_DRAFT``.
|
||||
|
||||
Returns:
|
||||
Tuple of (draft_id, message_id, error).
|
||||
"""
|
||||
try:
|
||||
params: dict[str, Any] = {
|
||||
"draft_id": draft_id,
|
||||
"is_html": is_html,
|
||||
}
|
||||
if to:
|
||||
params["recipient_email"] = to
|
||||
if subject is not None:
|
||||
params["subject"] = subject
|
||||
if body is not None:
|
||||
params["body"] = body
|
||||
cc_list = self._split_email_csv(cc)
|
||||
if cc_list:
|
||||
params["cc"] = cc_list
|
||||
bcc_list = self._split_email_csv(bcc)
|
||||
if bcc_list:
|
||||
params["bcc"] = bcc_list
|
||||
|
||||
result = await self.execute_tool(
|
||||
connected_account_id=connected_account_id,
|
||||
tool_name="GMAIL_UPDATE_DRAFT",
|
||||
params=params,
|
||||
entity_id=entity_id,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return None, None, result.get("error", "Unknown error")
|
||||
|
||||
payload = self._unwrap_response_data(result.get("data", {}))
|
||||
new_draft_id = draft_id
|
||||
message_id = None
|
||||
if isinstance(payload, dict):
|
||||
new_draft_id = payload.get("id") or payload.get("draft_id") or draft_id
|
||||
draft_message = payload.get("message") or {}
|
||||
if isinstance(draft_message, dict):
|
||||
message_id = draft_message.get("id") or draft_message.get(
|
||||
"message_id"
|
||||
)
|
||||
if message_id is None:
|
||||
message_id = payload.get("message_id") or payload.get("messageId")
|
||||
return new_draft_id, message_id, None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update Gmail draft: {e!s}")
|
||||
return None, None, str(e)
|
||||
|
||||
async def trash_gmail_message(
|
||||
self,
|
||||
connected_account_id: str,
|
||||
entity_id: str,
|
||||
message_id: str,
|
||||
) -> str | None:
|
||||
"""Move a Gmail message to trash via ``GMAIL_MOVE_TO_TRASH``.
|
||||
|
||||
Returns the error message on failure, ``None`` on success.
|
||||
"""
|
||||
try:
|
||||
result = await self.execute_tool(
|
||||
connected_account_id=connected_account_id,
|
||||
tool_name="GMAIL_MOVE_TO_TRASH",
|
||||
params={"message_id": message_id},
|
||||
entity_id=entity_id,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return result.get("error", "Unknown error")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to trash Gmail message: {e!s}")
|
||||
return str(e)
|
||||
|
||||
# ===== Google Calendar write methods =====
|
||||
|
||||
async def create_calendar_event(
|
||||
self,
|
||||
connected_account_id: str,
|
||||
entity_id: str,
|
||||
summary: str,
|
||||
start_datetime: str,
|
||||
end_datetime: str,
|
||||
timezone: str | None = None,
|
||||
description: str | None = None,
|
||||
location: str | None = None,
|
||||
attendees: list[str] | None = None,
|
||||
calendar_id: str = "primary",
|
||||
) -> tuple[str | None, str | None, str | None]:
|
||||
"""Create a Google Calendar event via ``GOOGLECALENDAR_CREATE_EVENT``.
|
||||
|
||||
Composio strips trailing timezone info on ``start_datetime`` /
|
||||
``end_datetime`` and uses the ``timezone`` field as the IANA name,
|
||||
so callers may pass ISO 8601 strings with or without offsets.
|
||||
|
||||
Returns:
|
||||
Tuple of (event_id, html_link, error).
|
||||
"""
|
||||
try:
|
||||
params: dict[str, Any] = {
|
||||
"summary": summary,
|
||||
"start_datetime": start_datetime,
|
||||
"end_datetime": end_datetime,
|
||||
"calendar_id": calendar_id,
|
||||
}
|
||||
if timezone:
|
||||
params["timezone"] = timezone
|
||||
if description:
|
||||
params["description"] = description
|
||||
if location:
|
||||
params["location"] = location
|
||||
if attendees:
|
||||
params["attendees"] = [a for a in attendees if a]
|
||||
|
||||
result = await self.execute_tool(
|
||||
connected_account_id=connected_account_id,
|
||||
tool_name="GOOGLECALENDAR_CREATE_EVENT",
|
||||
params=params,
|
||||
entity_id=entity_id,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return None, None, result.get("error", "Unknown error")
|
||||
|
||||
payload = self._unwrap_response_data(result.get("data", {}))
|
||||
event_id = None
|
||||
html_link = None
|
||||
if isinstance(payload, dict):
|
||||
event_id = payload.get("id") or payload.get("event_id")
|
||||
html_link = payload.get("htmlLink") or payload.get("html_link")
|
||||
return event_id, html_link, None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create Calendar event: {e!s}")
|
||||
return None, None, str(e)
|
||||
|
||||
async def update_calendar_event(
|
||||
self,
|
||||
connected_account_id: str,
|
||||
entity_id: str,
|
||||
event_id: str,
|
||||
summary: str | None = None,
|
||||
start_time: str | None = None,
|
||||
end_time: str | None = None,
|
||||
timezone: str | None = None,
|
||||
description: str | None = None,
|
||||
location: str | None = None,
|
||||
attendees: list[str] | None = None,
|
||||
calendar_id: str = "primary",
|
||||
) -> tuple[str | None, str | None, str | None]:
|
||||
"""Patch an existing Google Calendar event via ``GOOGLECALENDAR_PATCH_EVENT``.
|
||||
|
||||
Uses PATCH (not PUT) semantics so omitted fields are preserved.
|
||||
|
||||
Returns:
|
||||
Tuple of (event_id, html_link, error).
|
||||
"""
|
||||
try:
|
||||
params: dict[str, Any] = {
|
||||
"event_id": event_id,
|
||||
"calendar_id": calendar_id,
|
||||
}
|
||||
if summary is not None:
|
||||
params["summary"] = summary
|
||||
if start_time is not None:
|
||||
params["start_time"] = start_time
|
||||
if end_time is not None:
|
||||
params["end_time"] = end_time
|
||||
if timezone:
|
||||
params["timezone"] = timezone
|
||||
if description is not None:
|
||||
params["description"] = description
|
||||
if location is not None:
|
||||
params["location"] = location
|
||||
if attendees is not None:
|
||||
params["attendees"] = [a for a in attendees if a]
|
||||
|
||||
result = await self.execute_tool(
|
||||
connected_account_id=connected_account_id,
|
||||
tool_name="GOOGLECALENDAR_PATCH_EVENT",
|
||||
params=params,
|
||||
entity_id=entity_id,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return None, None, result.get("error", "Unknown error")
|
||||
|
||||
payload = self._unwrap_response_data(result.get("data", {}))
|
||||
new_event_id = event_id
|
||||
html_link = None
|
||||
if isinstance(payload, dict):
|
||||
new_event_id = payload.get("id") or payload.get("event_id") or event_id
|
||||
html_link = payload.get("htmlLink") or payload.get("html_link")
|
||||
return new_event_id, html_link, None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to patch Calendar event: {e!s}")
|
||||
return None, None, str(e)
|
||||
|
||||
async def delete_calendar_event(
|
||||
self,
|
||||
connected_account_id: str,
|
||||
entity_id: str,
|
||||
event_id: str,
|
||||
calendar_id: str = "primary",
|
||||
) -> str | None:
|
||||
"""Delete a Google Calendar event via ``GOOGLECALENDAR_DELETE_EVENT``.
|
||||
|
||||
Returns the error message on failure, ``None`` on success (idempotent
|
||||
on already-deleted events).
|
||||
"""
|
||||
try:
|
||||
result = await self.execute_tool(
|
||||
connected_account_id=connected_account_id,
|
||||
tool_name="GOOGLECALENDAR_DELETE_EVENT",
|
||||
params={
|
||||
"event_id": event_id,
|
||||
"calendar_id": calendar_id,
|
||||
},
|
||||
entity_id=entity_id,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return result.get("error", "Unknown error")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete Calendar event: {e!s}")
|
||||
return str(e)
|
||||
|
||||
# ===== Google Drive write methods =====
|
||||
|
||||
@staticmethod
|
||||
def _drive_web_view_link(file_id: str, mime_type: str | None) -> str:
|
||||
"""Synthesize a Google Drive ``webViewLink`` from id + mimeType.
|
||||
|
||||
Composio's ``GOOGLEDRIVE_CREATE_FILE_FROM_TEXT`` returns flat
|
||||
metadata (id, name, mimeType) but does not always include a
|
||||
``webViewLink``. We rebuild the canonical UI URL based on the
|
||||
Workspace MIME type so callers can keep using a single field.
|
||||
"""
|
||||
if not file_id:
|
||||
return ""
|
||||
mt = (mime_type or "").lower()
|
||||
if mt == "application/vnd.google-apps.document":
|
||||
return f"https://docs.google.com/document/d/{file_id}/edit"
|
||||
if mt == "application/vnd.google-apps.spreadsheet":
|
||||
return f"https://docs.google.com/spreadsheets/d/{file_id}/edit"
|
||||
if mt == "application/vnd.google-apps.presentation":
|
||||
return f"https://docs.google.com/presentation/d/{file_id}/edit"
|
||||
if mt == "application/vnd.google-apps.folder":
|
||||
return f"https://drive.google.com/drive/folders/{file_id}"
|
||||
return f"https://drive.google.com/file/d/{file_id}/view"
|
||||
|
||||
async def create_drive_file_from_text(
|
||||
self,
|
||||
connected_account_id: str,
|
||||
entity_id: str,
|
||||
name: str,
|
||||
mime_type: str,
|
||||
content: str | None = None,
|
||||
parent_id: str | None = None,
|
||||
) -> tuple[dict[str, Any] | None, str | None]:
|
||||
"""Create a Google Drive file from text via ``GOOGLEDRIVE_CREATE_FILE_FROM_TEXT``.
|
||||
|
||||
Composio's tool requires ``text_content`` even for "empty" files;
|
||||
an empty string is accepted. Native Workspace types (Docs, Sheets)
|
||||
are produced by setting ``mime_type`` to the Google Apps MIME, and
|
||||
Drive auto-converts the text payload (e.g. CSV → Sheet).
|
||||
|
||||
Returns:
|
||||
Tuple of (file_meta, error). ``file_meta`` keys:
|
||||
``id``, ``name``, ``mimeType``, ``webViewLink``.
|
||||
"""
|
||||
try:
|
||||
params: dict[str, Any] = {
|
||||
"file_name": name,
|
||||
"mime_type": mime_type,
|
||||
"text_content": content if content is not None else "",
|
||||
}
|
||||
if parent_id:
|
||||
params["parent_id"] = parent_id
|
||||
|
||||
result = await self.execute_tool(
|
||||
connected_account_id=connected_account_id,
|
||||
tool_name="GOOGLEDRIVE_CREATE_FILE_FROM_TEXT",
|
||||
params=params,
|
||||
entity_id=entity_id,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return None, result.get("error", "Unknown error")
|
||||
|
||||
payload = self._unwrap_response_data(result.get("data", {}))
|
||||
file_id: str | None = None
|
||||
file_name: str | None = name
|
||||
mime: str | None = mime_type
|
||||
web_view_link: str | None = None
|
||||
|
||||
if isinstance(payload, dict):
|
||||
file_id = (
|
||||
payload.get("id") or payload.get("file_id") or payload.get("fileId")
|
||||
)
|
||||
file_name = payload.get("name") or payload.get("file_name") or name
|
||||
mime = payload.get("mimeType") or payload.get("mime_type") or mime_type
|
||||
web_view_link = payload.get("webViewLink") or payload.get(
|
||||
"web_view_link"
|
||||
)
|
||||
|
||||
if not file_id:
|
||||
return None, "Composio response did not include a file id"
|
||||
|
||||
if not web_view_link:
|
||||
web_view_link = self._drive_web_view_link(file_id, mime)
|
||||
|
||||
return (
|
||||
{
|
||||
"id": file_id,
|
||||
"name": file_name,
|
||||
"mimeType": mime,
|
||||
"webViewLink": web_view_link,
|
||||
},
|
||||
None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create Drive file: {e!s}")
|
||||
return None, str(e)
|
||||
|
||||
async def trash_drive_file(
|
||||
self,
|
||||
connected_account_id: str,
|
||||
entity_id: str,
|
||||
file_id: str,
|
||||
) -> str | None:
|
||||
"""Move a Google Drive file to trash via ``GOOGLEDRIVE_TRASH_FILE``.
|
||||
|
||||
Returns the error message on failure, ``None`` on success.
|
||||
"""
|
||||
try:
|
||||
result = await self.execute_tool(
|
||||
connected_account_id=connected_account_id,
|
||||
tool_name="GOOGLEDRIVE_TRASH_FILE",
|
||||
params={"file_id": file_id},
|
||||
entity_id=entity_id,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return result.get("error", "Unknown error")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to trash Drive file: {e!s}")
|
||||
return str(e)
|
||||
|
||||
# ===== User Info Methods =====
|
||||
|
||||
async def get_connected_account_email(
|
||||
|
|
|
|||
|
|
@ -456,6 +456,8 @@ class VercelStreamingService:
|
|||
title: str,
|
||||
status: str = "in_progress",
|
||||
items: list[str] | None = None,
|
||||
*,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format a thinking step for chain-of-thought display (SurfSense specific).
|
||||
|
|
@ -469,15 +471,15 @@ class VercelStreamingService:
|
|||
Returns:
|
||||
str: SSE formatted thinking step data part
|
||||
"""
|
||||
return self.format_data(
|
||||
"thinking-step",
|
||||
{
|
||||
"id": step_id,
|
||||
"title": title,
|
||||
"status": status,
|
||||
"items": items or [],
|
||||
},
|
||||
)
|
||||
payload: dict[str, Any] = {
|
||||
"id": step_id,
|
||||
"title": title,
|
||||
"status": status,
|
||||
"items": items or [],
|
||||
}
|
||||
if metadata:
|
||||
payload["metadata"] = metadata
|
||||
return self.format_data("thinking-step", payload)
|
||||
|
||||
def format_thread_title_update(self, thread_id: int, title: str) -> str:
|
||||
"""
|
||||
|
|
@ -601,6 +603,7 @@ class VercelStreamingService:
|
|||
tool_name: str,
|
||||
*,
|
||||
langchain_tool_call_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format the start of tool input streaming.
|
||||
|
|
@ -608,15 +611,14 @@ class VercelStreamingService:
|
|||
Args:
|
||||
tool_call_id: The unique tool call identifier. May be EITHER the
|
||||
synthetic ``call_<run_id>`` id derived from LangGraph
|
||||
``run_id`` (legacy / ``SURFSENSE_ENABLE_STREAM_PARITY_V2``
|
||||
OFF, or the unmatched-fallback path under parity_v2) OR
|
||||
the authoritative LangChain ``tool_call.id`` (parity_v2
|
||||
path: when the provider streams ``tool_call_chunks`` we
|
||||
register the ``index`` and reuse the lc-id as the card
|
||||
id so live ``tool-input-delta`` events can be routed
|
||||
without a downstream join). Either way, the same id is
|
||||
preserved across ``tool-input-start`` / ``-delta`` /
|
||||
``-available`` / ``tool-output-available`` for one call.
|
||||
``run_id`` (unmatched chunk fallback when no ``index`` was
|
||||
registered) OR the authoritative LangChain ``tool_call.id``
|
||||
(when the provider streams ``tool_call_chunks`` we register
|
||||
the ``index`` and reuse the lc-id as the card id so live
|
||||
``tool-input-delta`` events route without a downstream join).
|
||||
Either way, the same id is preserved across
|
||||
``tool-input-start`` / ``-delta`` / ``-available`` /
|
||||
``tool-output-available`` for one call.
|
||||
tool_name: The name of the tool being called.
|
||||
langchain_tool_call_id: Optional authoritative LangChain
|
||||
``tool_call.id``. When set, surfaces as
|
||||
|
|
@ -636,6 +638,8 @@ class VercelStreamingService:
|
|||
}
|
||||
if langchain_tool_call_id:
|
||||
payload["langchainToolCallId"] = langchain_tool_call_id
|
||||
if metadata:
|
||||
payload["metadata"] = metadata
|
||||
return self._format_sse(payload)
|
||||
|
||||
def format_tool_input_delta(self, tool_call_id: str, input_text_delta: str) -> str:
|
||||
|
|
@ -667,6 +671,7 @@ class VercelStreamingService:
|
|||
input_data: dict[str, Any],
|
||||
*,
|
||||
langchain_tool_call_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format the completion of tool input.
|
||||
|
|
@ -692,6 +697,8 @@ class VercelStreamingService:
|
|||
}
|
||||
if langchain_tool_call_id:
|
||||
payload["langchainToolCallId"] = langchain_tool_call_id
|
||||
if metadata:
|
||||
payload["metadata"] = metadata
|
||||
return self._format_sse(payload)
|
||||
|
||||
def format_tool_output_available(
|
||||
|
|
@ -700,6 +707,7 @@ class VercelStreamingService:
|
|||
output: Any,
|
||||
*,
|
||||
langchain_tool_call_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format tool execution output.
|
||||
|
|
@ -726,6 +734,8 @@ class VercelStreamingService:
|
|||
}
|
||||
if langchain_tool_call_id:
|
||||
payload["langchainToolCallId"] = langchain_tool_call_id
|
||||
if metadata:
|
||||
payload["metadata"] = metadata
|
||||
return self._format_sse(payload)
|
||||
|
||||
# =========================================================================
|
||||
|
|
|
|||
20
surfsense_backend/app/services/streaming/__init__.py
Normal file
20
surfsense_backend/app/services/streaming/__init__.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
"""Single-responsibility split of the streaming SSE protocol.
|
||||
|
||||
Layout:
|
||||
* ``envelope/`` - SSE wire framing + ID generators
|
||||
* ``emitter/`` - identity of the agent that emitted an event + runtime registry
|
||||
* ``events/`` - one module per SSE event family
|
||||
* ``service.py`` - composition root used when emitting chat SSE
|
||||
* ``interrupt_correlation.py`` - id-aware lookup over LangGraph state
|
||||
|
||||
Naming on the wire:
|
||||
* AI SDK protocol fields keep their existing camelCase
|
||||
(``toolCallId``, ``messageId``, ``inputTextDelta``, ``langchainToolCallId``).
|
||||
* Every SurfSense-added field uses ``snake_case``, including the
|
||||
top-level ``emitted_by`` envelope and all inner ``data`` payloads.
|
||||
|
||||
Production chat uses ``app.services.new_streaming_service`` from
|
||||
``app.tasks.chat.stream_new_chat`` and related routes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
29
surfsense_backend/app/services/streaming/emitter/__init__.py
Normal file
29
surfsense_backend/app/services/streaming/emitter/__init__.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
"""Identity of the agent that emitted a streamed event.
|
||||
|
||||
The wire field is ``emitted_by``; the Python identity is :class:`Emitter`.
|
||||
``EmitterRegistry`` resolves which emitter owns a LangGraph event, with
|
||||
LangGraph's own namespace metadata as the primary key and a parent_ids
|
||||
walk as a fallback for cases where context vars don't propagate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .emitter import (
|
||||
MAIN_EMITTER,
|
||||
Emitter,
|
||||
EmitterLevel,
|
||||
attach_emitted_by,
|
||||
main_emitter,
|
||||
subagent_emitter,
|
||||
)
|
||||
from .registry import EmitterRegistry
|
||||
|
||||
__all__ = [
|
||||
"MAIN_EMITTER",
|
||||
"Emitter",
|
||||
"EmitterLevel",
|
||||
"EmitterRegistry",
|
||||
"attach_emitted_by",
|
||||
"main_emitter",
|
||||
"subagent_emitter",
|
||||
]
|
||||
61
surfsense_backend/app/services/streaming/emitter/emitter.py
Normal file
61
surfsense_backend/app/services/streaming/emitter/emitter.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
"""Identity payload describing which agent produced a stream event."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
EmitterLevel = Literal["main", "subagent"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Emitter:
|
||||
level: EmitterLevel
|
||||
subagent_type: str | None = None
|
||||
subagent_run_id: str | None = None
|
||||
parent_tool_call_id: str | None = None
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_payload(self) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {"level": self.level}
|
||||
if self.subagent_type is not None:
|
||||
payload["subagent_type"] = self.subagent_type
|
||||
if self.subagent_run_id is not None:
|
||||
payload["subagent_run_id"] = self.subagent_run_id
|
||||
if self.parent_tool_call_id is not None:
|
||||
payload["parent_tool_call_id"] = self.parent_tool_call_id
|
||||
if self.extra:
|
||||
payload.update(self.extra)
|
||||
return payload
|
||||
|
||||
|
||||
MAIN_EMITTER = Emitter(level="main")
|
||||
|
||||
|
||||
def main_emitter() -> Emitter:
|
||||
return MAIN_EMITTER
|
||||
|
||||
|
||||
def subagent_emitter(
|
||||
*,
|
||||
subagent_type: str,
|
||||
subagent_run_id: str,
|
||||
parent_tool_call_id: str | None = None,
|
||||
extra: dict[str, Any] | None = None,
|
||||
) -> Emitter:
|
||||
return Emitter(
|
||||
level="subagent",
|
||||
subagent_type=subagent_type,
|
||||
subagent_run_id=subagent_run_id,
|
||||
parent_tool_call_id=parent_tool_call_id,
|
||||
extra=dict(extra or {}),
|
||||
)
|
||||
|
||||
|
||||
def attach_emitted_by(
|
||||
payload: dict[str, Any], emitter: Emitter | None
|
||||
) -> dict[str, Any]:
|
||||
if emitter is None:
|
||||
return payload
|
||||
payload["emitted_by"] = emitter.to_payload()
|
||||
return payload
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue