mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-06 14:22:47 +02:00
Add multi-agent core registry, delegation, and MCP partitioning.
This commit is contained in:
parent
0c8ea2085e
commit
c974fcefe6
16 changed files with 437 additions and 0 deletions
|
|
@ -0,0 +1,25 @@
|
|||
"""Cross-cutting building blocks (prompts, agents, delegation, registry) — not domain logic."""
|
||||
|
||||
from app.agents.multi_agent_chat.core.agents import build_domain_agent
|
||||
from app.agents.multi_agent_chat.core.bindings import connector_binding
|
||||
from app.agents.multi_agent_chat.core.delegation import compose_child_task
|
||||
from app.agents.multi_agent_chat.core.invocation import extract_last_assistant_text
|
||||
from app.agents.multi_agent_chat.core.prompts import read_prompt_md
|
||||
from app.agents.multi_agent_chat.core.registry import (
|
||||
REGISTRY_ROUTING_CATEGORY_KEYS,
|
||||
TOOL_NAMES_BY_CATEGORY,
|
||||
build_registry_dependencies,
|
||||
build_registry_tools_for_category,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"REGISTRY_ROUTING_CATEGORY_KEYS",
|
||||
"TOOL_NAMES_BY_CATEGORY",
|
||||
"build_domain_agent",
|
||||
"build_registry_dependencies",
|
||||
"build_registry_tools_for_category",
|
||||
"compose_child_task",
|
||||
"connector_binding",
|
||||
"extract_last_assistant_text",
|
||||
"read_prompt_md",
|
||||
]
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
"""Compiled subgraph factories shared by domain slices."""
|
||||
|
||||
from app.agents.multi_agent_chat.core.agents.domain_graph import build_domain_agent
|
||||
|
||||
__all__ = ["build_domain_agent"]
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
"""Compile a domain LangGraph agent from a co-located prompt + tool list."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.multi_agent_chat.core.prompts import read_prompt_md
|
||||
|
||||
|
||||
def build_domain_agent(
|
||||
llm: BaseChatModel,
|
||||
tools: Sequence[BaseTool],
|
||||
*,
|
||||
prompt_package: str,
|
||||
prompt_stem: str = "domain_prompt",
|
||||
):
|
||||
"""``create_agent`` + ``{prompt_stem}.md`` loaded from ``prompt_package``."""
|
||||
system_prompt = read_prompt_md(prompt_package, prompt_stem)
|
||||
return create_agent(
|
||||
llm,
|
||||
system_prompt=system_prompt,
|
||||
tools=list(tools),
|
||||
)
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
"""Search-space / DB kwargs shared by ``new_chat`` tool factories (distinct from ``expert_agent.connectors`` integrations)."""
|
||||
|
||||
from app.agents.multi_agent_chat.core.bindings.binding import connector_binding
|
||||
|
||||
__all__ = ["connector_binding"]
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
"""Shared kwargs dict for ``new_chat`` tool factories (DB session + search space + user)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
def connector_binding(
|
||||
*,
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> dict[str, AsyncSession | int | str]:
|
||||
return {
|
||||
"db_session": db_session,
|
||||
"search_space_id": search_space_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
"""Supervisor → domain message shaping."""
|
||||
|
||||
from app.agents.multi_agent_chat.core.delegation.child_task import compose_child_task
|
||||
|
||||
__all__ = ["compose_child_task"]
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
"""Fold orchestrator-selected context into the single user message sent to a domain agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def compose_child_task(task: str, *, curated_context: str | None = None) -> str:
|
||||
"""Build the domain-agent user message: optional curated KB/context + task.
|
||||
|
||||
When ``curated_context`` is set (from supervisor/KB wiring), it is prepended so the
|
||||
child sees only what orchestration chose — not the full parent transcript.
|
||||
"""
|
||||
task = task.strip()
|
||||
if not curated_context or not curated_context.strip():
|
||||
return task
|
||||
return f"{curated_context.strip()}\n\n---\n\nTask:\n{task}"
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
"""Parsing LangGraph invoke results."""
|
||||
|
||||
from app.agents.multi_agent_chat.core.invocation.output import extract_last_assistant_text
|
||||
|
||||
__all__ = ["extract_last_assistant_text"]
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
"""Extract displayable text from a LangGraph agent ``invoke`` / ``ainvoke`` result."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def extract_last_assistant_text(result: dict[str, Any]) -> str:
|
||||
"""Return the last message's string content, or ``\"\"`` if missing."""
|
||||
messages = result.get("messages") or []
|
||||
if not messages:
|
||||
return ""
|
||||
last = messages[-1]
|
||||
content = getattr(last, "content", None)
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return str(last)
|
||||
|
|
@ -0,0 +1,128 @@
|
|||
"""Partition MCP tools onto multi-agent expert routes without modifying ``new_chat``.
|
||||
|
||||
Uses the same connector discovery shape as ``load_mcp_tools`` (copied query below). Tools come from
|
||||
``app.agents.new_chat.tools.mcp_tool.load_mcp_tools``; routing uses metadata already set there:
|
||||
|
||||
- HTTP tools: ``metadata["mcp_connector_id"]`` → DB connector row → expert route.
|
||||
- stdio tools: no connector id on the tool; ``metadata["mcp_connector_name"]`` → connector name map
|
||||
(duplicate names: last row wins — rare).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
from sqlalchemy import cast, select
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import SearchSourceConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# SurfSense ``SearchSourceConnectorType`` string → supervisor routing key (must match
|
||||
# ``DomainRoutingSpec.tool_name`` values used in ``supervisor_routing``).
|
||||
_CONNECTOR_TYPE_TO_EXPERT_ROUTE: dict[str, str] = {
|
||||
"GOOGLE_GMAIL_CONNECTOR": "gmail",
|
||||
"COMPOSIO_GMAIL_CONNECTOR": "gmail",
|
||||
"GOOGLE_CALENDAR_CONNECTOR": "calendar",
|
||||
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": "calendar",
|
||||
"DISCORD_CONNECTOR": "discord",
|
||||
"TEAMS_CONNECTOR": "teams",
|
||||
"LUMA_CONNECTOR": "luma",
|
||||
"LINEAR_CONNECTOR": "linear",
|
||||
"JIRA_CONNECTOR": "jira",
|
||||
"CLICKUP_CONNECTOR": "clickup",
|
||||
"SLACK_CONNECTOR": "slack",
|
||||
"AIRTABLE_CONNECTOR": "airtable",
|
||||
"MCP_CONNECTOR": "generic_mcp",
|
||||
}
|
||||
|
||||
# Ordering when appending MCP-only routes (no native registry slice for these types).
|
||||
MCP_ONLY_ROUTE_KEYS_IN_ORDER: tuple[str, ...] = (
|
||||
"linear",
|
||||
"slack",
|
||||
"jira",
|
||||
"clickup",
|
||||
"airtable",
|
||||
"generic_mcp",
|
||||
)
|
||||
|
||||
|
||||
async def fetch_mcp_connector_metadata_maps(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> tuple[dict[int, str], dict[str, str]]:
|
||||
"""Read-only copy of connector discovery used alongside ``load_mcp_tools``.
|
||||
|
||||
Same filter as ``new_chat.tools.mcp_tool.load_mcp_tools`` (connectors with ``server_config``).
|
||||
"""
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
cast(SearchSourceConnector.config, JSONB).has_key("server_config"),
|
||||
),
|
||||
)
|
||||
id_to_type: dict[int, str] = {}
|
||||
name_to_type: dict[str, str] = {}
|
||||
for connector in result.scalars():
|
||||
ct = (
|
||||
connector.connector_type.value
|
||||
if hasattr(connector.connector_type, "value")
|
||||
else str(connector.connector_type)
|
||||
)
|
||||
id_to_type[connector.id] = ct
|
||||
if connector.name:
|
||||
name_to_type[connector.name] = ct
|
||||
return id_to_type, name_to_type
|
||||
|
||||
|
||||
def partition_mcp_tools_by_expert_route(
|
||||
tools: Sequence[BaseTool],
|
||||
connector_id_to_type: dict[int, str],
|
||||
connector_name_to_type: dict[str, str],
|
||||
) -> dict[str, list[BaseTool]]:
|
||||
"""Bucket MCP tools by expert route key. Supervisor never receives raw MCP tools."""
|
||||
buckets: dict[str, list[BaseTool]] = defaultdict(list)
|
||||
|
||||
for tool in tools:
|
||||
meta: dict[str, Any] = getattr(tool, "metadata", None) or {}
|
||||
connector_type: str | None = None
|
||||
|
||||
cid = meta.get("mcp_connector_id")
|
||||
if cid is not None:
|
||||
try:
|
||||
cid_int = int(cid)
|
||||
except (TypeError, ValueError):
|
||||
cid_int = None
|
||||
if cid_int is not None:
|
||||
connector_type = connector_id_to_type.get(cid_int)
|
||||
|
||||
if connector_type is None and meta.get("mcp_transport") == "stdio":
|
||||
cname = meta.get("mcp_connector_name")
|
||||
if cname:
|
||||
connector_type = connector_name_to_type.get(str(cname))
|
||||
|
||||
if connector_type is None:
|
||||
logger.debug(
|
||||
"Skipping MCP tool %r — could not resolve connector type from metadata",
|
||||
getattr(tool, "name", None),
|
||||
)
|
||||
continue
|
||||
|
||||
route = _CONNECTOR_TYPE_TO_EXPERT_ROUTE.get(connector_type)
|
||||
if route is None:
|
||||
logger.warning(
|
||||
"MCP tool %r has unmapped connector type %s — skipped",
|
||||
getattr(tool, "name", None),
|
||||
connector_type,
|
||||
)
|
||||
continue
|
||||
|
||||
buckets[route].append(tool)
|
||||
|
||||
return dict(buckets)
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
"""Markdown prompt loading for domain and supervisor packages."""
|
||||
|
||||
from app.agents.multi_agent_chat.core.prompts.load import read_prompt_md
|
||||
|
||||
__all__ = ["read_prompt_md"]
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
"""Load ``*.md`` prompt files from co-located packages (domain slices ship ``domain_prompt.md``)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from importlib import resources
|
||||
|
||||
|
||||
def read_prompt_md(package: str, stem: str) -> str:
|
||||
"""Read ``{stem}.md`` from the given import package (e.g. ``…expert_agent.connectors.gmail``)."""
|
||||
try:
|
||||
ref = resources.files(package).joinpath(f"{stem}.md")
|
||||
if not ref.is_file():
|
||||
return ""
|
||||
text = ref.read_text(encoding="utf-8")
|
||||
except (FileNotFoundError, ModuleNotFoundError, OSError, TypeError):
|
||||
return ""
|
||||
if text.endswith("\n"):
|
||||
text = text[:-1]
|
||||
return text
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
"""``new_chat`` tool registry grouping + dependency bundles for domain slices."""
|
||||
|
||||
from app.agents.multi_agent_chat.core.registry.categories import (
|
||||
REGISTRY_ROUTING_CATEGORY_KEYS,
|
||||
TOOL_NAMES_BY_CATEGORY,
|
||||
)
|
||||
from app.agents.multi_agent_chat.core.registry.dependencies import build_registry_dependencies
|
||||
from app.agents.multi_agent_chat.core.registry.subset import build_registry_tools_for_category
|
||||
|
||||
__all__ = [
|
||||
"REGISTRY_ROUTING_CATEGORY_KEYS",
|
||||
"TOOL_NAMES_BY_CATEGORY",
|
||||
"build_registry_dependencies",
|
||||
"build_registry_tools_for_category",
|
||||
]
|
||||
|
|
@ -0,0 +1,84 @@
|
|||
"""Registry tool names grouped by multi-agent routing category.
|
||||
|
||||
Each string must match ``ToolDefinition.name`` in
|
||||
``app.agents.new_chat.tools.registry.BUILTIN_TOOLS`` — these are **not** guessed or MCP-only:
|
||||
:class:`~app.agents.multi_agent_chat.core.registry.subset.build_registry_tools_for_category`
|
||||
uses synchronous :func:`~app.agents.new_chat.tools.registry.build_tools`, which only instantiates
|
||||
``BUILTIN_TOOLS``. MCP tools are loaded separately and merged in ``supervisor_routing``.
|
||||
|
||||
Connectors that exist for search/indexing but have **no** entry in ``BUILTIN_TOOLS`` correctly have
|
||||
no row here (no chat tools to delegate)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Keys match supervisor routing tool names; values match ``BUILTIN_TOOLS`` names exactly.
|
||||
TOOL_NAMES_BY_CATEGORY: dict[str, list[str]] = {
|
||||
"gmail": [
|
||||
"search_gmail",
|
||||
"read_gmail_email",
|
||||
"create_gmail_draft",
|
||||
"send_gmail_email",
|
||||
"trash_gmail_email",
|
||||
"update_gmail_draft",
|
||||
],
|
||||
"calendar": [
|
||||
"search_calendar_events",
|
||||
"create_calendar_event",
|
||||
"update_calendar_event",
|
||||
"delete_calendar_event",
|
||||
],
|
||||
"research": [
|
||||
"web_search",
|
||||
"scrape_webpage",
|
||||
"search_surfsense_docs",
|
||||
],
|
||||
"deliverables": [
|
||||
"generate_podcast",
|
||||
"generate_video_presentation",
|
||||
"generate_report",
|
||||
"generate_resume",
|
||||
"generate_image",
|
||||
],
|
||||
"memory": [
|
||||
"update_memory",
|
||||
],
|
||||
"discord": [
|
||||
"list_discord_channels",
|
||||
"read_discord_messages",
|
||||
"send_discord_message",
|
||||
],
|
||||
"teams": [
|
||||
"list_teams_channels",
|
||||
"read_teams_messages",
|
||||
"send_teams_message",
|
||||
],
|
||||
"notion": [
|
||||
"create_notion_page",
|
||||
"update_notion_page",
|
||||
"delete_notion_page",
|
||||
],
|
||||
"confluence": [
|
||||
"create_confluence_page",
|
||||
"update_confluence_page",
|
||||
"delete_confluence_page",
|
||||
],
|
||||
"google_drive": [
|
||||
"create_google_drive_file",
|
||||
"delete_google_drive_file",
|
||||
],
|
||||
"dropbox": [
|
||||
"create_dropbox_file",
|
||||
"delete_dropbox_file",
|
||||
],
|
||||
"onedrive": [
|
||||
"create_onedrive_file",
|
||||
"delete_onedrive_file",
|
||||
],
|
||||
"luma": [
|
||||
"list_luma_events",
|
||||
"read_luma_event",
|
||||
"create_luma_event",
|
||||
],
|
||||
}
|
||||
|
||||
REGISTRY_ROUTING_CATEGORY_KEYS: tuple[str, ...] = tuple(TOOL_NAMES_BY_CATEGORY.keys())
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
"""Dependency dict for :func:`app.agents.new_chat.tools.registry.build_tools` in multi-agent graphs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import ChatVisibility
|
||||
|
||||
|
||||
def build_registry_dependencies(
|
||||
*,
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
thread_id: str,
|
||||
llm: BaseChatModel | None = None,
|
||||
firecrawl_api_key: str | None = None,
|
||||
connector_service: Any | None = None,
|
||||
available_connectors: list[str] | None = None,
|
||||
available_document_types: list[str] | None = None,
|
||||
thread_visibility: ChatVisibility = ChatVisibility.PRIVATE,
|
||||
) -> dict[str, Any]:
|
||||
"""Union of kwargs commonly required by registry factories across category slices.
|
||||
|
||||
Individual categories enable a subset of tools; each tool still validates its own
|
||||
``ToolDefinition.requires`` against this dict.
|
||||
"""
|
||||
return {
|
||||
"db_session": db_session,
|
||||
"search_space_id": search_space_id,
|
||||
"user_id": user_id,
|
||||
"thread_id": thread_id,
|
||||
"llm": llm,
|
||||
"firecrawl_api_key": firecrawl_api_key,
|
||||
"connector_service": connector_service,
|
||||
"available_connectors": available_connectors,
|
||||
"available_document_types": available_document_types,
|
||||
"thread_visibility": thread_visibility,
|
||||
}
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
"""Build :mod:`new_chat` registry tool subsets for multi-agent domain slices."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.new_chat.tools.registry import build_tools
|
||||
from app.agents.multi_agent_chat.core.registry.categories import TOOL_NAMES_BY_CATEGORY
|
||||
|
||||
|
||||
def build_registry_tools_for_category(
|
||||
dependencies: dict[str, Any],
|
||||
category: str,
|
||||
) -> list[BaseTool]:
|
||||
"""Instantiate only the tools registered for ``category`` (see ``TOOL_NAMES_BY_CATEGORY``)."""
|
||||
names = TOOL_NAMES_BY_CATEGORY.get(category)
|
||||
if not names:
|
||||
msg = f"Unknown registry category: {category!r}"
|
||||
raise ValueError(msg)
|
||||
return build_tools(dependencies, enabled_tools=names)
|
||||
Loading…
Add table
Add a link
Reference in a new issue