mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-08 07:12:39 +02:00
Coerce deliverables thread_id and invoke domain agents asynchronously.
This commit is contained in:
parent
5bc33626b9
commit
8f8d7540f0
3 changed files with 30 additions and 8 deletions
|
|
@ -10,12 +10,31 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from app.db import ChatVisibility
|
from app.db import ChatVisibility
|
||||||
|
|
||||||
|
|
||||||
|
def coerce_thread_id_for_registry(thread_id: str | int | None) -> int | None:
|
||||||
|
"""Normalize chat thread id for registry tools that FK to ``new_chat_threads.id``.
|
||||||
|
|
||||||
|
``create_surfsense_deep_agent`` passes an ``int``; multi-agent wiring may pass
|
||||||
|
``str(chat_id)`` for LangGraph/checkpointer consistency. AsyncPG requires ``int``
|
||||||
|
for integer columns.
|
||||||
|
"""
|
||||||
|
if thread_id is None:
|
||||||
|
return None
|
||||||
|
if isinstance(thread_id, int):
|
||||||
|
return thread_id
|
||||||
|
s = str(thread_id).strip()
|
||||||
|
if not s:
|
||||||
|
return None
|
||||||
|
if s.isdigit():
|
||||||
|
return int(s)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def build_registry_dependencies(
|
def build_registry_dependencies(
|
||||||
*,
|
*,
|
||||||
db_session: AsyncSession,
|
db_session: AsyncSession,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
thread_id: str,
|
thread_id: str | int | None,
|
||||||
llm: BaseChatModel | None = None,
|
llm: BaseChatModel | None = None,
|
||||||
firecrawl_api_key: str | None = None,
|
firecrawl_api_key: str | None = None,
|
||||||
connector_service: Any | None = None,
|
connector_service: Any | None = None,
|
||||||
|
|
@ -32,7 +51,7 @@ def build_registry_dependencies(
|
||||||
"db_session": db_session,
|
"db_session": db_session,
|
||||||
"search_space_id": search_space_id,
|
"search_space_id": search_space_id,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"thread_id": thread_id,
|
"thread_id": coerce_thread_id_for_registry(thread_id),
|
||||||
"llm": llm,
|
"llm": llm,
|
||||||
"firecrawl_api_key": firecrawl_api_key,
|
"firecrawl_api_key": firecrawl_api_key,
|
||||||
"connector_service": connector_service,
|
"connector_service": connector_service,
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,10 @@ from app.agents.multi_agent_chat.core.mcp_partition import (
|
||||||
fetch_mcp_connector_metadata_maps,
|
fetch_mcp_connector_metadata_maps,
|
||||||
partition_mcp_tools_by_expert_route,
|
partition_mcp_tools_by_expert_route,
|
||||||
)
|
)
|
||||||
from app.agents.multi_agent_chat.core.registry import build_registry_dependencies
|
from app.agents.multi_agent_chat.core.registry.dependencies import (
|
||||||
|
build_registry_dependencies,
|
||||||
|
coerce_thread_id_for_registry,
|
||||||
|
)
|
||||||
from app.agents.multi_agent_chat.middleware.supervisor_stack import build_supervisor_middleware_stack
|
from app.agents.multi_agent_chat.middleware.supervisor_stack import build_supervisor_middleware_stack
|
||||||
from app.agents.multi_agent_chat.routing.supervisor_routing import build_supervisor_routing_tools
|
from app.agents.multi_agent_chat.routing.supervisor_routing import build_supervisor_routing_tools
|
||||||
from app.agents.multi_agent_chat.supervisor import build_supervisor_agent
|
from app.agents.multi_agent_chat.supervisor import build_supervisor_agent
|
||||||
|
|
@ -83,7 +86,7 @@ async def create_multi_agent_chat(
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
checkpointer: Checkpointer | None = None,
|
checkpointer: Checkpointer | None = None,
|
||||||
thread_id: str | None = None,
|
thread_id: str | int | None = None,
|
||||||
firecrawl_api_key: str | None = None,
|
firecrawl_api_key: str | None = None,
|
||||||
connector_service: Any | None = None,
|
connector_service: Any | None = None,
|
||||||
available_connectors: list[str] | None = None,
|
available_connectors: list[str] | None = None,
|
||||||
|
|
@ -148,7 +151,7 @@ async def create_multi_agent_chat(
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
thread_id=thread_id or "",
|
thread_id=thread_id,
|
||||||
llm=llm,
|
llm=llm,
|
||||||
firecrawl_api_key=firecrawl_api_key,
|
firecrawl_api_key=firecrawl_api_key,
|
||||||
connector_service=connector_service,
|
connector_service=connector_service,
|
||||||
|
|
@ -159,7 +162,7 @@ async def create_multi_agent_chat(
|
||||||
routing_tools = build_supervisor_routing_tools(
|
routing_tools = build_supervisor_routing_tools(
|
||||||
llm,
|
llm,
|
||||||
registry_dependencies=registry_dependencies,
|
registry_dependencies=registry_dependencies,
|
||||||
include_deliverables=thread_id is not None,
|
include_deliverables=coerce_thread_id_for_registry(thread_id) is not None,
|
||||||
mcp_tools_by_route=mcp_tools_by_route,
|
mcp_tools_by_route=mcp_tools_by_route,
|
||||||
available_connectors=resolved_connectors,
|
available_connectors=resolved_connectors,
|
||||||
thread_visibility=thread_visibility,
|
thread_visibility=thread_visibility,
|
||||||
|
|
|
||||||
|
|
@ -107,10 +107,10 @@ def _normalize_domain_output(spec: DomainRoutingSpec, raw_text: str) -> str:
|
||||||
|
|
||||||
def _routing_tool_for_spec(spec: DomainRoutingSpec) -> BaseTool:
|
def _routing_tool_for_spec(spec: DomainRoutingSpec) -> BaseTool:
|
||||||
@tool(spec.tool_name, description=spec.description)
|
@tool(spec.tool_name, description=spec.description)
|
||||||
def _route(task: str) -> str:
|
async def _route(task: str) -> str:
|
||||||
curated = spec.curated_context(task) if spec.curated_context else None
|
curated = spec.curated_context(task) if spec.curated_context else None
|
||||||
content = compose_child_task(task, curated_context=curated)
|
content = compose_child_task(task, curated_context=curated)
|
||||||
result = spec.domain_agent.invoke(
|
result = await spec.domain_agent.ainvoke(
|
||||||
{"messages": [{"role": "user", "content": content}]},
|
{"messages": [{"role": "user", "content": content}]},
|
||||||
)
|
)
|
||||||
return _normalize_domain_output(spec, extract_last_assistant_text(result))
|
return _normalize_domain_output(spec, extract_last_assistant_text(result))
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue