diff --git a/surfsense_backend/app/agents/multi_agent_chat/shared/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/shared/__init__.py new file mode 100644 index 000000000..1ef1ad771 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/shared/__init__.py @@ -0,0 +1,13 @@ +"""Cross-cutting helpers: prompt loading, domain agent factory, connector deps.""" + +from app.agents.multi_agent_chat.shared.deps import connector_binding +from app.agents.multi_agent_chat.shared.domain_agent_factory import build_domain_agent +from app.agents.multi_agent_chat.shared.invoke_output import extract_last_assistant_text +from app.agents.multi_agent_chat.shared.prompt_loader import read_prompt_md + +__all__ = [ + "build_domain_agent", + "connector_binding", + "extract_last_assistant_text", + "read_prompt_md", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/shared/deps.py b/surfsense_backend/app/agents/multi_agent_chat/shared/deps.py new file mode 100644 index 000000000..c1e18e849 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/shared/deps.py @@ -0,0 +1,18 @@ +"""Shared kwargs for ``new_chat`` connector tool factories.""" + +from __future__ import annotations + +from sqlalchemy.ext.asyncio import AsyncSession + + +def connector_binding( + *, + db_session: AsyncSession, + search_space_id: int, + user_id: str, +) -> dict[str, AsyncSession | int | str]: + return { + "db_session": db_session, + "search_space_id": search_space_id, + "user_id": user_id, + } diff --git a/surfsense_backend/app/agents/multi_agent_chat/shared/domain_agent_factory.py b/surfsense_backend/app/agents/multi_agent_chat/shared/domain_agent_factory.py new file mode 100644 index 000000000..c6c5b061a --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/shared/domain_agent_factory.py @@ -0,0 +1,27 @@ +"""Compile a domain agent graph from a co-located prompt + tool list.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from langchain.agents import create_agent +from langchain_core.language_models import BaseChatModel +from langchain_core.tools import BaseTool + +from app.agents.multi_agent_chat.shared.prompt_loader import read_prompt_md + + +def build_domain_agent( + llm: BaseChatModel, + tools: Sequence[BaseTool], + *, + prompt_package: str, + prompt_stem: str = "domain_prompt", +): + """``create_agent`` + ``{prompt_stem}.md`` loaded from ``prompt_package``.""" + system_prompt = read_prompt_md(prompt_package, prompt_stem) + return create_agent( + llm, + system_prompt=system_prompt, + tools=list(tools), + ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/shared/invoke_output.py b/surfsense_backend/app/agents/multi_agent_chat/shared/invoke_output.py new file mode 100644 index 000000000..2bbab6e57 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/shared/invoke_output.py @@ -0,0 +1,17 @@ +"""Extract displayable text from a LangGraph agent ``invoke`` / ``ainvoke`` result.""" + +from __future__ import annotations + +from typing import Any + + +def extract_last_assistant_text(result: dict[str, Any]) -> str: + """Return the last message's string content, or ``\"\"`` if missing.""" + messages = result.get("messages") or [] + if not messages: + return "" + last = messages[-1] + content = getattr(last, "content", None) + if isinstance(content, str): + return content + return str(last) diff --git a/surfsense_backend/app/agents/multi_agent_chat/shared/prompt_loader.py b/surfsense_backend/app/agents/multi_agent_chat/shared/prompt_loader.py new file mode 100644 index 000000000..940647364 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/shared/prompt_loader.py @@ -0,0 +1,19 @@ +"""Load ``*.md`` from any package (vertical slices use co-located prompts).""" + +from __future__ import annotations + +from importlib import resources + + +def read_prompt_md(package: str, stem: str) -> str: + """Read ``{stem}.md`` from the given import package (e.g. ``app.agents.multi_agent_chat.gmail``).""" + try: + ref = resources.files(package).joinpath(f"{stem}.md") + if not ref.is_file(): + return "" + text = ref.read_text(encoding="utf-8") + except (FileNotFoundError, ModuleNotFoundError, OSError, TypeError): + return "" + if text.endswith("\n"): + text = text[:-1] + return text