mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-06 14:22:47 +02:00
Coerce deliverables thread_id and invoke domain agents asynchronously.
This commit is contained in:
parent
5bc33626b9
commit
8f8d7540f0
3 changed files with 30 additions and 8 deletions
|
|
@ -10,12 +10,31 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from app.db import ChatVisibility
|
||||
|
||||
|
||||
def coerce_thread_id_for_registry(thread_id: str | int | None) -> int | None:
|
||||
"""Normalize chat thread id for registry tools that FK to ``new_chat_threads.id``.
|
||||
|
||||
``create_surfsense_deep_agent`` passes an ``int``; multi-agent wiring may pass
|
||||
``str(chat_id)`` for LangGraph/checkpointer consistency. AsyncPG requires ``int``
|
||||
for integer columns.
|
||||
"""
|
||||
if thread_id is None:
|
||||
return None
|
||||
if isinstance(thread_id, int):
|
||||
return thread_id
|
||||
s = str(thread_id).strip()
|
||||
if not s:
|
||||
return None
|
||||
if s.isdigit():
|
||||
return int(s)
|
||||
return None
|
||||
|
||||
|
||||
def build_registry_dependencies(
|
||||
*,
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
thread_id: str,
|
||||
thread_id: str | int | None,
|
||||
llm: BaseChatModel | None = None,
|
||||
firecrawl_api_key: str | None = None,
|
||||
connector_service: Any | None = None,
|
||||
|
|
@ -32,7 +51,7 @@ def build_registry_dependencies(
|
|||
"db_session": db_session,
|
||||
"search_space_id": search_space_id,
|
||||
"user_id": user_id,
|
||||
"thread_id": thread_id,
|
||||
"thread_id": coerce_thread_id_for_registry(thread_id),
|
||||
"llm": llm,
|
||||
"firecrawl_api_key": firecrawl_api_key,
|
||||
"connector_service": connector_service,
|
||||
|
|
|
|||
|
|
@ -23,7 +23,10 @@ from app.agents.multi_agent_chat.core.mcp_partition import (
|
|||
fetch_mcp_connector_metadata_maps,
|
||||
partition_mcp_tools_by_expert_route,
|
||||
)
|
||||
from app.agents.multi_agent_chat.core.registry import build_registry_dependencies
|
||||
from app.agents.multi_agent_chat.core.registry.dependencies import (
|
||||
build_registry_dependencies,
|
||||
coerce_thread_id_for_registry,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.supervisor_stack import build_supervisor_middleware_stack
|
||||
from app.agents.multi_agent_chat.routing.supervisor_routing import build_supervisor_routing_tools
|
||||
from app.agents.multi_agent_chat.supervisor import build_supervisor_agent
|
||||
|
|
@ -83,7 +86,7 @@ async def create_multi_agent_chat(
|
|||
search_space_id: int,
|
||||
user_id: str,
|
||||
checkpointer: Checkpointer | None = None,
|
||||
thread_id: str | None = None,
|
||||
thread_id: str | int | None = None,
|
||||
firecrawl_api_key: str | None = None,
|
||||
connector_service: Any | None = None,
|
||||
available_connectors: list[str] | None = None,
|
||||
|
|
@ -148,7 +151,7 @@ async def create_multi_agent_chat(
|
|||
db_session=db_session,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id or "",
|
||||
thread_id=thread_id,
|
||||
llm=llm,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
connector_service=connector_service,
|
||||
|
|
@ -159,7 +162,7 @@ async def create_multi_agent_chat(
|
|||
routing_tools = build_supervisor_routing_tools(
|
||||
llm,
|
||||
registry_dependencies=registry_dependencies,
|
||||
include_deliverables=thread_id is not None,
|
||||
include_deliverables=coerce_thread_id_for_registry(thread_id) is not None,
|
||||
mcp_tools_by_route=mcp_tools_by_route,
|
||||
available_connectors=resolved_connectors,
|
||||
thread_visibility=thread_visibility,
|
||||
|
|
|
|||
|
|
@ -107,10 +107,10 @@ def _normalize_domain_output(spec: DomainRoutingSpec, raw_text: str) -> str:
|
|||
|
||||
def _routing_tool_for_spec(spec: DomainRoutingSpec) -> BaseTool:
|
||||
@tool(spec.tool_name, description=spec.description)
|
||||
def _route(task: str) -> str:
|
||||
async def _route(task: str) -> str:
|
||||
curated = spec.curated_context(task) if spec.curated_context else None
|
||||
content = compose_child_task(task, curated_context=curated)
|
||||
result = spec.domain_agent.invoke(
|
||||
result = await spec.domain_agent.ainvoke(
|
||||
{"messages": [{"role": "user", "content": content}]},
|
||||
)
|
||||
return _normalize_domain_output(spec, extract_last_assistant_text(result))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue