mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-06 22:32:39 +02:00
Add multi-agent core registry, delegation, and MCP partitioning.
This commit is contained in:
parent
0c8ea2085e
commit
c974fcefe6
16 changed files with 437 additions and 0 deletions
|
|
@ -0,0 +1,25 @@
|
||||||
|
"""Cross-cutting building blocks (prompts, agents, delegation, registry) — not domain logic."""
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.core.agents import build_domain_agent
|
||||||
|
from app.agents.multi_agent_chat.core.bindings import connector_binding
|
||||||
|
from app.agents.multi_agent_chat.core.delegation import compose_child_task
|
||||||
|
from app.agents.multi_agent_chat.core.invocation import extract_last_assistant_text
|
||||||
|
from app.agents.multi_agent_chat.core.prompts import read_prompt_md
|
||||||
|
from app.agents.multi_agent_chat.core.registry import (
|
||||||
|
REGISTRY_ROUTING_CATEGORY_KEYS,
|
||||||
|
TOOL_NAMES_BY_CATEGORY,
|
||||||
|
build_registry_dependencies,
|
||||||
|
build_registry_tools_for_category,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"REGISTRY_ROUTING_CATEGORY_KEYS",
|
||||||
|
"TOOL_NAMES_BY_CATEGORY",
|
||||||
|
"build_domain_agent",
|
||||||
|
"build_registry_dependencies",
|
||||||
|
"build_registry_tools_for_category",
|
||||||
|
"compose_child_task",
|
||||||
|
"connector_binding",
|
||||||
|
"extract_last_assistant_text",
|
||||||
|
"read_prompt_md",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
"""Compiled subgraph factories shared by domain slices."""
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.core.agents.domain_graph import build_domain_agent
|
||||||
|
|
||||||
|
__all__ = ["build_domain_agent"]
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
"""Compile a domain LangGraph agent from a co-located prompt + tool list."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from langchain.agents import create_agent
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.core.prompts import read_prompt_md
|
||||||
|
|
||||||
|
|
||||||
|
def build_domain_agent(
|
||||||
|
llm: BaseChatModel,
|
||||||
|
tools: Sequence[BaseTool],
|
||||||
|
*,
|
||||||
|
prompt_package: str,
|
||||||
|
prompt_stem: str = "domain_prompt",
|
||||||
|
):
|
||||||
|
"""``create_agent`` + ``{prompt_stem}.md`` loaded from ``prompt_package``."""
|
||||||
|
system_prompt = read_prompt_md(prompt_package, prompt_stem)
|
||||||
|
return create_agent(
|
||||||
|
llm,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tools=list(tools),
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
"""Search-space / DB kwargs shared by ``new_chat`` tool factories (distinct from ``expert_agent.connectors`` integrations)."""
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.core.bindings.binding import connector_binding
|
||||||
|
|
||||||
|
__all__ = ["connector_binding"]
|
||||||
|
|
@ -0,0 +1,18 @@
|
||||||
|
"""Shared kwargs dict for ``new_chat`` tool factories (DB session + search space + user)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
|
def connector_binding(
|
||||||
|
*,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str,
|
||||||
|
) -> dict[str, AsyncSession | int | str]:
|
||||||
|
return {
|
||||||
|
"db_session": db_session,
|
||||||
|
"search_space_id": search_space_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
"""Supervisor → domain message shaping."""
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.core.delegation.child_task import compose_child_task
|
||||||
|
|
||||||
|
__all__ = ["compose_child_task"]
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
"""Fold orchestrator-selected context into the single user message sent to a domain agent."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
def compose_child_task(task: str, *, curated_context: str | None = None) -> str:
|
||||||
|
"""Build the domain-agent user message: optional curated KB/context + task.
|
||||||
|
|
||||||
|
When ``curated_context`` is set (from supervisor/KB wiring), it is prepended so the
|
||||||
|
child sees only what orchestration chose — not the full parent transcript.
|
||||||
|
"""
|
||||||
|
task = task.strip()
|
||||||
|
if not curated_context or not curated_context.strip():
|
||||||
|
return task
|
||||||
|
return f"{curated_context.strip()}\n\n---\n\nTask:\n{task}"
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
"""Parsing LangGraph invoke results."""
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.core.invocation.output import extract_last_assistant_text
|
||||||
|
|
||||||
|
__all__ = ["extract_last_assistant_text"]
|
||||||
|
|
@ -0,0 +1,17 @@
|
||||||
|
"""Extract displayable text from a LangGraph agent ``invoke`` / ``ainvoke`` result."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def extract_last_assistant_text(result: dict[str, Any]) -> str:
|
||||||
|
"""Return the last message's string content, or ``\"\"`` if missing."""
|
||||||
|
messages = result.get("messages") or []
|
||||||
|
if not messages:
|
||||||
|
return ""
|
||||||
|
last = messages[-1]
|
||||||
|
content = getattr(last, "content", None)
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
return str(last)
|
||||||
|
|
@ -0,0 +1,128 @@
|
||||||
|
"""Partition MCP tools onto multi-agent expert routes without modifying ``new_chat``.
|
||||||
|
|
||||||
|
Uses the same connector discovery shape as ``load_mcp_tools`` (copied query below). Tools come from
|
||||||
|
``app.agents.new_chat.tools.mcp_tool.load_mcp_tools``; routing uses metadata already set there:
|
||||||
|
|
||||||
|
- HTTP tools: ``metadata["mcp_connector_id"]`` → DB connector row → expert route.
|
||||||
|
- stdio tools: no connector id on the tool; ``metadata["mcp_connector_name"]`` → connector name map
|
||||||
|
(duplicate names: last row wins — rare).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
from sqlalchemy import cast, select
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import SearchSourceConnector
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# SurfSense ``SearchSourceConnectorType`` string → supervisor routing key (must match
|
||||||
|
# ``DomainRoutingSpec.tool_name`` values used in ``supervisor_routing``).
|
||||||
|
_CONNECTOR_TYPE_TO_EXPERT_ROUTE: dict[str, str] = {
|
||||||
|
"GOOGLE_GMAIL_CONNECTOR": "gmail",
|
||||||
|
"COMPOSIO_GMAIL_CONNECTOR": "gmail",
|
||||||
|
"GOOGLE_CALENDAR_CONNECTOR": "calendar",
|
||||||
|
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": "calendar",
|
||||||
|
"DISCORD_CONNECTOR": "discord",
|
||||||
|
"TEAMS_CONNECTOR": "teams",
|
||||||
|
"LUMA_CONNECTOR": "luma",
|
||||||
|
"LINEAR_CONNECTOR": "linear",
|
||||||
|
"JIRA_CONNECTOR": "jira",
|
||||||
|
"CLICKUP_CONNECTOR": "clickup",
|
||||||
|
"SLACK_CONNECTOR": "slack",
|
||||||
|
"AIRTABLE_CONNECTOR": "airtable",
|
||||||
|
"MCP_CONNECTOR": "generic_mcp",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Ordering when appending MCP-only routes (no native registry slice for these types).
|
||||||
|
MCP_ONLY_ROUTE_KEYS_IN_ORDER: tuple[str, ...] = (
|
||||||
|
"linear",
|
||||||
|
"slack",
|
||||||
|
"jira",
|
||||||
|
"clickup",
|
||||||
|
"airtable",
|
||||||
|
"generic_mcp",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_mcp_connector_metadata_maps(
|
||||||
|
session: AsyncSession,
|
||||||
|
search_space_id: int,
|
||||||
|
) -> tuple[dict[int, str], dict[str, str]]:
|
||||||
|
"""Read-only copy of connector discovery used alongside ``load_mcp_tools``.
|
||||||
|
|
||||||
|
Same filter as ``new_chat.tools.mcp_tool.load_mcp_tools`` (connectors with ``server_config``).
|
||||||
|
"""
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
cast(SearchSourceConnector.config, JSONB).has_key("server_config"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
id_to_type: dict[int, str] = {}
|
||||||
|
name_to_type: dict[str, str] = {}
|
||||||
|
for connector in result.scalars():
|
||||||
|
ct = (
|
||||||
|
connector.connector_type.value
|
||||||
|
if hasattr(connector.connector_type, "value")
|
||||||
|
else str(connector.connector_type)
|
||||||
|
)
|
||||||
|
id_to_type[connector.id] = ct
|
||||||
|
if connector.name:
|
||||||
|
name_to_type[connector.name] = ct
|
||||||
|
return id_to_type, name_to_type
|
||||||
|
|
||||||
|
|
||||||
|
def partition_mcp_tools_by_expert_route(
|
||||||
|
tools: Sequence[BaseTool],
|
||||||
|
connector_id_to_type: dict[int, str],
|
||||||
|
connector_name_to_type: dict[str, str],
|
||||||
|
) -> dict[str, list[BaseTool]]:
|
||||||
|
"""Bucket MCP tools by expert route key. Supervisor never receives raw MCP tools."""
|
||||||
|
buckets: dict[str, list[BaseTool]] = defaultdict(list)
|
||||||
|
|
||||||
|
for tool in tools:
|
||||||
|
meta: dict[str, Any] = getattr(tool, "metadata", None) or {}
|
||||||
|
connector_type: str | None = None
|
||||||
|
|
||||||
|
cid = meta.get("mcp_connector_id")
|
||||||
|
if cid is not None:
|
||||||
|
try:
|
||||||
|
cid_int = int(cid)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
cid_int = None
|
||||||
|
if cid_int is not None:
|
||||||
|
connector_type = connector_id_to_type.get(cid_int)
|
||||||
|
|
||||||
|
if connector_type is None and meta.get("mcp_transport") == "stdio":
|
||||||
|
cname = meta.get("mcp_connector_name")
|
||||||
|
if cname:
|
||||||
|
connector_type = connector_name_to_type.get(str(cname))
|
||||||
|
|
||||||
|
if connector_type is None:
|
||||||
|
logger.debug(
|
||||||
|
"Skipping MCP tool %r — could not resolve connector type from metadata",
|
||||||
|
getattr(tool, "name", None),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
route = _CONNECTOR_TYPE_TO_EXPERT_ROUTE.get(connector_type)
|
||||||
|
if route is None:
|
||||||
|
logger.warning(
|
||||||
|
"MCP tool %r has unmapped connector type %s — skipped",
|
||||||
|
getattr(tool, "name", None),
|
||||||
|
connector_type,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
buckets[route].append(tool)
|
||||||
|
|
||||||
|
return dict(buckets)
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
"""Markdown prompt loading for domain and supervisor packages."""
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.core.prompts.load import read_prompt_md
|
||||||
|
|
||||||
|
__all__ = ["read_prompt_md"]
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
"""Load ``*.md`` prompt files from co-located packages (domain slices ship ``domain_prompt.md``)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from importlib import resources
|
||||||
|
|
||||||
|
|
||||||
|
def read_prompt_md(package: str, stem: str) -> str:
|
||||||
|
"""Read ``{stem}.md`` from the given import package (e.g. ``…expert_agent.connectors.gmail``)."""
|
||||||
|
try:
|
||||||
|
ref = resources.files(package).joinpath(f"{stem}.md")
|
||||||
|
if not ref.is_file():
|
||||||
|
return ""
|
||||||
|
text = ref.read_text(encoding="utf-8")
|
||||||
|
except (FileNotFoundError, ModuleNotFoundError, OSError, TypeError):
|
||||||
|
return ""
|
||||||
|
if text.endswith("\n"):
|
||||||
|
text = text[:-1]
|
||||||
|
return text
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
"""``new_chat`` tool registry grouping + dependency bundles for domain slices."""
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.core.registry.categories import (
|
||||||
|
REGISTRY_ROUTING_CATEGORY_KEYS,
|
||||||
|
TOOL_NAMES_BY_CATEGORY,
|
||||||
|
)
|
||||||
|
from app.agents.multi_agent_chat.core.registry.dependencies import build_registry_dependencies
|
||||||
|
from app.agents.multi_agent_chat.core.registry.subset import build_registry_tools_for_category
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"REGISTRY_ROUTING_CATEGORY_KEYS",
|
||||||
|
"TOOL_NAMES_BY_CATEGORY",
|
||||||
|
"build_registry_dependencies",
|
||||||
|
"build_registry_tools_for_category",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,84 @@
|
||||||
|
"""Registry tool names grouped by multi-agent routing category.
|
||||||
|
|
||||||
|
Each string must match ``ToolDefinition.name`` in
|
||||||
|
``app.agents.new_chat.tools.registry.BUILTIN_TOOLS`` — these are **not** guessed or MCP-only:
|
||||||
|
:class:`~app.agents.multi_agent_chat.core.registry.subset.build_registry_tools_for_category`
|
||||||
|
uses synchronous :func:`~app.agents.new_chat.tools.registry.build_tools`, which only instantiates
|
||||||
|
``BUILTIN_TOOLS``. MCP tools are loaded separately and merged in ``supervisor_routing``.
|
||||||
|
|
||||||
|
Connectors that exist for search/indexing but have **no** entry in ``BUILTIN_TOOLS`` correctly have
|
||||||
|
no row here (no chat tools to delegate)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Keys match supervisor routing tool names; values match ``BUILTIN_TOOLS`` names exactly.
|
||||||
|
TOOL_NAMES_BY_CATEGORY: dict[str, list[str]] = {
|
||||||
|
"gmail": [
|
||||||
|
"search_gmail",
|
||||||
|
"read_gmail_email",
|
||||||
|
"create_gmail_draft",
|
||||||
|
"send_gmail_email",
|
||||||
|
"trash_gmail_email",
|
||||||
|
"update_gmail_draft",
|
||||||
|
],
|
||||||
|
"calendar": [
|
||||||
|
"search_calendar_events",
|
||||||
|
"create_calendar_event",
|
||||||
|
"update_calendar_event",
|
||||||
|
"delete_calendar_event",
|
||||||
|
],
|
||||||
|
"research": [
|
||||||
|
"web_search",
|
||||||
|
"scrape_webpage",
|
||||||
|
"search_surfsense_docs",
|
||||||
|
],
|
||||||
|
"deliverables": [
|
||||||
|
"generate_podcast",
|
||||||
|
"generate_video_presentation",
|
||||||
|
"generate_report",
|
||||||
|
"generate_resume",
|
||||||
|
"generate_image",
|
||||||
|
],
|
||||||
|
"memory": [
|
||||||
|
"update_memory",
|
||||||
|
],
|
||||||
|
"discord": [
|
||||||
|
"list_discord_channels",
|
||||||
|
"read_discord_messages",
|
||||||
|
"send_discord_message",
|
||||||
|
],
|
||||||
|
"teams": [
|
||||||
|
"list_teams_channels",
|
||||||
|
"read_teams_messages",
|
||||||
|
"send_teams_message",
|
||||||
|
],
|
||||||
|
"notion": [
|
||||||
|
"create_notion_page",
|
||||||
|
"update_notion_page",
|
||||||
|
"delete_notion_page",
|
||||||
|
],
|
||||||
|
"confluence": [
|
||||||
|
"create_confluence_page",
|
||||||
|
"update_confluence_page",
|
||||||
|
"delete_confluence_page",
|
||||||
|
],
|
||||||
|
"google_drive": [
|
||||||
|
"create_google_drive_file",
|
||||||
|
"delete_google_drive_file",
|
||||||
|
],
|
||||||
|
"dropbox": [
|
||||||
|
"create_dropbox_file",
|
||||||
|
"delete_dropbox_file",
|
||||||
|
],
|
||||||
|
"onedrive": [
|
||||||
|
"create_onedrive_file",
|
||||||
|
"delete_onedrive_file",
|
||||||
|
],
|
||||||
|
"luma": [
|
||||||
|
"list_luma_events",
|
||||||
|
"read_luma_event",
|
||||||
|
"create_luma_event",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTRY_ROUTING_CATEGORY_KEYS: tuple[str, ...] = tuple(TOOL_NAMES_BY_CATEGORY.keys())
|
||||||
|
|
@ -0,0 +1,42 @@
|
||||||
|
"""Dependency dict for :func:`app.agents.new_chat.tools.registry.build_tools` in multi-agent graphs."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import ChatVisibility
|
||||||
|
|
||||||
|
|
||||||
|
def build_registry_dependencies(
|
||||||
|
*,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str,
|
||||||
|
thread_id: str,
|
||||||
|
llm: BaseChatModel | None = None,
|
||||||
|
firecrawl_api_key: str | None = None,
|
||||||
|
connector_service: Any | None = None,
|
||||||
|
available_connectors: list[str] | None = None,
|
||||||
|
available_document_types: list[str] | None = None,
|
||||||
|
thread_visibility: ChatVisibility = ChatVisibility.PRIVATE,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Union of kwargs commonly required by registry factories across category slices.
|
||||||
|
|
||||||
|
Individual categories enable a subset of tools; each tool still validates its own
|
||||||
|
``ToolDefinition.requires`` against this dict.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"db_session": db_session,
|
||||||
|
"search_space_id": search_space_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"llm": llm,
|
||||||
|
"firecrawl_api_key": firecrawl_api_key,
|
||||||
|
"connector_service": connector_service,
|
||||||
|
"available_connectors": available_connectors,
|
||||||
|
"available_document_types": available_document_types,
|
||||||
|
"thread_visibility": thread_visibility,
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,22 @@
|
||||||
|
"""Build :mod:`new_chat` registry tool subsets for multi-agent domain slices."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.registry import build_tools
|
||||||
|
from app.agents.multi_agent_chat.core.registry.categories import TOOL_NAMES_BY_CATEGORY
|
||||||
|
|
||||||
|
|
||||||
|
def build_registry_tools_for_category(
|
||||||
|
dependencies: dict[str, Any],
|
||||||
|
category: str,
|
||||||
|
) -> list[BaseTool]:
|
||||||
|
"""Instantiate only the tools registered for ``category`` (see ``TOOL_NAMES_BY_CATEGORY``)."""
|
||||||
|
names = TOOL_NAMES_BY_CATEGORY.get(category)
|
||||||
|
if not names:
|
||||||
|
msg = f"Unknown registry category: {category!r}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
return build_tools(dependencies, enabled_tools=names)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue