mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-06 14:22:47 +02:00
Add multi_agent_chat shared helpers for prompts, domain agents, deps, and invoke parsing.
This commit is contained in:
parent
b9132f8544
commit
5ff2678253
5 changed files with 94 additions and 0 deletions
|
|
@ -0,0 +1,13 @@
|
|||
"""Cross-cutting helpers: prompt loading, domain agent factory, connector deps."""
|
||||
|
||||
from app.agents.multi_agent_chat.shared.deps import connector_binding
|
||||
from app.agents.multi_agent_chat.shared.domain_agent_factory import build_domain_agent
|
||||
from app.agents.multi_agent_chat.shared.invoke_output import extract_last_assistant_text
|
||||
from app.agents.multi_agent_chat.shared.prompt_loader import read_prompt_md
|
||||
|
||||
__all__ = [
|
||||
"build_domain_agent",
|
||||
"connector_binding",
|
||||
"extract_last_assistant_text",
|
||||
"read_prompt_md",
|
||||
]
|
||||
18
surfsense_backend/app/agents/multi_agent_chat/shared/deps.py
Normal file
18
surfsense_backend/app/agents/multi_agent_chat/shared/deps.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
"""Shared kwargs for ``new_chat`` connector tool factories."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
def connector_binding(
|
||||
*,
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> dict[str, AsyncSession | int | str]:
|
||||
return {
|
||||
"db_session": db_session,
|
||||
"search_space_id": search_space_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
"""Compile a domain agent graph from a co-located prompt + tool list."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.multi_agent_chat.shared.prompt_loader import read_prompt_md
|
||||
|
||||
|
||||
def build_domain_agent(
|
||||
llm: BaseChatModel,
|
||||
tools: Sequence[BaseTool],
|
||||
*,
|
||||
prompt_package: str,
|
||||
prompt_stem: str = "domain_prompt",
|
||||
):
|
||||
"""``create_agent`` + ``{prompt_stem}.md`` loaded from ``prompt_package``."""
|
||||
system_prompt = read_prompt_md(prompt_package, prompt_stem)
|
||||
return create_agent(
|
||||
llm,
|
||||
system_prompt=system_prompt,
|
||||
tools=list(tools),
|
||||
)
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
"""Extract displayable text from a LangGraph agent ``invoke`` / ``ainvoke`` result."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def extract_last_assistant_text(result: dict[str, Any]) -> str:
|
||||
"""Return the last message's string content, or ``\"\"`` if missing."""
|
||||
messages = result.get("messages") or []
|
||||
if not messages:
|
||||
return ""
|
||||
last = messages[-1]
|
||||
content = getattr(last, "content", None)
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return str(last)
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
"""Load ``*.md`` from any package (vertical slices use co-located prompts)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from importlib import resources
|
||||
|
||||
|
||||
def read_prompt_md(package: str, stem: str) -> str:
|
||||
"""Read ``{stem}.md`` from the given import package (e.g. ``app.agents.multi_agent_chat.gmail``)."""
|
||||
try:
|
||||
ref = resources.files(package).joinpath(f"{stem}.md")
|
||||
if not ref.is_file():
|
||||
return ""
|
||||
text = ref.read_text(encoding="utf-8")
|
||||
except (FileNotFoundError, ModuleNotFoundError, OSError, TypeError):
|
||||
return ""
|
||||
if text.endswith("\n"):
|
||||
text = text[:-1]
|
||||
return text
|
||||
Loading…
Add table
Add a link
Reference in a new issue