mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-07-02 22:01:05 +02:00
Compare commits
44 commits
7c4d1a6af6
...
91c2c06108
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
91c2c06108 | ||
|
|
888e1d9cec | ||
|
|
94dbbfa7e4 | ||
|
|
68f3f9313c | ||
|
|
451868906b | ||
|
|
4bee367d4a | ||
|
|
0bdc08162e | ||
|
|
d37417cbe9 | ||
|
|
fa0b47dfca | ||
|
|
5d3142332b | ||
|
|
e1e4bb4706 | ||
|
|
1188550bf6 | ||
|
|
b5301fa438 | ||
|
|
38b9e8dcc5 | ||
|
|
136901276a | ||
|
|
3f68962772 | ||
|
|
26695b949e | ||
|
|
43744713f7 | ||
|
|
3d1a504395 | ||
|
|
2b40592f0b | ||
|
|
f22d7434ce | ||
|
|
ae3d254a2c | ||
|
|
ec27807644 | ||
|
|
635cdde0eb | ||
|
|
25644e1c0b | ||
|
|
b6e2510e55 | ||
|
|
fce465a40f | ||
|
|
71f4f77f26 | ||
|
|
041af34820 | ||
|
|
5169d3d56c | ||
|
|
e6065b6793 | ||
|
|
1c9c496e01 | ||
|
|
71cd04b05e | ||
|
|
f844c3288c | ||
|
|
ea7bcebcd0 | ||
|
|
b3a8364fbd | ||
|
|
8d8ba6cbe8 | ||
|
|
4875fd9211 | ||
|
|
2f59fc9c72 | ||
|
|
85baaacd0a | ||
|
|
3eb448ec8d | ||
|
|
82c7d4a2ab | ||
|
|
0c4fd30cce | ||
|
|
7c61668823 |
129 changed files with 5680 additions and 4060 deletions
|
|
@ -0,0 +1,44 @@
|
|||
"""124_add_ai_file_sort_enabled
|
||||
|
||||
Revision ID: 124
|
||||
Revises: 123
|
||||
Create Date: 2026-04-14
|
||||
|
||||
Adds ai_file_sort_enabled boolean column to searchspaces.
|
||||
Defaults to False so AI file sorting is opt-in per search space.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "124"
|
||||
down_revision: str | None = "123"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
existing_columns = [
|
||||
col["name"] for col in sa.inspect(conn).get_columns("searchspaces")
|
||||
]
|
||||
|
||||
if "ai_file_sort_enabled" not in existing_columns:
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column(
|
||||
"ai_file_sort_enabled",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("searchspaces", "ai_file_sort_enabled")
|
||||
|
|
@ -472,7 +472,7 @@ async def create_surfsense_deep_agent(
|
|||
SubAgentMiddleware(backend=StateBackend, subagents=[general_purpose_spec]),
|
||||
create_summarization_middleware(llm, StateBackend),
|
||||
PatchToolCallsMiddleware(),
|
||||
DedupHITLToolCallsMiddleware(),
|
||||
DedupHITLToolCallsMiddleware(agent_tools=tools),
|
||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -20,19 +20,39 @@ from langgraph.runtime import Runtime
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_HITL_TOOL_DEDUP_KEYS: dict[str, str] = {
|
||||
"delete_calendar_event": "event_title_or_id",
|
||||
"update_calendar_event": "event_title_or_id",
|
||||
"trash_gmail_email": "email_subject_or_id",
|
||||
_NATIVE_HITL_TOOL_DEDUP_KEYS: dict[str, str] = {
|
||||
# Gmail
|
||||
"send_gmail_email": "subject",
|
||||
"create_gmail_draft": "subject",
|
||||
"update_gmail_draft": "draft_subject_or_id",
|
||||
"trash_gmail_email": "email_subject_or_id",
|
||||
# Google Calendar
|
||||
"create_calendar_event": "title",
|
||||
"update_calendar_event": "event_title_or_id",
|
||||
"delete_calendar_event": "event_title_or_id",
|
||||
# Google Drive
|
||||
"create_google_drive_file": "file_name",
|
||||
"delete_google_drive_file": "file_name",
|
||||
# OneDrive
|
||||
"create_onedrive_file": "file_name",
|
||||
"delete_onedrive_file": "file_name",
|
||||
"delete_notion_page": "page_title",
|
||||
# Dropbox
|
||||
"create_dropbox_file": "file_name",
|
||||
"delete_dropbox_file": "file_name",
|
||||
# Notion
|
||||
"create_notion_page": "title",
|
||||
"update_notion_page": "page_title",
|
||||
"delete_linear_issue": "issue_ref",
|
||||
"delete_notion_page": "page_title",
|
||||
# Linear
|
||||
"create_linear_issue": "title",
|
||||
"update_linear_issue": "issue_ref",
|
||||
"delete_linear_issue": "issue_ref",
|
||||
# Jira
|
||||
"create_jira_issue": "summary",
|
||||
"update_jira_issue": "issue_title_or_key",
|
||||
"delete_jira_issue": "issue_title_or_key",
|
||||
# Confluence
|
||||
"create_confluence_page": "title",
|
||||
"update_confluence_page": "page_title_or_id",
|
||||
"delete_confluence_page": "page_title_or_id",
|
||||
}
|
||||
|
|
@ -43,22 +63,39 @@ class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
|
||||
Only the **first** occurrence of each (tool-name, primary-arg-value)
|
||||
pair is kept; subsequent duplicates are silently dropped.
|
||||
|
||||
The dedup map is built from two sources:
|
||||
|
||||
1. A comprehensive list of native HITL tools (hardcoded above).
|
||||
2. Any ``StructuredTool`` instances passed via *agent_tools* whose
|
||||
``metadata`` contains ``{"hitl": True, "hitl_dedup_key": "..."}``.
|
||||
This is how MCP tools automatically get dedup support.
|
||||
"""
|
||||
|
||||
tools = ()
|
||||
|
||||
def __init__(self, *, agent_tools: list[Any] | None = None) -> None:
|
||||
self._dedup_keys: dict[str, str] = dict(_NATIVE_HITL_TOOL_DEDUP_KEYS)
|
||||
for t in agent_tools or []:
|
||||
meta = getattr(t, "metadata", None) or {}
|
||||
if meta.get("hitl") and meta.get("hitl_dedup_key"):
|
||||
self._dedup_keys[t.name] = meta["hitl_dedup_key"]
|
||||
|
||||
def after_model(
|
||||
self, state: AgentState, runtime: Runtime[Any]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._dedup(state)
|
||||
return self._dedup(state, self._dedup_keys)
|
||||
|
||||
async def aafter_model(
|
||||
self, state: AgentState, runtime: Runtime[Any]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._dedup(state)
|
||||
return self._dedup(state, self._dedup_keys)
|
||||
|
||||
@staticmethod
|
||||
def _dedup(state: AgentState) -> dict[str, Any] | None: # type: ignore[type-arg]
|
||||
def _dedup(
|
||||
state: AgentState,
|
||||
dedup_keys: dict[str, str], # type: ignore[type-arg]
|
||||
) -> dict[str, Any] | None:
|
||||
messages = state.get("messages")
|
||||
if not messages:
|
||||
return None
|
||||
|
|
@ -73,7 +110,7 @@ class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
|
||||
for tc in tool_calls:
|
||||
name = tc.get("name", "")
|
||||
dedup_key_arg = _HITL_TOOL_DEDUP_KEYS.get(name)
|
||||
dedup_key_arg = dedup_keys.get(name)
|
||||
if dedup_key_arg is not None:
|
||||
arg_val = str(tc.get("args", {}).get(dedup_key_arg, "")).lower()
|
||||
key = (name, arg_val)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import secrets
|
||||
from datetime import UTC, datetime
|
||||
from typing import Annotated, Any
|
||||
|
||||
|
|
@ -27,6 +28,7 @@ from sqlalchemy import delete, select
|
|||
|
||||
from app.agents.new_chat.sandbox import (
|
||||
_evict_sandbox_cache,
|
||||
delete_sandbox,
|
||||
get_or_create_sandbox,
|
||||
is_sandbox_enabled,
|
||||
)
|
||||
|
|
@ -552,7 +554,8 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
@staticmethod
|
||||
def _wrap_as_python(code: str) -> str:
|
||||
"""Wrap Python code in a shell invocation for the sandbox."""
|
||||
return f"python3 << 'PYEOF'\n{code}\nPYEOF"
|
||||
sentinel = f"_PYEOF_{secrets.token_hex(8)}"
|
||||
return f"python3 << '{sentinel}'\n{code}\n{sentinel}"
|
||||
|
||||
async def _execute_in_sandbox(
|
||||
self,
|
||||
|
|
@ -572,7 +575,10 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
self._thread_id,
|
||||
first_err,
|
||||
)
|
||||
_evict_sandbox_cache(self._thread_id)
|
||||
try:
|
||||
await delete_sandbox(self._thread_id)
|
||||
except Exception:
|
||||
_evict_sandbox_cache(self._thread_id)
|
||||
try:
|
||||
return await self._try_sandbox_execute(command, runtime, timeout)
|
||||
except Exception:
|
||||
|
|
@ -587,7 +593,14 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
runtime: ToolRuntime[None, FilesystemState],
|
||||
timeout: int | None,
|
||||
) -> str:
|
||||
sandbox, is_new = await get_or_create_sandbox(self._thread_id)
|
||||
sandbox, _is_new = await get_or_create_sandbox(self._thread_id)
|
||||
# NOTE: sync_files_to_sandbox is intentionally disabled.
|
||||
# The virtual FS contains XML-wrapped KB documents whose paths
|
||||
# would double-nest under SANDBOX_DOCUMENTS_ROOT (e.g.
|
||||
# /home/daytona/documents/documents/Report.xml) and uploading
|
||||
# all KB docs on the first execute_code call adds significant
|
||||
# latency. Re-enable once path mapping is fixed and upload is
|
||||
# limited to user-created scratch files.
|
||||
# files = runtime.state.get("files") or {}
|
||||
# await sync_files_to_sandbox(self._thread_id, files, sandbox, is_new)
|
||||
result = await sandbox.aexecute(command, timeout=timeout)
|
||||
|
|
|
|||
|
|
@ -58,6 +58,14 @@ class KBSearchPlan(BaseModel):
|
|||
default=None,
|
||||
description="Optional ISO end date or datetime for KB search filtering.",
|
||||
)
|
||||
is_recency_query: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"True when the user's intent is primarily about recency or temporal "
|
||||
"ordering (e.g. 'latest', 'newest', 'most recent', 'last uploaded') "
|
||||
"rather than topical relevance."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _extract_text_from_message(message: BaseMessage) -> str:
|
||||
|
|
@ -245,7 +253,7 @@ def _build_kb_planner_prompt(
|
|||
return (
|
||||
"You optimize internal knowledge-base search inputs for document retrieval.\n"
|
||||
"Return JSON only with this exact shape:\n"
|
||||
'{"optimized_query":"string","start_date":"ISO string or null","end_date":"ISO string or null"}\n\n'
|
||||
'{"optimized_query":"string","start_date":"ISO string or null","end_date":"ISO string or null","is_recency_query":bool}\n\n'
|
||||
"Rules:\n"
|
||||
"- Preserve the user's intent.\n"
|
||||
"- Rewrite the query to improve retrieval using concrete entities, acronyms, projects, tools, people, and document-specific terms when helpful.\n"
|
||||
|
|
@ -253,6 +261,11 @@ def _build_kb_planner_prompt(
|
|||
"- Only use date filters when the latest user request or recent dialogue clearly implies a time range.\n"
|
||||
"- If you use date filters, prefer returning both bounds.\n"
|
||||
"- If no date filter is useful, return null for both dates.\n"
|
||||
'- Set "is_recency_query" to true ONLY when the user\'s primary intent is about '
|
||||
"recency or temporal ordering rather than topical relevance. Examples: "
|
||||
'"latest file", "newest upload", "most recent document", "what did I save last", '
|
||||
'"show me files from today", "last thing I added". '
|
||||
"When true, results will be sorted by date instead of relevance.\n"
|
||||
"- Do not include markdown, prose, or explanations.\n\n"
|
||||
f"Today's UTC date: {today}\n\n"
|
||||
f"Recent conversation:\n{recent_conversation or '(none)'}\n\n"
|
||||
|
|
@ -506,6 +519,135 @@ def _resolve_search_types(
|
|||
return list(expanded) if expanded else None
|
||||
|
||||
|
||||
_RECENCY_MAX_CHUNKS_PER_DOC = 5
|
||||
|
||||
|
||||
async def browse_recent_documents(
|
||||
*,
|
||||
search_space_id: int,
|
||||
document_type: list[str] | None = None,
|
||||
top_k: int = 10,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return documents ordered by recency (newest first), no relevance ranking.
|
||||
|
||||
Used when the user's intent is temporal ("latest file", "most recent upload")
|
||||
and hybrid search would produce poor results because the query has no
|
||||
meaningful topical signal.
|
||||
"""
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from app.db import DocumentType
|
||||
|
||||
async with shielded_async_session() as session:
|
||||
base_conditions = [
|
||||
Document.search_space_id == search_space_id,
|
||||
func.coalesce(Document.status["state"].astext, "ready") != "deleting",
|
||||
]
|
||||
|
||||
if document_type is not None:
|
||||
import contextlib
|
||||
|
||||
doc_type_enums = []
|
||||
for dt in document_type:
|
||||
if isinstance(dt, str):
|
||||
with contextlib.suppress(KeyError):
|
||||
doc_type_enums.append(DocumentType[dt])
|
||||
else:
|
||||
doc_type_enums.append(dt)
|
||||
if doc_type_enums:
|
||||
if len(doc_type_enums) == 1:
|
||||
base_conditions.append(Document.document_type == doc_type_enums[0])
|
||||
else:
|
||||
base_conditions.append(Document.document_type.in_(doc_type_enums))
|
||||
|
||||
if start_date is not None:
|
||||
base_conditions.append(Document.updated_at >= start_date)
|
||||
if end_date is not None:
|
||||
base_conditions.append(Document.updated_at <= end_date)
|
||||
|
||||
doc_query = (
|
||||
select(Document)
|
||||
.where(*base_conditions)
|
||||
.order_by(Document.updated_at.desc())
|
||||
.limit(top_k)
|
||||
)
|
||||
result = await session.execute(doc_query)
|
||||
documents = result.scalars().unique().all()
|
||||
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
doc_ids = [d.id for d in documents]
|
||||
|
||||
numbered = (
|
||||
select(
|
||||
Chunk.id.label("chunk_id"),
|
||||
Chunk.document_id,
|
||||
Chunk.content,
|
||||
func.row_number()
|
||||
.over(partition_by=Chunk.document_id, order_by=Chunk.id)
|
||||
.label("rn"),
|
||||
)
|
||||
.where(Chunk.document_id.in_(doc_ids))
|
||||
.subquery("numbered")
|
||||
)
|
||||
|
||||
chunk_query = (
|
||||
select(numbered.c.chunk_id, numbered.c.document_id, numbered.c.content)
|
||||
.where(numbered.c.rn <= _RECENCY_MAX_CHUNKS_PER_DOC)
|
||||
.order_by(numbered.c.document_id, numbered.c.chunk_id)
|
||||
)
|
||||
chunk_result = await session.execute(chunk_query)
|
||||
fetched_chunks = chunk_result.all()
|
||||
|
||||
doc_chunks: dict[int, list[dict[str, Any]]] = {d.id: [] for d in documents}
|
||||
for row in fetched_chunks:
|
||||
if row.document_id in doc_chunks:
|
||||
doc_chunks[row.document_id].append(
|
||||
{"chunk_id": row.chunk_id, "content": row.content}
|
||||
)
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
for doc in documents:
|
||||
chunks_list = doc_chunks.get(doc.id, [])
|
||||
metadata = doc.document_metadata or {}
|
||||
results.append(
|
||||
{
|
||||
"document_id": doc.id,
|
||||
"content": "\n\n".join(
|
||||
c["content"] for c in chunks_list if c.get("content")
|
||||
),
|
||||
"score": 0.0,
|
||||
"chunks": chunks_list,
|
||||
"matched_chunk_ids": [],
|
||||
"document": {
|
||||
"id": doc.id,
|
||||
"title": doc.title,
|
||||
"document_type": (
|
||||
doc.document_type.value
|
||||
if getattr(doc, "document_type", None)
|
||||
else None
|
||||
),
|
||||
"metadata": metadata,
|
||||
},
|
||||
"source": (
|
||||
doc.document_type.value
|
||||
if getattr(doc, "document_type", None)
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"browse_recent_documents: %d docs returned for space=%d",
|
||||
len(results),
|
||||
search_space_id,
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
async def search_knowledge_base(
|
||||
*,
|
||||
query: str,
|
||||
|
|
@ -704,10 +846,13 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
*,
|
||||
messages: Sequence[BaseMessage],
|
||||
user_text: str,
|
||||
) -> tuple[str, datetime | None, datetime | None]:
|
||||
"""Rewrite the KB query and infer optional date filters with the LLM."""
|
||||
) -> tuple[str, datetime | None, datetime | None, bool]:
|
||||
"""Rewrite the KB query and infer optional date filters with the LLM.
|
||||
|
||||
Returns (optimized_query, start_date, end_date, is_recency_query).
|
||||
"""
|
||||
if self.llm is None:
|
||||
return user_text, None, None
|
||||
return user_text, None, None, False
|
||||
|
||||
recent_conversation = _render_recent_conversation(
|
||||
messages,
|
||||
|
|
@ -734,15 +879,18 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
plan.start_date,
|
||||
plan.end_date,
|
||||
)
|
||||
is_recency = plan.is_recency_query
|
||||
_perf_log.info(
|
||||
"[kb_fs_middleware] planner in %.3fs query=%r optimized=%r start=%s end=%s",
|
||||
"[kb_fs_middleware] planner in %.3fs query=%r optimized=%r "
|
||||
"start=%s end=%s recency=%s",
|
||||
loop.time() - t0,
|
||||
user_text[:80],
|
||||
optimized_query[:120],
|
||||
start_date.isoformat() if start_date else None,
|
||||
end_date.isoformat() if end_date else None,
|
||||
is_recency,
|
||||
)
|
||||
return optimized_query, start_date, end_date
|
||||
return optimized_query, start_date, end_date, is_recency
|
||||
except (json.JSONDecodeError, ValidationError, ValueError) as exc:
|
||||
logger.warning(
|
||||
"KB planner returned invalid output, using raw query: %s", exc
|
||||
|
|
@ -750,7 +898,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
except Exception as exc: # pragma: no cover - defensive fallback
|
||||
logger.warning("KB planner failed, using raw query: %s", exc)
|
||||
|
||||
return user_text, None, None
|
||||
return user_text, None, None, False
|
||||
|
||||
def before_agent( # type: ignore[override]
|
||||
self,
|
||||
|
|
@ -789,7 +937,12 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
|
||||
t0 = _perf_log and asyncio.get_event_loop().time()
|
||||
existing_files = state.get("files")
|
||||
planned_query, start_date, end_date = await self._plan_search_inputs(
|
||||
(
|
||||
planned_query,
|
||||
start_date,
|
||||
end_date,
|
||||
is_recency,
|
||||
) = await self._plan_search_inputs(
|
||||
messages=messages,
|
||||
user_text=user_text,
|
||||
)
|
||||
|
|
@ -805,16 +958,28 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
# messages within the same agent instance.
|
||||
self.mentioned_document_ids = []
|
||||
|
||||
# --- 2. Run KB hybrid search ---
|
||||
search_results = await search_knowledge_base(
|
||||
query=planned_query,
|
||||
search_space_id=self.search_space_id,
|
||||
available_connectors=self.available_connectors,
|
||||
available_document_types=self.available_document_types,
|
||||
top_k=self.top_k,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
# --- 2. Run KB search (recency browse or hybrid) ---
|
||||
if is_recency:
|
||||
doc_types = _resolve_search_types(
|
||||
self.available_connectors, self.available_document_types
|
||||
)
|
||||
search_results = await browse_recent_documents(
|
||||
search_space_id=self.search_space_id,
|
||||
document_type=doc_types,
|
||||
top_k=self.top_k,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
else:
|
||||
search_results = await search_knowledge_base(
|
||||
query=planned_query,
|
||||
search_space_id=self.search_space_id,
|
||||
available_connectors=self.available_connectors,
|
||||
available_document_types=self.available_document_types,
|
||||
top_k=self.top_k,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
# --- 3. Merge: mentioned first, then search (dedup by doc id) ---
|
||||
seen_doc_ids: set[int] = set()
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import contextlib
|
|||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
from daytona import (
|
||||
|
|
@ -55,9 +56,16 @@ class _TimeoutAwareSandbox(DaytonaSandbox):
|
|||
) -> ExecuteResponse: # type: ignore[override]
|
||||
return await asyncio.to_thread(self.execute, command, timeout=timeout)
|
||||
|
||||
def download_file(self, path: str) -> bytes:
|
||||
"""Download a file from the sandbox filesystem."""
|
||||
return self._sandbox.fs.download_file(path)
|
||||
|
||||
|
||||
_daytona_client: Daytona | None = None
|
||||
_client_lock = threading.Lock()
|
||||
_sandbox_cache: dict[str, _TimeoutAwareSandbox] = {}
|
||||
_sandbox_locks: dict[str, asyncio.Lock] = {}
|
||||
_sandbox_locks_mu = asyncio.Lock()
|
||||
_seeded_files: dict[str, dict[str, str]] = {}
|
||||
_SANDBOX_CACHE_MAX_SIZE = 20
|
||||
THREAD_LABEL_KEY = "surfsense_thread"
|
||||
|
|
@ -70,14 +78,15 @@ def is_sandbox_enabled() -> bool:
|
|||
|
||||
def _get_client() -> Daytona:
|
||||
global _daytona_client
|
||||
if _daytona_client is None:
|
||||
config = DaytonaConfig(
|
||||
api_key=os.environ.get("DAYTONA_API_KEY", ""),
|
||||
api_url=os.environ.get("DAYTONA_API_URL", "https://app.daytona.io/api"),
|
||||
target=os.environ.get("DAYTONA_TARGET", "us"),
|
||||
)
|
||||
_daytona_client = Daytona(config)
|
||||
return _daytona_client
|
||||
with _client_lock:
|
||||
if _daytona_client is None:
|
||||
config = DaytonaConfig(
|
||||
api_key=os.environ.get("DAYTONA_API_KEY", ""),
|
||||
api_url=os.environ.get("DAYTONA_API_URL", "https://app.daytona.io/api"),
|
||||
target=os.environ.get("DAYTONA_TARGET", "us"),
|
||||
)
|
||||
_daytona_client = Daytona(config)
|
||||
return _daytona_client
|
||||
|
||||
|
||||
def _sandbox_create_params(
|
||||
|
|
@ -129,14 +138,16 @@ def _find_or_create(thread_id: str) -> tuple[_TimeoutAwareSandbox, bool]:
|
|||
try:
|
||||
client.delete(sandbox)
|
||||
except Exception:
|
||||
logger.debug("Could not delete broken sandbox %s", sandbox.id, exc_info=True)
|
||||
logger.debug(
|
||||
"Could not delete broken sandbox %s", sandbox.id, exc_info=True
|
||||
)
|
||||
sandbox = client.create(_sandbox_create_params(labels))
|
||||
is_new = True
|
||||
logger.info("Created replacement sandbox: %s", sandbox.id)
|
||||
elif sandbox.state != SandboxState.STARTED:
|
||||
sandbox.wait_for_sandbox_start(timeout=60)
|
||||
|
||||
except Exception:
|
||||
except DaytonaError:
|
||||
logger.info("No existing sandbox for thread %s — creating one", thread_id)
|
||||
sandbox = client.create(_sandbox_create_params(labels))
|
||||
is_new = True
|
||||
|
|
@ -145,6 +156,16 @@ def _find_or_create(thread_id: str) -> tuple[_TimeoutAwareSandbox, bool]:
|
|||
return _TimeoutAwareSandbox(sandbox=sandbox), is_new
|
||||
|
||||
|
||||
async def _get_thread_lock(key: str) -> asyncio.Lock:
|
||||
"""Return a per-thread asyncio lock, creating one if needed."""
|
||||
async with _sandbox_locks_mu:
|
||||
lock = _sandbox_locks.get(key)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_sandbox_locks[key] = lock
|
||||
return lock
|
||||
|
||||
|
||||
async def get_or_create_sandbox(
|
||||
thread_id: int | str,
|
||||
) -> tuple[_TimeoutAwareSandbox, bool]:
|
||||
|
|
@ -152,25 +173,52 @@ async def get_or_create_sandbox(
|
|||
|
||||
Uses an in-process cache keyed by thread_id so subsequent messages
|
||||
in the same conversation reuse the sandbox object without an API call.
|
||||
A per-thread async lock prevents duplicate sandbox creation from
|
||||
concurrent requests.
|
||||
|
||||
Returns:
|
||||
Tuple of (sandbox, is_new). *is_new* is True when a fresh sandbox
|
||||
was created, signalling that file tracking should be reset.
|
||||
"""
|
||||
key = str(thread_id)
|
||||
cached = _sandbox_cache.get(key)
|
||||
if cached is not None:
|
||||
logger.info("Reusing cached sandbox for thread %s", key)
|
||||
return cached, False
|
||||
sandbox, is_new = await asyncio.to_thread(_find_or_create, key)
|
||||
_sandbox_cache[key] = sandbox
|
||||
lock = await _get_thread_lock(key)
|
||||
|
||||
if len(_sandbox_cache) > _SANDBOX_CACHE_MAX_SIZE:
|
||||
oldest_key = next(iter(_sandbox_cache))
|
||||
_sandbox_cache.pop(oldest_key, None)
|
||||
logger.debug("Evicted oldest sandbox cache entry: %s", oldest_key)
|
||||
async with lock:
|
||||
cached = _sandbox_cache.get(key)
|
||||
if cached is not None:
|
||||
logger.info("Reusing cached sandbox for thread %s", key)
|
||||
return cached, False
|
||||
sandbox, is_new = await asyncio.to_thread(_find_or_create, key)
|
||||
_sandbox_cache[key] = sandbox
|
||||
|
||||
return sandbox, is_new
|
||||
if len(_sandbox_cache) > _SANDBOX_CACHE_MAX_SIZE:
|
||||
oldest_key = next(iter(_sandbox_cache))
|
||||
if oldest_key != key:
|
||||
evicted = _sandbox_cache.pop(oldest_key, None)
|
||||
_seeded_files.pop(oldest_key, None)
|
||||
logger.debug("Evicted sandbox cache entry: %s", oldest_key)
|
||||
if evicted is not None:
|
||||
_schedule_sandbox_delete(evicted)
|
||||
|
||||
return sandbox, is_new
|
||||
|
||||
|
||||
def _schedule_sandbox_delete(sandbox: _TimeoutAwareSandbox) -> None:
|
||||
"""Best-effort background deletion of an evicted sandbox."""
|
||||
|
||||
def _delete() -> None:
|
||||
try:
|
||||
client = _get_client()
|
||||
client.delete(sandbox._sandbox)
|
||||
logger.info("Deleted evicted sandbox: %s", sandbox._sandbox.id)
|
||||
except Exception:
|
||||
logger.debug("Could not delete evicted sandbox", exc_info=True)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.run_in_executor(None, _delete)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
async def sync_files_to_sandbox(
|
||||
|
|
|
|||
|
|
@ -2,10 +2,10 @@ import logging
|
|||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.services.confluence import ConfluenceToolMetadataService
|
||||
|
||||
|
|
@ -65,54 +65,28 @@ def create_create_confluence_page_tool(
|
|||
"connector_type": "confluence",
|
||||
}
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "confluence_page_creation",
|
||||
"action": {
|
||||
"tool": "create_confluence_page",
|
||||
"params": {
|
||||
"title": title,
|
||||
"content": content,
|
||||
"space_id": space_id,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="confluence_page_creation",
|
||||
tool_name="create_confluence_page",
|
||||
params={
|
||||
"title": title,
|
||||
"content": content,
|
||||
"space_id": space_id,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The page was not created.",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_title = final_params.get("title", title)
|
||||
final_content = final_params.get("content", content) or ""
|
||||
final_space_id = final_params.get("space_id", space_id)
|
||||
final_connector_id = final_params.get("connector_id", connector_id)
|
||||
final_title = result.params.get("title", title)
|
||||
final_content = result.params.get("content", content) or ""
|
||||
final_space_id = result.params.get("space_id", space_id)
|
||||
final_connector_id = result.params.get("connector_id", connector_id)
|
||||
|
||||
if not final_title or not final_title.strip():
|
||||
return {"status": "error", "message": "Page title cannot be empty."}
|
||||
|
|
|
|||
|
|
@ -2,10 +2,10 @@ import logging
|
|||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.services.confluence import ConfluenceToolMetadataService
|
||||
|
||||
|
|
@ -74,54 +74,28 @@ def create_delete_confluence_page_tool(
|
|||
document_id = page_data["document_id"]
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "confluence_page_deletion",
|
||||
"action": {
|
||||
"tool": "delete_confluence_page",
|
||||
"params": {
|
||||
"page_id": page_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="confluence_page_deletion",
|
||||
tool_name="delete_confluence_page",
|
||||
params={
|
||||
"page_id": page_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The page was not deleted.",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_page_id = final_params.get("page_id", page_id)
|
||||
final_connector_id = final_params.get(
|
||||
final_page_id = result.params.get("page_id", page_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
|
|
|
|||
|
|
@ -2,10 +2,10 @@ import logging
|
|||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.services.confluence import ConfluenceToolMetadataService
|
||||
|
||||
|
|
@ -78,62 +78,36 @@ def create_update_confluence_page_tool(
|
|||
document_id = page_data.get("document_id")
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "confluence_page_update",
|
||||
"action": {
|
||||
"tool": "update_confluence_page",
|
||||
"params": {
|
||||
"page_id": page_id,
|
||||
"document_id": document_id,
|
||||
"new_title": new_title,
|
||||
"new_content": new_content,
|
||||
"version": current_version,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="confluence_page_update",
|
||||
tool_name="update_confluence_page",
|
||||
params={
|
||||
"page_id": page_id,
|
||||
"document_id": document_id,
|
||||
"new_title": new_title,
|
||||
"new_content": new_content,
|
||||
"version": current_version,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The page was not updated.",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_page_id = final_params.get("page_id", page_id)
|
||||
final_title = final_params.get("new_title", new_title) or current_title
|
||||
final_content = final_params.get("new_content", new_content)
|
||||
final_page_id = result.params.get("page_id", page_id)
|
||||
final_title = result.params.get("new_title", new_title) or current_title
|
||||
final_content = result.params.get("new_content", new_content)
|
||||
if final_content is None:
|
||||
final_content = current_body
|
||||
final_version = final_params.get("version", current_version)
|
||||
final_connector_id = final_params.get(
|
||||
final_version = result.params.get("version", current_version)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_document_id = final_params.get("document_id", document_id)
|
||||
final_document_id = result.params.get("document_id", document_id)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
|
|
|
|||
|
|
@ -5,10 +5,10 @@ from pathlib import Path
|
|||
from typing import Any, Literal
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.dropbox.client import DropboxClient
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
|
|
@ -159,56 +159,30 @@ def create_create_dropbox_file_tool(
|
|||
"supported_types": _SUPPORTED_TYPES,
|
||||
}
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "dropbox_file_creation",
|
||||
"action": {
|
||||
"tool": "create_dropbox_file",
|
||||
"params": {
|
||||
"name": name,
|
||||
"file_type": file_type,
|
||||
"content": content,
|
||||
"connector_id": None,
|
||||
"parent_folder_path": None,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="dropbox_file_creation",
|
||||
tool_name="create_dropbox_file",
|
||||
params={
|
||||
"name": name,
|
||||
"file_type": file_type,
|
||||
"content": content,
|
||||
"connector_id": None,
|
||||
"parent_folder_path": None,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The file was not created.",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_name = final_params.get("name", name)
|
||||
final_file_type = final_params.get("file_type", file_type)
|
||||
final_content = final_params.get("content", content)
|
||||
final_connector_id = final_params.get("connector_id")
|
||||
final_parent_folder_path = final_params.get("parent_folder_path")
|
||||
final_name = result.params.get("name", name)
|
||||
final_file_type = result.params.get("file_type", file_type)
|
||||
final_content = result.params.get("content", content)
|
||||
final_connector_id = result.params.get("connector_id")
|
||||
final_parent_folder_path = result.params.get("parent_folder_path")
|
||||
|
||||
if not final_name or not final_name.strip():
|
||||
return {"status": "error", "message": "File name cannot be empty."}
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ import logging
|
|||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy import String, and_, cast, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.dropbox.client import DropboxClient
|
||||
from app.db import (
|
||||
Document,
|
||||
|
|
@ -174,53 +174,26 @@ def create_delete_dropbox_file_tool(
|
|||
},
|
||||
}
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "dropbox_file_trash",
|
||||
"action": {
|
||||
"tool": "delete_dropbox_file",
|
||||
"params": {
|
||||
"file_path": file_path,
|
||||
"connector_id": connector.id,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="dropbox_file_trash",
|
||||
tool_name="delete_dropbox_file",
|
||||
params={
|
||||
"file_path": file_path,
|
||||
"connector_id": connector.id,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The file was not deleted. Do not ask again or suggest alternatives.",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_file_path = final_params.get("file_path", file_path)
|
||||
final_connector_id = final_params.get("connector_id", connector.id)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
final_file_path = result.params.get("file_path", file_path)
|
||||
final_connector_id = result.params.get("connector_id", connector.id)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if final_connector_id != connector.id:
|
||||
result = await db_session.execute(
|
||||
|
|
|
|||
|
|
@ -6,9 +6,9 @@ from email.mime.text import MIMEText
|
|||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.services.gmail import GmailToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -85,60 +85,32 @@ def create_create_gmail_draft_tool(
|
|||
logger.info(
|
||||
f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "gmail_draft_creation",
|
||||
"action": {
|
||||
"tool": "create_gmail_draft",
|
||||
"params": {
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"body": body,
|
||||
"cc": cc,
|
||||
"bcc": bcc,
|
||||
"connector_id": None,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="gmail_draft_creation",
|
||||
tool_name="create_gmail_draft",
|
||||
params={
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"body": body,
|
||||
"cc": cc,
|
||||
"bcc": bcc,
|
||||
"connector_id": None,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The draft was not created. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_to = final_params.get("to", to)
|
||||
final_subject = final_params.get("subject", subject)
|
||||
final_body = final_params.get("body", body)
|
||||
final_cc = final_params.get("cc", cc)
|
||||
final_bcc = final_params.get("bcc", bcc)
|
||||
final_connector_id = final_params.get("connector_id")
|
||||
final_to = result.params.get("to", to)
|
||||
final_subject = result.params.get("subject", subject)
|
||||
final_body = result.params.get("body", body)
|
||||
final_cc = result.params.get("cc", cc)
|
||||
final_bcc = result.params.get("bcc", bcc)
|
||||
final_connector_id = result.params.get("connector_id")
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
|
|
|
|||
|
|
@ -6,9 +6,9 @@ from email.mime.text import MIMEText
|
|||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.services.gmail import GmailToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -86,60 +86,32 @@ def create_send_gmail_email_tool(
|
|||
logger.info(
|
||||
f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "gmail_email_send",
|
||||
"action": {
|
||||
"tool": "send_gmail_email",
|
||||
"params": {
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"body": body,
|
||||
"cc": cc,
|
||||
"bcc": bcc,
|
||||
"connector_id": None,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="gmail_email_send",
|
||||
tool_name="send_gmail_email",
|
||||
params={
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"body": body,
|
||||
"cc": cc,
|
||||
"bcc": bcc,
|
||||
"connector_id": None,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The email was not sent. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_to = final_params.get("to", to)
|
||||
final_subject = final_params.get("subject", subject)
|
||||
final_body = final_params.get("body", body)
|
||||
final_cc = final_params.get("cc", cc)
|
||||
final_bcc = final_params.get("bcc", bcc)
|
||||
final_connector_id = final_params.get("connector_id")
|
||||
final_to = result.params.get("to", to)
|
||||
final_subject = result.params.get("subject", subject)
|
||||
final_body = result.params.get("body", body)
|
||||
final_cc = result.params.get("cc", cc)
|
||||
final_bcc = result.params.get("bcc", bcc)
|
||||
final_connector_id = result.params.get("connector_id")
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
|
|
|
|||
|
|
@ -4,9 +4,9 @@ from datetime import datetime
|
|||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.services.gmail import GmailToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -101,56 +101,28 @@ def create_trash_gmail_email_tool(
|
|||
logger.info(
|
||||
f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "gmail_email_trash",
|
||||
"action": {
|
||||
"tool": "trash_gmail_email",
|
||||
"params": {
|
||||
"message_id": message_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="gmail_email_trash",
|
||||
tool_name="trash_gmail_email",
|
||||
params={
|
||||
"message_id": message_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
edited_action = decision.get("edited_action")
|
||||
final_params: dict[str, Any] = {}
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_message_id = final_params.get("message_id", message_id)
|
||||
final_connector_id = final_params.get(
|
||||
final_message_id = result.params.get("message_id", message_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -6,9 +6,9 @@ from email.mime.text import MIMEText
|
|||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.services.gmail import GmailToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -122,65 +122,37 @@ def create_update_gmail_draft_tool(
|
|||
f"Requesting approval for updating Gmail draft: '{original_subject}' "
|
||||
f"(message_id={message_id}, draft_id={draft_id_from_context})"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "gmail_draft_update",
|
||||
"action": {
|
||||
"tool": "update_gmail_draft",
|
||||
"params": {
|
||||
"message_id": message_id,
|
||||
"draft_id": draft_id_from_context,
|
||||
"to": final_to_default,
|
||||
"subject": final_subject_default,
|
||||
"body": body,
|
||||
"cc": cc,
|
||||
"bcc": bcc,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="gmail_draft_update",
|
||||
tool_name="update_gmail_draft",
|
||||
params={
|
||||
"message_id": message_id,
|
||||
"draft_id": draft_id_from_context,
|
||||
"to": final_to_default,
|
||||
"subject": final_subject_default,
|
||||
"body": body,
|
||||
"cc": cc,
|
||||
"bcc": bcc,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_to = final_params.get("to", final_to_default)
|
||||
final_subject = final_params.get("subject", final_subject_default)
|
||||
final_body = final_params.get("body", body)
|
||||
final_cc = final_params.get("cc", cc)
|
||||
final_bcc = final_params.get("bcc", bcc)
|
||||
final_connector_id = final_params.get(
|
||||
final_to = result.params.get("to", final_to_default)
|
||||
final_subject = result.params.get("subject", final_subject_default)
|
||||
final_body = result.params.get("body", body)
|
||||
final_cc = result.params.get("cc", cc)
|
||||
final_bcc = result.params.get("bcc", bcc)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_draft_id = final_params.get("draft_id", draft_id_from_context)
|
||||
final_draft_id = result.params.get("draft_id", draft_id_from_context)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -6,9 +6,9 @@ from typing import Any
|
|||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -90,63 +90,35 @@ def create_create_calendar_event_tool(
|
|||
logger.info(
|
||||
f"Requesting approval for creating calendar event: summary='{summary}'"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "google_calendar_event_creation",
|
||||
"action": {
|
||||
"tool": "create_calendar_event",
|
||||
"params": {
|
||||
"summary": summary,
|
||||
"start_datetime": start_datetime,
|
||||
"end_datetime": end_datetime,
|
||||
"description": description,
|
||||
"location": location,
|
||||
"attendees": attendees,
|
||||
"timezone": context.get("timezone"),
|
||||
"connector_id": None,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="google_calendar_event_creation",
|
||||
tool_name="create_calendar_event",
|
||||
params={
|
||||
"summary": summary,
|
||||
"start_datetime": start_datetime,
|
||||
"end_datetime": end_datetime,
|
||||
"description": description,
|
||||
"location": location,
|
||||
"attendees": attendees,
|
||||
"timezone": context.get("timezone"),
|
||||
"connector_id": None,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The event was not created. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_summary = final_params.get("summary", summary)
|
||||
final_start_datetime = final_params.get("start_datetime", start_datetime)
|
||||
final_end_datetime = final_params.get("end_datetime", end_datetime)
|
||||
final_description = final_params.get("description", description)
|
||||
final_location = final_params.get("location", location)
|
||||
final_attendees = final_params.get("attendees", attendees)
|
||||
final_connector_id = final_params.get("connector_id")
|
||||
final_summary = result.params.get("summary", summary)
|
||||
final_start_datetime = result.params.get("start_datetime", start_datetime)
|
||||
final_end_datetime = result.params.get("end_datetime", end_datetime)
|
||||
final_description = result.params.get("description", description)
|
||||
final_location = result.params.get("location", location)
|
||||
final_attendees = result.params.get("attendees", attendees)
|
||||
final_connector_id = result.params.get("connector_id")
|
||||
|
||||
if not final_summary or not final_summary.strip():
|
||||
return {"status": "error", "message": "Event summary cannot be empty."}
|
||||
|
|
|
|||
|
|
@ -6,9 +6,9 @@ from typing import Any
|
|||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -100,56 +100,28 @@ def create_delete_calendar_event_tool(
|
|||
logger.info(
|
||||
f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "google_calendar_event_deletion",
|
||||
"action": {
|
||||
"tool": "delete_calendar_event",
|
||||
"params": {
|
||||
"event_id": event_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="google_calendar_event_deletion",
|
||||
tool_name="delete_calendar_event",
|
||||
params={
|
||||
"event_id": event_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
edited_action = decision.get("edited_action")
|
||||
final_params: dict[str, Any] = {}
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_event_id = final_params.get("event_id", event_id)
|
||||
final_connector_id = final_params.get(
|
||||
final_event_id = result.params.get("event_id", event_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -6,9 +6,9 @@ from typing import Any
|
|||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -116,71 +116,45 @@ def create_update_calendar_event_tool(
|
|||
logger.info(
|
||||
f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "google_calendar_event_update",
|
||||
"action": {
|
||||
"tool": "update_calendar_event",
|
||||
"params": {
|
||||
"event_id": event_id,
|
||||
"document_id": document_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"new_summary": new_summary,
|
||||
"new_start_datetime": new_start_datetime,
|
||||
"new_end_datetime": new_end_datetime,
|
||||
"new_description": new_description,
|
||||
"new_location": new_location,
|
||||
"new_attendees": new_attendees,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="google_calendar_event_update",
|
||||
tool_name="update_calendar_event",
|
||||
params={
|
||||
"event_id": event_id,
|
||||
"document_id": document_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"new_summary": new_summary,
|
||||
"new_start_datetime": new_start_datetime,
|
||||
"new_end_datetime": new_end_datetime,
|
||||
"new_description": new_description,
|
||||
"new_location": new_location,
|
||||
"new_attendees": new_attendees,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The event was not updated. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
edited_action = decision.get("edited_action")
|
||||
final_params: dict[str, Any] = {}
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_event_id = final_params.get("event_id", event_id)
|
||||
final_connector_id = final_params.get(
|
||||
final_event_id = result.params.get("event_id", event_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_new_summary = final_params.get("new_summary", new_summary)
|
||||
final_new_start_datetime = final_params.get(
|
||||
final_new_summary = result.params.get("new_summary", new_summary)
|
||||
final_new_start_datetime = result.params.get(
|
||||
"new_start_datetime", new_start_datetime
|
||||
)
|
||||
final_new_end_datetime = final_params.get(
|
||||
final_new_end_datetime = result.params.get(
|
||||
"new_end_datetime", new_end_datetime
|
||||
)
|
||||
final_new_description = final_params.get("new_description", new_description)
|
||||
final_new_location = final_params.get("new_location", new_location)
|
||||
final_new_attendees = final_params.get("new_attendees", new_attendees)
|
||||
final_new_description = result.params.get(
|
||||
"new_description", new_description
|
||||
)
|
||||
final_new_location = result.params.get("new_location", new_location)
|
||||
final_new_attendees = result.params.get("new_attendees", new_attendees)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@ from typing import Any, Literal
|
|||
|
||||
from googleapiclient.errors import HttpError
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.google_drive.client import GoogleDriveClient
|
||||
from app.connectors.google_drive.file_types import GOOGLE_DOC, GOOGLE_SHEET
|
||||
from app.services.google_drive import GoogleDriveToolMetadataService
|
||||
|
|
@ -99,58 +99,30 @@ def create_create_google_drive_file_tool(
|
|||
logger.info(
|
||||
f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "google_drive_file_creation",
|
||||
"action": {
|
||||
"tool": "create_google_drive_file",
|
||||
"params": {
|
||||
"name": name,
|
||||
"file_type": file_type,
|
||||
"content": content,
|
||||
"connector_id": None,
|
||||
"parent_folder_id": None,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="google_drive_file_creation",
|
||||
tool_name="create_google_drive_file",
|
||||
params={
|
||||
"name": name,
|
||||
"file_type": file_type,
|
||||
"content": content,
|
||||
"connector_id": None,
|
||||
"parent_folder_id": None,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The file was not created. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_name = final_params.get("name", name)
|
||||
final_file_type = final_params.get("file_type", file_type)
|
||||
final_content = final_params.get("content", content)
|
||||
final_connector_id = final_params.get("connector_id")
|
||||
final_parent_folder_id = final_params.get("parent_folder_id")
|
||||
final_name = result.params.get("name", name)
|
||||
final_file_type = result.params.get("file_type", file_type)
|
||||
final_content = result.params.get("content", content)
|
||||
final_connector_id = result.params.get("connector_id")
|
||||
final_parent_folder_id = result.params.get("parent_folder_id")
|
||||
|
||||
if not final_name or not final_name.strip():
|
||||
return {"status": "error", "message": "File name cannot be empty."}
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@ from typing import Any
|
|||
|
||||
from googleapiclient.errors import HttpError
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.google_drive.client import GoogleDriveClient
|
||||
from app.services.google_drive import GoogleDriveToolMetadataService
|
||||
|
||||
|
|
@ -101,56 +101,28 @@ def create_delete_google_drive_file_tool(
|
|||
logger.info(
|
||||
f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "google_drive_file_trash",
|
||||
"action": {
|
||||
"tool": "delete_google_drive_file",
|
||||
"params": {
|
||||
"file_id": file_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="google_drive_file_trash",
|
||||
tool_name="delete_google_drive_file",
|
||||
params={
|
||||
"file_id": file_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
edited_action = decision.get("edited_action")
|
||||
final_params: dict[str, Any] = {}
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_file_id = final_params.get("file_id", file_id)
|
||||
final_connector_id = final_params.get(
|
||||
final_file_id = result.params.get("file_id", file_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
|
|
|
|||
142
surfsense_backend/app/agents/new_chat/tools/hitl.py
Normal file
142
surfsense_backend/app/agents/new_chat/tools/hitl.py
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
"""Unified HITL (Human-in-the-Loop) approval utility.
|
||||
|
||||
Provides a single ``request_approval()`` function that encapsulates the
|
||||
interrupt payload creation, decision parsing, and parameter merging logic
|
||||
shared by every sensitive tool (native connectors and MCP tools alike).
|
||||
|
||||
Usage inside a tool::
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
|
||||
result = request_approval(
|
||||
action_type="gmail_email_send",
|
||||
tool_name="send_gmail_email",
|
||||
params={"to": to, "subject": subject, "body": body},
|
||||
context=context,
|
||||
)
|
||||
if result.rejected:
|
||||
return {"status": "rejected", "message": "User declined."}
|
||||
# result.params contains the final (possibly edited) parameters
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from langgraph.types import interrupt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class HITLResult:
|
||||
"""Outcome of a human-in-the-loop approval request."""
|
||||
|
||||
rejected: bool
|
||||
decision_type: str
|
||||
params: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _parse_decision(approval: Any) -> tuple[str, dict[str, Any]]:
|
||||
"""Extract the first valid decision and its edited parameters.
|
||||
|
||||
Returns:
|
||||
(decision_type, edited_params) where *decision_type* is one of
|
||||
``"approve"``, ``"edit"``, or ``"reject"`` and *edited_params* is
|
||||
the dict of user-modified arguments (empty when there are none).
|
||||
|
||||
Raises:
|
||||
ValueError: when no usable decision dict can be found.
|
||||
"""
|
||||
decisions_raw = approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
decisions = decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
|
||||
if not decisions:
|
||||
raise ValueError("No approval decision received")
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type: str = (
|
||||
decision.get("type") or decision.get("decision_type") or "approve"
|
||||
)
|
||||
|
||||
edited_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
edited_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
edited_params = decision["args"]
|
||||
|
||||
return decision_type, edited_params
|
||||
|
||||
|
||||
def request_approval(
|
||||
*,
|
||||
action_type: str,
|
||||
tool_name: str,
|
||||
params: dict[str, Any],
|
||||
context: dict[str, Any] | None = None,
|
||||
trusted_tools: list[str] | None = None,
|
||||
) -> HITLResult:
|
||||
"""Pause the graph for user approval and return the decision.
|
||||
|
||||
This is a **synchronous** helper (not ``async``) because
|
||||
``langgraph.types.interrupt`` is itself synchronous — it raises a
|
||||
``GraphInterrupt`` exception that the LangGraph runtime catches.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
action_type:
|
||||
A label that the frontend uses to select the correct approval card
|
||||
(e.g. ``"gmail_email_send"``, ``"mcp_tool_call"``).
|
||||
tool_name:
|
||||
The registered LangChain tool name (e.g. ``"send_gmail_email"``).
|
||||
params:
|
||||
The original tool arguments. These are shown in the approval card
|
||||
and used as defaults when the user does not edit anything.
|
||||
context:
|
||||
Rich metadata from a ``*ToolMetadataService`` (accounts, folders,
|
||||
labels, etc.). For MCP tools this can hold the server name and
|
||||
tool description.
|
||||
trusted_tools:
|
||||
An allow-list of tool names the user has previously marked as
|
||||
"Always Allow". If *tool_name* appears in this list, HITL is
|
||||
skipped and the tool executes immediately.
|
||||
|
||||
Returns
|
||||
-------
|
||||
HITLResult
|
||||
``result.rejected`` is ``True`` when the user chose to deny the
|
||||
action. Otherwise ``result.params`` contains the final parameter
|
||||
dict — either the originals or the user-edited version merged on
|
||||
top.
|
||||
"""
|
||||
if trusted_tools and tool_name in trusted_tools:
|
||||
logger.info("Tool '%s' is user-trusted — skipping HITL", tool_name)
|
||||
return HITLResult(rejected=False, decision_type="trusted", params=dict(params))
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": action_type,
|
||||
"action": {"tool": tool_name, "params": params},
|
||||
"context": context or {},
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
decision_type, edited_params = _parse_decision(approval)
|
||||
except ValueError:
|
||||
logger.warning("No approval decision received for %s", tool_name)
|
||||
return HITLResult(rejected=False, decision_type="error", params=params)
|
||||
|
||||
logger.info("User decision for %s: %s", tool_name, decision_type)
|
||||
|
||||
if decision_type == "reject":
|
||||
return HITLResult(rejected=True, decision_type="reject", params=params)
|
||||
|
||||
final_params = {**params, **edited_params} if edited_params else dict(params)
|
||||
return HITLResult(rejected=False, decision_type=decision_type, params=final_params)
|
||||
|
|
@ -3,10 +3,10 @@ import logging
|
|||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.jira_history import JiraHistoryConnector
|
||||
from app.services.jira import JiraToolMetadataService
|
||||
|
||||
|
|
@ -69,58 +69,32 @@ def create_create_jira_issue_tool(
|
|||
"connector_type": "jira",
|
||||
}
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "jira_issue_creation",
|
||||
"action": {
|
||||
"tool": "create_jira_issue",
|
||||
"params": {
|
||||
"project_key": project_key,
|
||||
"summary": summary,
|
||||
"issue_type": issue_type,
|
||||
"description": description,
|
||||
"priority": priority,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="jira_issue_creation",
|
||||
tool_name="create_jira_issue",
|
||||
params={
|
||||
"project_key": project_key,
|
||||
"summary": summary,
|
||||
"issue_type": issue_type,
|
||||
"description": description,
|
||||
"priority": priority,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The issue was not created.",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_project_key = final_params.get("project_key", project_key)
|
||||
final_summary = final_params.get("summary", summary)
|
||||
final_issue_type = final_params.get("issue_type", issue_type)
|
||||
final_description = final_params.get("description", description)
|
||||
final_priority = final_params.get("priority", priority)
|
||||
final_connector_id = final_params.get("connector_id", connector_id)
|
||||
final_project_key = result.params.get("project_key", project_key)
|
||||
final_summary = result.params.get("summary", summary)
|
||||
final_issue_type = result.params.get("issue_type", issue_type)
|
||||
final_description = result.params.get("description", description)
|
||||
final_priority = result.params.get("priority", priority)
|
||||
final_connector_id = result.params.get("connector_id", connector_id)
|
||||
|
||||
if not final_summary or not final_summary.strip():
|
||||
return {"status": "error", "message": "Issue summary cannot be empty."}
|
||||
|
|
|
|||
|
|
@ -3,10 +3,10 @@ import logging
|
|||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.jira_history import JiraHistoryConnector
|
||||
from app.services.jira import JiraToolMetadataService
|
||||
|
||||
|
|
@ -71,54 +71,28 @@ def create_delete_jira_issue_tool(
|
|||
document_id = issue_data["document_id"]
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "jira_issue_deletion",
|
||||
"action": {
|
||||
"tool": "delete_jira_issue",
|
||||
"params": {
|
||||
"issue_key": issue_key,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="jira_issue_deletion",
|
||||
tool_name="delete_jira_issue",
|
||||
params={
|
||||
"issue_key": issue_key,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The issue was not deleted.",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_issue_key = final_params.get("issue_key", issue_key)
|
||||
final_connector_id = final_params.get(
|
||||
final_issue_key = result.params.get("issue_key", issue_key)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
|
|
|
|||
|
|
@ -3,10 +3,10 @@ import logging
|
|||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.jira_history import JiraHistoryConnector
|
||||
from app.services.jira import JiraToolMetadataService
|
||||
|
||||
|
|
@ -75,60 +75,34 @@ def create_update_jira_issue_tool(
|
|||
document_id = issue_data.get("document_id")
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "jira_issue_update",
|
||||
"action": {
|
||||
"tool": "update_jira_issue",
|
||||
"params": {
|
||||
"issue_key": issue_key,
|
||||
"document_id": document_id,
|
||||
"new_summary": new_summary,
|
||||
"new_description": new_description,
|
||||
"new_priority": new_priority,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="jira_issue_update",
|
||||
tool_name="update_jira_issue",
|
||||
params={
|
||||
"issue_key": issue_key,
|
||||
"document_id": document_id,
|
||||
"new_summary": new_summary,
|
||||
"new_description": new_description,
|
||||
"new_priority": new_priority,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The issue was not updated.",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_issue_key = final_params.get("issue_key", issue_key)
|
||||
final_summary = final_params.get("new_summary", new_summary)
|
||||
final_description = final_params.get("new_description", new_description)
|
||||
final_priority = final_params.get("new_priority", new_priority)
|
||||
final_connector_id = final_params.get(
|
||||
final_issue_key = result.params.get("issue_key", issue_key)
|
||||
final_summary = result.params.get("new_summary", new_summary)
|
||||
final_description = result.params.get("new_description", new_description)
|
||||
final_priority = result.params.get("new_priority", new_priority)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_document_id = final_params.get("document_id", document_id)
|
||||
final_document_id = result.params.get("document_id", document_id)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ import logging
|
|||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.linear_connector import LinearAPIError, LinearConnector
|
||||
from app.services.linear import LinearToolMetadataService
|
||||
|
||||
|
|
@ -94,65 +94,37 @@ def create_create_linear_issue_tool(
|
|||
}
|
||||
|
||||
logger.info(f"Requesting approval for creating Linear issue: '{title}'")
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "linear_issue_creation",
|
||||
"action": {
|
||||
"tool": "create_linear_issue",
|
||||
"params": {
|
||||
"title": title,
|
||||
"description": description,
|
||||
"team_id": None,
|
||||
"state_id": None,
|
||||
"assignee_id": None,
|
||||
"priority": None,
|
||||
"label_ids": [],
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="linear_issue_creation",
|
||||
tool_name="create_linear_issue",
|
||||
params={
|
||||
"title": title,
|
||||
"description": description,
|
||||
"team_id": None,
|
||||
"state_id": None,
|
||||
"assignee_id": None,
|
||||
"priority": None,
|
||||
"label_ids": [],
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
logger.info("Linear issue creation rejected by user")
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The issue was not created. Do not ask again or suggest alternatives.",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_title = final_params.get("title", title)
|
||||
final_description = final_params.get("description", description)
|
||||
final_team_id = final_params.get("team_id")
|
||||
final_state_id = final_params.get("state_id")
|
||||
final_assignee_id = final_params.get("assignee_id")
|
||||
final_priority = final_params.get("priority")
|
||||
final_label_ids = final_params.get("label_ids") or []
|
||||
final_connector_id = final_params.get("connector_id", connector_id)
|
||||
final_title = result.params.get("title", title)
|
||||
final_description = result.params.get("description", description)
|
||||
final_team_id = result.params.get("team_id")
|
||||
final_state_id = result.params.get("state_id")
|
||||
final_assignee_id = result.params.get("assignee_id")
|
||||
final_priority = result.params.get("priority")
|
||||
final_label_ids = result.params.get("label_ids") or []
|
||||
final_connector_id = result.params.get("connector_id", connector_id)
|
||||
|
||||
if not final_title or not final_title.strip():
|
||||
logger.error("Title is empty or contains only whitespace")
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ import logging
|
|||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.linear_connector import LinearAPIError, LinearConnector
|
||||
from app.services.linear import LinearToolMetadataService
|
||||
|
||||
|
|
@ -114,57 +114,29 @@ def create_delete_linear_issue_tool(
|
|||
f"Requesting approval for deleting Linear issue: '{issue_ref}' "
|
||||
f"(id={issue_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "linear_issue_deletion",
|
||||
"action": {
|
||||
"tool": "delete_linear_issue",
|
||||
"params": {
|
||||
"issue_id": issue_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="linear_issue_deletion",
|
||||
tool_name="delete_linear_issue",
|
||||
params={
|
||||
"issue_id": issue_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
logger.info("Linear issue deletion rejected by user")
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The issue was not deleted. Do not ask again or suggest alternatives.",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
edited_action = decision.get("edited_action")
|
||||
final_params: dict[str, Any] = {}
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_issue_id = final_params.get("issue_id", issue_id)
|
||||
final_connector_id = final_params.get(
|
||||
final_issue_id = result.params.get("issue_id", issue_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
logger.info(
|
||||
f"Deleting Linear issue with final params: issue_id={final_issue_id}, "
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ import logging
|
|||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.linear_connector import LinearAPIError, LinearConnector
|
||||
from app.services.linear import LinearKBSyncService, LinearToolMetadataService
|
||||
|
||||
|
|
@ -130,69 +130,45 @@ def create_update_linear_issue_tool(
|
|||
logger.info(
|
||||
f"Requesting approval for updating Linear issue: '{issue_ref}' (id={issue_id})"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "linear_issue_update",
|
||||
"action": {
|
||||
"tool": "update_linear_issue",
|
||||
"params": {
|
||||
"issue_id": issue_id,
|
||||
"document_id": document_id,
|
||||
"new_title": new_title,
|
||||
"new_description": new_description,
|
||||
"new_state_id": new_state_id,
|
||||
"new_assignee_id": new_assignee_id,
|
||||
"new_priority": new_priority,
|
||||
"new_label_ids": new_label_ids,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="linear_issue_update",
|
||||
tool_name="update_linear_issue",
|
||||
params={
|
||||
"issue_id": issue_id,
|
||||
"document_id": document_id,
|
||||
"new_title": new_title,
|
||||
"new_description": new_description,
|
||||
"new_state_id": new_state_id,
|
||||
"new_assignee_id": new_assignee_id,
|
||||
"new_priority": new_priority,
|
||||
"new_label_ids": new_label_ids,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
logger.info("Linear issue update rejected by user")
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The issue was not updated. Do not ask again or suggest alternatives.",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
edited_action = decision.get("edited_action")
|
||||
final_params: dict[str, Any] = {}
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_issue_id = final_params.get("issue_id", issue_id)
|
||||
final_document_id = final_params.get("document_id", document_id)
|
||||
final_new_title = final_params.get("new_title", new_title)
|
||||
final_new_description = final_params.get("new_description", new_description)
|
||||
final_new_state_id = final_params.get("new_state_id", new_state_id)
|
||||
final_new_assignee_id = final_params.get("new_assignee_id", new_assignee_id)
|
||||
final_new_priority = final_params.get("new_priority", new_priority)
|
||||
final_new_label_ids: list[str] | None = final_params.get(
|
||||
final_issue_id = result.params.get("issue_id", issue_id)
|
||||
final_document_id = result.params.get("document_id", document_id)
|
||||
final_new_title = result.params.get("new_title", new_title)
|
||||
final_new_description = result.params.get(
|
||||
"new_description", new_description
|
||||
)
|
||||
final_new_state_id = result.params.get("new_state_id", new_state_id)
|
||||
final_new_assignee_id = result.params.get(
|
||||
"new_assignee_id", new_assignee_id
|
||||
)
|
||||
final_new_priority = result.params.get("new_priority", new_priority)
|
||||
final_new_label_ids: list[str] | None = result.params.get(
|
||||
"new_label_ids", new_label_ids
|
||||
)
|
||||
final_connector_id = final_params.get(
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,11 @@ Supports both transport types:
|
|||
- stdio: Local process-based MCP servers (command, args, env)
|
||||
- streamable-http/http/sse: Remote HTTP-based MCP servers (url, headers)
|
||||
|
||||
This implements real MCP protocol support similar to Cursor's implementation.
|
||||
All MCP tools are unconditionally gated by HITL (Human-in-the-Loop) approval.
|
||||
Per the MCP spec: "Clients MUST consider tool annotations to be untrusted unless
|
||||
they come from trusted servers." Users can bypass HITL for specific tools by
|
||||
clicking "Always Allow", which adds the tool name to the connector's
|
||||
``config.trusted_tools`` allow-list.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -21,6 +25,7 @@ from pydantic import BaseModel, create_model
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.agents.new_chat.tools.mcp_client import MCPClient
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
|
|
@ -49,27 +54,15 @@ def _create_dynamic_input_model_from_schema(
|
|||
tool_name: str,
|
||||
input_schema: dict[str, Any],
|
||||
) -> type[BaseModel]:
|
||||
"""Create a Pydantic model from MCP tool's JSON schema.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool (used for model class name)
|
||||
input_schema: JSON schema from MCP server
|
||||
|
||||
Returns:
|
||||
Pydantic model class for tool input validation
|
||||
|
||||
"""
|
||||
"""Create a Pydantic model from MCP tool's JSON schema."""
|
||||
properties = input_schema.get("properties", {})
|
||||
required_fields = input_schema.get("required", [])
|
||||
|
||||
# Build Pydantic field definitions
|
||||
field_definitions = {}
|
||||
for param_name, param_schema in properties.items():
|
||||
param_description = param_schema.get("description", "")
|
||||
is_required = param_name in required_fields
|
||||
|
||||
# Use Any type for complex schemas to preserve structure
|
||||
# This allows the MCP server to do its own validation
|
||||
from typing import Any as AnyType
|
||||
|
||||
from pydantic import Field
|
||||
|
|
@ -85,7 +78,6 @@ def _create_dynamic_input_model_from_schema(
|
|||
Field(None, description=param_description),
|
||||
)
|
||||
|
||||
# Create dynamic model
|
||||
model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input"
|
||||
return create_model(model_name, **field_definitions)
|
||||
|
||||
|
|
@ -93,55 +85,70 @@ def _create_dynamic_input_model_from_schema(
|
|||
async def _create_mcp_tool_from_definition_stdio(
|
||||
tool_def: dict[str, Any],
|
||||
mcp_client: MCPClient,
|
||||
*,
|
||||
connector_name: str = "",
|
||||
connector_id: int | None = None,
|
||||
trusted_tools: list[str] | None = None,
|
||||
) -> StructuredTool:
|
||||
"""Create a LangChain tool from an MCP tool definition (stdio transport).
|
||||
|
||||
Args:
|
||||
tool_def: Tool definition from MCP server with name, description, input_schema
|
||||
mcp_client: MCP client instance for calling the tool
|
||||
|
||||
Returns:
|
||||
LangChain StructuredTool instance
|
||||
|
||||
All MCP tools are unconditionally wrapped with HITL approval.
|
||||
``request_approval()`` is called OUTSIDE the try/except so that
|
||||
``GraphInterrupt`` propagates cleanly to LangGraph.
|
||||
"""
|
||||
tool_name = tool_def.get("name", "unnamed_tool")
|
||||
tool_description = tool_def.get("description", "No description provided")
|
||||
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
||||
|
||||
# Log the actual schema for debugging
|
||||
logger.info(f"MCP tool '{tool_name}' input schema: {input_schema}")
|
||||
|
||||
# Create dynamic input model from schema
|
||||
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
|
||||
|
||||
async def mcp_tool_call(**kwargs) -> str:
|
||||
"""Execute the MCP tool call via the client with retry support."""
|
||||
logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}")
|
||||
|
||||
# HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph
|
||||
hitl_result = request_approval(
|
||||
action_type="mcp_tool_call",
|
||||
tool_name=tool_name,
|
||||
params=kwargs,
|
||||
context={
|
||||
"mcp_server": connector_name,
|
||||
"tool_description": tool_description,
|
||||
"mcp_transport": "stdio",
|
||||
"mcp_connector_id": connector_id,
|
||||
},
|
||||
trusted_tools=trusted_tools,
|
||||
)
|
||||
if hitl_result.rejected:
|
||||
return "Tool call rejected by user."
|
||||
call_kwargs = hitl_result.params
|
||||
|
||||
try:
|
||||
# Connect to server and call tool (connect has built-in retry logic)
|
||||
async with mcp_client.connect():
|
||||
result = await mcp_client.call_tool(tool_name, kwargs)
|
||||
result = await mcp_client.call_tool(tool_name, call_kwargs)
|
||||
return str(result)
|
||||
except RuntimeError as e:
|
||||
# Connection failures after all retries
|
||||
error_msg = f"MCP tool '{tool_name}' connection failed after retries: {e!s}"
|
||||
logger.error(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
except Exception as e:
|
||||
# Tool execution or other errors
|
||||
error_msg = f"MCP tool '{tool_name}' execution failed: {e!s}"
|
||||
logger.exception(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
|
||||
# Create StructuredTool with response_format to preserve exact schema
|
||||
tool = StructuredTool(
|
||||
name=tool_name,
|
||||
description=tool_description,
|
||||
coroutine=mcp_tool_call,
|
||||
args_schema=input_model,
|
||||
# Store the original MCP schema as metadata so we can access it later
|
||||
metadata={"mcp_input_schema": input_schema, "mcp_transport": "stdio"},
|
||||
metadata={
|
||||
"mcp_input_schema": input_schema,
|
||||
"mcp_transport": "stdio",
|
||||
"hitl": True,
|
||||
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Created MCP tool (stdio): '{tool_name}'")
|
||||
|
|
@ -152,43 +159,54 @@ async def _create_mcp_tool_from_definition_http(
|
|||
tool_def: dict[str, Any],
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
*,
|
||||
connector_name: str = "",
|
||||
connector_id: int | None = None,
|
||||
trusted_tools: list[str] | None = None,
|
||||
) -> StructuredTool:
|
||||
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
|
||||
|
||||
Args:
|
||||
tool_def: Tool definition from MCP server with name, description, input_schema
|
||||
url: URL of the MCP server
|
||||
headers: HTTP headers for authentication
|
||||
|
||||
Returns:
|
||||
LangChain StructuredTool instance
|
||||
|
||||
All MCP tools are unconditionally wrapped with HITL approval.
|
||||
``request_approval()`` is called OUTSIDE the try/except so that
|
||||
``GraphInterrupt`` propagates cleanly to LangGraph.
|
||||
"""
|
||||
tool_name = tool_def.get("name", "unnamed_tool")
|
||||
tool_description = tool_def.get("description", "No description provided")
|
||||
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
||||
|
||||
# Log the actual schema for debugging
|
||||
logger.info(f"MCP HTTP tool '{tool_name}' input schema: {input_schema}")
|
||||
|
||||
# Create dynamic input model from schema
|
||||
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
|
||||
|
||||
async def mcp_http_tool_call(**kwargs) -> str:
|
||||
"""Execute the MCP tool call via HTTP transport."""
|
||||
logger.info(f"MCP HTTP tool '{tool_name}' called with params: {kwargs}")
|
||||
|
||||
# HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph
|
||||
hitl_result = request_approval(
|
||||
action_type="mcp_tool_call",
|
||||
tool_name=tool_name,
|
||||
params=kwargs,
|
||||
context={
|
||||
"mcp_server": connector_name,
|
||||
"tool_description": tool_description,
|
||||
"mcp_transport": "http",
|
||||
"mcp_connector_id": connector_id,
|
||||
},
|
||||
trusted_tools=trusted_tools,
|
||||
)
|
||||
if hitl_result.rejected:
|
||||
return "Tool call rejected by user."
|
||||
call_kwargs = hitl_result.params
|
||||
|
||||
try:
|
||||
async with (
|
||||
streamablehttp_client(url, headers=headers) as (read, write, _),
|
||||
ClientSession(read, write) as session,
|
||||
):
|
||||
await session.initialize()
|
||||
response = await session.call_tool(tool_name, arguments=call_kwargs)
|
||||
|
||||
# Call the tool
|
||||
response = await session.call_tool(tool_name, arguments=kwargs)
|
||||
|
||||
# Extract content from response
|
||||
result = []
|
||||
for content in response.content:
|
||||
if hasattr(content, "text"):
|
||||
|
|
@ -209,7 +227,6 @@ async def _create_mcp_tool_from_definition_http(
|
|||
logger.exception(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
|
||||
# Create StructuredTool
|
||||
tool = StructuredTool(
|
||||
name=tool_name,
|
||||
description=tool_description,
|
||||
|
|
@ -219,6 +236,8 @@ async def _create_mcp_tool_from_definition_http(
|
|||
"mcp_input_schema": input_schema,
|
||||
"mcp_transport": "http",
|
||||
"mcp_url": url,
|
||||
"hitl": True,
|
||||
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -230,20 +249,11 @@ async def _load_stdio_mcp_tools(
|
|||
connector_id: int,
|
||||
connector_name: str,
|
||||
server_config: dict[str, Any],
|
||||
trusted_tools: list[str] | None = None,
|
||||
) -> list[StructuredTool]:
|
||||
"""Load tools from a stdio-based MCP server.
|
||||
|
||||
Args:
|
||||
connector_id: Connector ID for logging
|
||||
connector_name: Connector name for logging
|
||||
server_config: Server configuration with command, args, env
|
||||
|
||||
Returns:
|
||||
List of tools from the MCP server
|
||||
"""
|
||||
"""Load tools from a stdio-based MCP server."""
|
||||
tools: list[StructuredTool] = []
|
||||
|
||||
# Validate required command field
|
||||
command = server_config.get("command")
|
||||
if not command or not isinstance(command, str):
|
||||
logger.warning(
|
||||
|
|
@ -251,7 +261,6 @@ async def _load_stdio_mcp_tools(
|
|||
)
|
||||
return tools
|
||||
|
||||
# Validate args field (must be list if present)
|
||||
args = server_config.get("args", [])
|
||||
if not isinstance(args, list):
|
||||
logger.warning(
|
||||
|
|
@ -259,7 +268,6 @@ async def _load_stdio_mcp_tools(
|
|||
)
|
||||
return tools
|
||||
|
||||
# Validate env field (must be dict if present)
|
||||
env = server_config.get("env", {})
|
||||
if not isinstance(env, dict):
|
||||
logger.warning(
|
||||
|
|
@ -267,10 +275,8 @@ async def _load_stdio_mcp_tools(
|
|||
)
|
||||
return tools
|
||||
|
||||
# Create MCP client
|
||||
mcp_client = MCPClient(command, args, env)
|
||||
|
||||
# Connect and discover tools
|
||||
async with mcp_client.connect():
|
||||
tool_definitions = await mcp_client.list_tools()
|
||||
|
||||
|
|
@ -279,10 +285,15 @@ async def _load_stdio_mcp_tools(
|
|||
f"'{command}' (connector {connector_id})"
|
||||
)
|
||||
|
||||
# Create LangChain tools from definitions
|
||||
for tool_def in tool_definitions:
|
||||
try:
|
||||
tool = await _create_mcp_tool_from_definition_stdio(tool_def, mcp_client)
|
||||
tool = await _create_mcp_tool_from_definition_stdio(
|
||||
tool_def,
|
||||
mcp_client,
|
||||
connector_name=connector_name,
|
||||
connector_id=connector_id,
|
||||
trusted_tools=trusted_tools,
|
||||
)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
|
|
@ -297,20 +308,11 @@ async def _load_http_mcp_tools(
|
|||
connector_id: int,
|
||||
connector_name: str,
|
||||
server_config: dict[str, Any],
|
||||
trusted_tools: list[str] | None = None,
|
||||
) -> list[StructuredTool]:
|
||||
"""Load tools from an HTTP-based MCP server.
|
||||
|
||||
Args:
|
||||
connector_id: Connector ID for logging
|
||||
connector_name: Connector name for logging
|
||||
server_config: Server configuration with url, headers
|
||||
|
||||
Returns:
|
||||
List of tools from the MCP server
|
||||
"""
|
||||
"""Load tools from an HTTP-based MCP server."""
|
||||
tools: list[StructuredTool] = []
|
||||
|
||||
# Validate required url field
|
||||
url = server_config.get("url")
|
||||
if not url or not isinstance(url, str):
|
||||
logger.warning(
|
||||
|
|
@ -318,7 +320,6 @@ async def _load_http_mcp_tools(
|
|||
)
|
||||
return tools
|
||||
|
||||
# Validate headers field (must be dict if present)
|
||||
headers = server_config.get("headers", {})
|
||||
if not isinstance(headers, dict):
|
||||
logger.warning(
|
||||
|
|
@ -326,7 +327,6 @@ async def _load_http_mcp_tools(
|
|||
)
|
||||
return tools
|
||||
|
||||
# Connect and discover tools via HTTP
|
||||
try:
|
||||
async with (
|
||||
streamablehttp_client(url, headers=headers) as (read, write, _),
|
||||
|
|
@ -334,7 +334,6 @@ async def _load_http_mcp_tools(
|
|||
):
|
||||
await session.initialize()
|
||||
|
||||
# List available tools
|
||||
response = await session.list_tools()
|
||||
tool_definitions = []
|
||||
for tool in response.tools:
|
||||
|
|
@ -353,11 +352,15 @@ async def _load_http_mcp_tools(
|
|||
f"'{url}' (connector {connector_id})"
|
||||
)
|
||||
|
||||
# Create LangChain tools from definitions
|
||||
for tool_def in tool_definitions:
|
||||
try:
|
||||
tool = await _create_mcp_tool_from_definition_http(
|
||||
tool_def, url, headers
|
||||
tool_def,
|
||||
url,
|
||||
headers,
|
||||
connector_name=connector_name,
|
||||
connector_id=connector_id,
|
||||
trusted_tools=trusted_tools,
|
||||
)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
|
|
@ -398,14 +401,6 @@ async def load_mcp_tools(
|
|||
|
||||
Results are cached per search space for up to 5 minutes to avoid
|
||||
re-spawning MCP server processes on every chat message.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
search_space_id: User's search space ID
|
||||
|
||||
Returns:
|
||||
List of LangChain StructuredTool instances
|
||||
|
||||
"""
|
||||
_evict_expired_mcp_cache()
|
||||
|
||||
|
|
@ -436,6 +431,7 @@ async def load_mcp_tools(
|
|||
try:
|
||||
config = connector.config or {}
|
||||
server_config = config.get("server_config", {})
|
||||
trusted_tools = config.get("trusted_tools", [])
|
||||
|
||||
if not server_config or not isinstance(server_config, dict):
|
||||
logger.warning(
|
||||
|
|
@ -447,11 +443,17 @@ async def load_mcp_tools(
|
|||
|
||||
if transport in ("streamable-http", "http", "sse"):
|
||||
connector_tools = await _load_http_mcp_tools(
|
||||
connector.id, connector.name, server_config
|
||||
connector.id,
|
||||
connector.name,
|
||||
server_config,
|
||||
trusted_tools=trusted_tools,
|
||||
)
|
||||
else:
|
||||
connector_tools = await _load_stdio_mcp_tools(
|
||||
connector.id, connector.name, server_config
|
||||
connector.id,
|
||||
connector.name,
|
||||
server_config,
|
||||
trusted_tools=trusted_tools,
|
||||
)
|
||||
|
||||
tools.extend(connector_tools)
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ import logging
|
|||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
|
||||
from app.services.notion import NotionToolMetadataService
|
||||
|
||||
|
|
@ -99,61 +99,29 @@ def create_create_notion_page_tool(
|
|||
}
|
||||
|
||||
logger.info(f"Requesting approval for creating Notion page: '{title}'")
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "notion_page_creation",
|
||||
"action": {
|
||||
"tool": "create_notion_page",
|
||||
"params": {
|
||||
"title": title,
|
||||
"content": content,
|
||||
"parent_page_id": None,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="notion_page_creation",
|
||||
tool_name="create_notion_page",
|
||||
params={
|
||||
"title": title,
|
||||
"content": content,
|
||||
"parent_page_id": None,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No approval decision received",
|
||||
}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
logger.info("Notion page creation rejected by user")
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The page was not created. Do not ask again or suggest alternatives.",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
edited_action = decision.get("edited_action")
|
||||
final_params: dict[str, Any] = {}
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
# Some interrupt payloads place args directly on the decision.
|
||||
final_params = decision["args"]
|
||||
|
||||
final_title = final_params.get("title", title)
|
||||
final_content = final_params.get("content", content)
|
||||
final_parent_page_id = final_params.get("parent_page_id")
|
||||
final_connector_id = final_params.get("connector_id", connector_id)
|
||||
final_title = result.params.get("title", title)
|
||||
final_content = result.params.get("content", content)
|
||||
final_parent_page_id = result.params.get("parent_page_id")
|
||||
final_connector_id = result.params.get("connector_id", connector_id)
|
||||
|
||||
if not final_title or not final_title.strip():
|
||||
logger.error("Title is empty or contains only whitespace")
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ import logging
|
|||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
|
||||
from app.services.notion.tool_metadata_service import NotionToolMetadataService
|
||||
|
||||
|
|
@ -114,63 +114,29 @@ def create_delete_notion_page_tool(
|
|||
f"Requesting approval for deleting Notion page: '{page_title}' (page_id={page_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
|
||||
# Request approval before deleting
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "notion_page_deletion",
|
||||
"action": {
|
||||
"tool": "delete_notion_page",
|
||||
"params": {
|
||||
"page_id": page_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="notion_page_deletion",
|
||||
tool_name="delete_notion_page",
|
||||
params={
|
||||
"page_id": page_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No approval decision received",
|
||||
}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
logger.info("Notion page deletion rejected by user")
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The page was not deleted. Do not ask again or suggest alternatives.",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
# Extract edited action arguments (if user modified the checkbox)
|
||||
edited_action = decision.get("edited_action")
|
||||
final_params: dict[str, Any] = {}
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
# Some interrupt payloads place args directly on the decision.
|
||||
final_params = decision["args"]
|
||||
|
||||
final_page_id = final_params.get("page_id", page_id)
|
||||
final_connector_id = final_params.get(
|
||||
final_page_id = result.params.get("page_id", page_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
logger.info(
|
||||
f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ import logging
|
|||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
|
||||
from app.services.notion import NotionToolMetadataService
|
||||
|
||||
|
|
@ -127,59 +127,27 @@ def create_update_notion_page_tool(
|
|||
logger.info(
|
||||
f"Requesting approval for updating Notion page: '{page_title}' (page_id={page_id})"
|
||||
)
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "notion_page_update",
|
||||
"action": {
|
||||
"tool": "update_notion_page",
|
||||
"params": {
|
||||
"page_id": page_id,
|
||||
"content": content,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="notion_page_update",
|
||||
tool_name="update_notion_page",
|
||||
params={
|
||||
"page_id": page_id,
|
||||
"content": content,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
logger.warning("No approval decision received")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No approval decision received",
|
||||
}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
logger.info("Notion page update rejected by user")
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The page was not updated. Do not ask again or suggest alternatives.",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
edited_action = decision.get("edited_action")
|
||||
final_params: dict[str, Any] = {}
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
# Some interrupt payloads place args directly on the decision.
|
||||
final_params = decision["args"]
|
||||
|
||||
final_page_id = final_params.get("page_id", page_id)
|
||||
final_content = final_params.get("content", content)
|
||||
final_connector_id = final_params.get(
|
||||
final_page_id = result.params.get("page_id", page_id)
|
||||
final_content = result.params.get("content", content)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,10 +5,10 @@ from pathlib import Path
|
|||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.onedrive.client import OneDriveClient
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
|
|
@ -145,54 +145,28 @@ def create_create_onedrive_file_tool(
|
|||
"parent_folders": parent_folders,
|
||||
}
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "onedrive_file_creation",
|
||||
"action": {
|
||||
"tool": "create_onedrive_file",
|
||||
"params": {
|
||||
"name": name,
|
||||
"content": content,
|
||||
"connector_id": None,
|
||||
"parent_folder_id": None,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="onedrive_file_creation",
|
||||
tool_name="create_onedrive_file",
|
||||
params={
|
||||
"name": name,
|
||||
"content": content,
|
||||
"connector_id": None,
|
||||
"parent_folder_id": None,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The file was not created.",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_name = final_params.get("name", name)
|
||||
final_content = final_params.get("content", content)
|
||||
final_connector_id = final_params.get("connector_id")
|
||||
final_parent_folder_id = final_params.get("parent_folder_id")
|
||||
final_name = result.params.get("name", name)
|
||||
final_content = result.params.get("content", content)
|
||||
final_connector_id = result.params.get("connector_id")
|
||||
final_parent_folder_id = result.params.get("parent_folder_id")
|
||||
|
||||
if not final_name or not final_name.strip():
|
||||
return {"status": "error", "message": "File name cannot be empty."}
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ import logging
|
|||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
from sqlalchemy import String, and_, cast, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.onedrive.client import OneDriveClient
|
||||
from app.db import (
|
||||
Document,
|
||||
|
|
@ -174,53 +174,26 @@ def create_delete_onedrive_file_tool(
|
|||
},
|
||||
}
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": "onedrive_file_trash",
|
||||
"action": {
|
||||
"tool": "delete_onedrive_file",
|
||||
"params": {
|
||||
"file_id": file_id,
|
||||
"connector_id": connector.id,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
},
|
||||
"context": context,
|
||||
}
|
||||
result = request_approval(
|
||||
action_type="onedrive_file_trash",
|
||||
tool_name="delete_onedrive_file",
|
||||
params={
|
||||
"file_id": file_id,
|
||||
"connector_id": connector.id,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
decisions_raw = (
|
||||
approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
)
|
||||
decisions = (
|
||||
decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
)
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
if not decisions:
|
||||
return {"status": "error", "message": "No approval decision received"}
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type = decision.get("type") or decision.get("decision_type")
|
||||
logger.info(f"User decision: {decision_type}")
|
||||
|
||||
if decision_type == "reject":
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
final_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
final_params = decision["args"]
|
||||
|
||||
final_file_id = final_params.get("file_id", file_id)
|
||||
final_connector_id = final_params.get("connector_id", connector.id)
|
||||
final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb)
|
||||
final_file_id = result.params.get("file_id", file_id)
|
||||
final_connector_id = result.params.get("connector_id", connector.id)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if final_connector_id != connector.id:
|
||||
result = await db_session.execute(
|
||||
|
|
|
|||
|
|
@ -1321,6 +1321,10 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
Integer, nullable=True, default=0
|
||||
) # For vision/screenshot analysis, defaults to Auto mode
|
||||
|
||||
ai_file_sort_enabled = Column(
|
||||
Boolean, nullable=False, default=False, server_default="false"
|
||||
)
|
||||
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
|
|
|||
|
|
@ -422,6 +422,8 @@ class IndexingPipelineService:
|
|||
)
|
||||
log_index_success(ctx, chunk_count=len(chunks))
|
||||
|
||||
await self._enqueue_ai_sort_if_enabled(document)
|
||||
|
||||
except RETRYABLE_LLM_ERRORS as e:
|
||||
log_retryable_llm_error(ctx, e)
|
||||
await rollback_and_persist_failure(
|
||||
|
|
@ -457,6 +459,29 @@ class IndexingPipelineService:
|
|||
|
||||
return document
|
||||
|
||||
async def _enqueue_ai_sort_if_enabled(self, document: Document) -> None:
|
||||
"""Fire-and-forget: enqueue incremental AI sort if the search space has it enabled."""
|
||||
try:
|
||||
from app.db import SearchSpace
|
||||
|
||||
result = await self.session.execute(
|
||||
select(SearchSpace.ai_file_sort_enabled).where(
|
||||
SearchSpace.id == document.search_space_id
|
||||
)
|
||||
)
|
||||
enabled = result.scalar()
|
||||
if not enabled:
|
||||
return
|
||||
|
||||
from app.tasks.celery_tasks.document_tasks import ai_sort_document_task
|
||||
|
||||
user_id = str(document.created_by_id) if document.created_by_id else ""
|
||||
ai_sort_document_task.delay(document.search_space_id, user_id, document.id)
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to enqueue AI sort for document %s", document.id, exc_info=True
|
||||
)
|
||||
|
||||
async def index_batch_parallel(
|
||||
self,
|
||||
connector_docs: list[ConnectorDocument],
|
||||
|
|
|
|||
|
|
@ -139,9 +139,19 @@ async def get_editor_content(
|
|||
status_code=409,
|
||||
detail="This document is still being processed. Please wait a moment and try again.",
|
||||
)
|
||||
if state == "failed":
|
||||
reason = (
|
||||
doc_status.get("reason", "Unknown error")
|
||||
if isinstance(doc_status, dict)
|
||||
else "Unknown error"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Processing failed: {reason}. You can delete this document and re-upload it.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="This document has no viewable content yet. It may still be syncing. Try again in a few seconds, or re-upload if the issue persists.",
|
||||
detail="This document has no content. It may not have been processed correctly. Try deleting and re-uploading it.",
|
||||
)
|
||||
|
||||
markdown_content = "\n\n".join(chunk_contents)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,9 @@ router = APIRouter()
|
|||
@router.get("/search-spaces/{search_space_id}/export")
|
||||
async def export_knowledge_base(
|
||||
search_space_id: int,
|
||||
folder_id: int | None = Query(None, description="Export only this folder's subtree"),
|
||||
folder_id: int | None = Query(
|
||||
None, description="Export only this folder's subtree"
|
||||
),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
|
|
|
|||
|
|
@ -86,9 +86,8 @@ async def download_sandbox_file(
|
|||
|
||||
# Fall back to live sandbox download
|
||||
try:
|
||||
sandbox = await get_or_create_sandbox(thread_id)
|
||||
raw_sandbox = sandbox._sandbox
|
||||
content: bytes = await asyncio.to_thread(raw_sandbox.fs.download_file, path)
|
||||
sandbox, _ = await get_or_create_sandbox(thread_id)
|
||||
content: bytes = await asyncio.to_thread(sandbox.download_file, path)
|
||||
except Exception as exc:
|
||||
logger.warning("Sandbox file download failed for %s: %s", path, exc)
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -636,9 +636,16 @@ async def delete_search_source_connector(
|
|||
)
|
||||
|
||||
# Delete the connector record
|
||||
search_space_id = db_connector.search_space_id
|
||||
is_mcp = db_connector.connector_type == SearchSourceConnectorType.MCP_CONNECTOR
|
||||
await session.delete(db_connector)
|
||||
await session.commit()
|
||||
|
||||
if is_mcp:
|
||||
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
||||
|
||||
invalidate_mcp_tools_cache(search_space_id)
|
||||
|
||||
logger.info(
|
||||
f"Connector {connector_id} ({connector_name}) deleted successfully. "
|
||||
f"Total documents deleted: {total_deleted}"
|
||||
|
|
@ -3624,3 +3631,114 @@ async def get_drive_picker_token(
|
|||
status_code=500,
|
||||
detail="Failed to retrieve access token. Check server logs for details.",
|
||||
) from e
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MCP Tool Trust (Allow-List) Routes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class MCPTrustToolRequest(BaseModel):
|
||||
tool_name: str
|
||||
|
||||
|
||||
@router.post("/connectors/mcp/{connector_id}/trust-tool")
|
||||
async def trust_mcp_tool(
|
||||
connector_id: int,
|
||||
body: MCPTrustToolRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Add a tool to the MCP connector's trusted (always-allow) list.
|
||||
|
||||
Once trusted, the tool executes without HITL approval on subsequent calls.
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.MCP_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(status_code=404, detail="MCP connector not found")
|
||||
|
||||
config = dict(connector.config or {})
|
||||
trusted: list[str] = list(config.get("trusted_tools", []))
|
||||
if body.tool_name not in trusted:
|
||||
trusted.append(body.tool_name)
|
||||
config["trusted_tools"] = trusted
|
||||
connector.config = config
|
||||
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
|
||||
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
||||
|
||||
invalidate_mcp_tools_cache(connector.search_space_id)
|
||||
|
||||
return {"status": "ok", "trusted_tools": trusted}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to trust MCP tool: {e!s}", exc_info=True)
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to trust tool: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/connectors/mcp/{connector_id}/untrust-tool")
|
||||
async def untrust_mcp_tool(
|
||||
connector_id: int,
|
||||
body: MCPTrustToolRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Remove a tool from the MCP connector's trusted list.
|
||||
|
||||
The tool will require HITL approval again on subsequent calls.
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.MCP_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(status_code=404, detail="MCP connector not found")
|
||||
|
||||
config = dict(connector.config or {})
|
||||
trusted: list[str] = list(config.get("trusted_tools", []))
|
||||
if body.tool_name in trusted:
|
||||
trusted.remove(body.tool_name)
|
||||
config["trusted_tools"] = trusted
|
||||
connector.config = config
|
||||
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
|
||||
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
||||
|
||||
invalidate_mcp_tools_cache(connector.search_space_id)
|
||||
|
||||
return {"status": "ok", "trusted_tools": trusted}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to untrust MCP tool: {e!s}", exc_info=True)
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to untrust tool: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -216,6 +216,7 @@ async def read_search_spaces(
|
|||
user_id=space.user_id,
|
||||
citations_enabled=space.citations_enabled,
|
||||
qna_custom_instructions=space.qna_custom_instructions,
|
||||
ai_file_sort_enabled=space.ai_file_sort_enabled,
|
||||
member_count=member_count,
|
||||
is_owner=is_owner,
|
||||
)
|
||||
|
|
@ -384,6 +385,42 @@ async def edit_team_memory(
|
|||
return db_search_space
|
||||
|
||||
|
||||
@router.post("/searchspaces/{search_space_id}/ai-sort")
|
||||
async def trigger_ai_sort(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Trigger a full AI file sort for all documents in the search space."""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.SETTINGS_UPDATE.value,
|
||||
"You don't have permission to trigger AI sort on this search space",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
db_search_space = result.scalars().first()
|
||||
if not db_search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
from app.tasks.celery_tasks.document_tasks import ai_sort_search_space_task
|
||||
|
||||
ai_sort_search_space_task.delay(search_space_id, str(user.id))
|
||||
return {"message": "AI sort started"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to trigger AI sort: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to trigger AI sort: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/searchspaces/{search_space_id}", response_model=dict)
|
||||
async def delete_search_space(
|
||||
search_space_id: int,
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ class SearchSpaceUpdate(BaseModel):
|
|||
citations_enabled: bool | None = None
|
||||
qna_custom_instructions: str | None = None
|
||||
shared_memory_md: str | None = None
|
||||
ai_file_sort_enabled: bool | None = None
|
||||
|
||||
|
||||
class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
|
||||
|
|
@ -31,6 +32,7 @@ class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
|
|||
citations_enabled: bool
|
||||
qna_custom_instructions: str | None = None
|
||||
shared_memory_md: str | None = None
|
||||
ai_file_sort_enabled: bool = False
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
|
|
|||
329
surfsense_backend/app/services/ai_file_sort_service.py
Normal file
329
surfsense_backend/app/services/ai_file_sort_service.py
Normal file
|
|
@ -0,0 +1,329 @@
|
|||
"""AI File Sort Service: builds connector-type/date/category/subcategory folder paths."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.db import (
|
||||
Chunk,
|
||||
Document,
|
||||
DocumentType,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
from app.services.folder_service import ensure_folder_hierarchy_with_depth_validation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DOCTYPE_TO_CONNECTOR_LABEL: dict[str, str] = {
|
||||
DocumentType.EXTENSION: "Browser Extension",
|
||||
DocumentType.CRAWLED_URL: "Web Crawl",
|
||||
DocumentType.FILE: "File Upload",
|
||||
DocumentType.SLACK_CONNECTOR: "Slack",
|
||||
DocumentType.TEAMS_CONNECTOR: "Teams",
|
||||
DocumentType.ONEDRIVE_FILE: "OneDrive",
|
||||
DocumentType.NOTION_CONNECTOR: "Notion",
|
||||
DocumentType.YOUTUBE_VIDEO: "YouTube",
|
||||
DocumentType.GITHUB_CONNECTOR: "GitHub",
|
||||
DocumentType.LINEAR_CONNECTOR: "Linear",
|
||||
DocumentType.DISCORD_CONNECTOR: "Discord",
|
||||
DocumentType.JIRA_CONNECTOR: "Jira",
|
||||
DocumentType.CONFLUENCE_CONNECTOR: "Confluence",
|
||||
DocumentType.CLICKUP_CONNECTOR: "ClickUp",
|
||||
DocumentType.GOOGLE_CALENDAR_CONNECTOR: "Google Calendar",
|
||||
DocumentType.GOOGLE_GMAIL_CONNECTOR: "Gmail",
|
||||
DocumentType.GOOGLE_DRIVE_FILE: "Google Drive",
|
||||
DocumentType.AIRTABLE_CONNECTOR: "Airtable",
|
||||
DocumentType.LUMA_CONNECTOR: "Luma",
|
||||
DocumentType.ELASTICSEARCH_CONNECTOR: "Elasticsearch",
|
||||
DocumentType.BOOKSTACK_CONNECTOR: "BookStack",
|
||||
DocumentType.CIRCLEBACK: "Circleback",
|
||||
DocumentType.OBSIDIAN_CONNECTOR: "Obsidian",
|
||||
DocumentType.NOTE: "Notes",
|
||||
DocumentType.DROPBOX_FILE: "Dropbox",
|
||||
DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR: "Google Drive (Composio)",
|
||||
DocumentType.COMPOSIO_GMAIL_CONNECTOR: "Gmail (Composio)",
|
||||
DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR: "Google Calendar (Composio)",
|
||||
DocumentType.LOCAL_FOLDER_FILE: "Local Folder",
|
||||
}
|
||||
|
||||
_CONNECTOR_TYPE_LABEL: dict[str, str] = {
|
||||
SearchSourceConnectorType.SERPER_API: "Serper Search",
|
||||
SearchSourceConnectorType.TAVILY_API: "Tavily Search",
|
||||
SearchSourceConnectorType.SEARXNG_API: "SearXNG Search",
|
||||
SearchSourceConnectorType.LINKUP_API: "Linkup Search",
|
||||
SearchSourceConnectorType.BAIDU_SEARCH_API: "Baidu Search",
|
||||
SearchSourceConnectorType.SLACK_CONNECTOR: "Slack",
|
||||
SearchSourceConnectorType.TEAMS_CONNECTOR: "Teams",
|
||||
SearchSourceConnectorType.ONEDRIVE_CONNECTOR: "OneDrive",
|
||||
SearchSourceConnectorType.NOTION_CONNECTOR: "Notion",
|
||||
SearchSourceConnectorType.GITHUB_CONNECTOR: "GitHub",
|
||||
SearchSourceConnectorType.LINEAR_CONNECTOR: "Linear",
|
||||
SearchSourceConnectorType.DISCORD_CONNECTOR: "Discord",
|
||||
SearchSourceConnectorType.JIRA_CONNECTOR: "Jira",
|
||||
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: "Confluence",
|
||||
SearchSourceConnectorType.CLICKUP_CONNECTOR: "ClickUp",
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: "Google Calendar",
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: "Gmail",
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR: "Google Drive",
|
||||
SearchSourceConnectorType.AIRTABLE_CONNECTOR: "Airtable",
|
||||
SearchSourceConnectorType.LUMA_CONNECTOR: "Luma",
|
||||
SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: "Elasticsearch",
|
||||
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: "Web Crawl",
|
||||
SearchSourceConnectorType.BOOKSTACK_CONNECTOR: "BookStack",
|
||||
SearchSourceConnectorType.CIRCLEBACK_CONNECTOR: "Circleback",
|
||||
SearchSourceConnectorType.OBSIDIAN_CONNECTOR: "Obsidian",
|
||||
SearchSourceConnectorType.MCP_CONNECTOR: "MCP",
|
||||
SearchSourceConnectorType.DROPBOX_CONNECTOR: "Dropbox",
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR: "Google Drive (Composio)",
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR: "Gmail (Composio)",
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR: "Google Calendar (Composio)",
|
||||
}
|
||||
|
||||
_MAX_CONTENT_CHARS = 4000
|
||||
_MAX_CHUNKS_FOR_CONTEXT = 5
|
||||
|
||||
_CATEGORY_PROMPT = (
|
||||
"Based on the document information below, classify it into a broad category "
|
||||
"and a more specific subcategory.\n\n"
|
||||
"Rules:\n"
|
||||
"- category: 1-2 word broad theme (e.g. Science, Finance, Engineering, Communication, Media)\n"
|
||||
"- subcategory: 1-2 word specific topic within the category "
|
||||
"(e.g. Physics, Tax Reports, Backend, Team Updates)\n"
|
||||
"- Use nouns only. Do not include generic terms like 'General' or 'Miscellaneous'.\n\n"
|
||||
"Title: {title}\n\n"
|
||||
"Content: {summary}\n\n"
|
||||
'Respond with ONLY a JSON object: {{"category": "...", "subcategory": "..."}}'
|
||||
)
|
||||
|
||||
_SAFE_NAME_RE = re.compile(r"[^a-zA-Z0-9 _\-()]")
|
||||
_FALLBACK_CATEGORY = "Uncategorized"
|
||||
_FALLBACK_SUBCATEGORY = "General"
|
||||
|
||||
|
||||
def resolve_root_folder_label(
|
||||
document: Document, connector: SearchSourceConnector | None
|
||||
) -> str:
|
||||
if connector is not None:
|
||||
return _CONNECTOR_TYPE_LABEL.get(
|
||||
connector.connector_type, str(connector.connector_type)
|
||||
)
|
||||
return _DOCTYPE_TO_CONNECTOR_LABEL.get(
|
||||
document.document_type, str(document.document_type)
|
||||
)
|
||||
|
||||
|
||||
def resolve_date_folder(document: Document) -> str:
|
||||
ts = document.updated_at or document.created_at
|
||||
if ts is None:
|
||||
ts = datetime.now(UTC)
|
||||
return ts.strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
def sanitize_category_folder_name(
|
||||
value: str | None, fallback: str = _FALLBACK_CATEGORY
|
||||
) -> str:
|
||||
if not value or not value.strip():
|
||||
return fallback
|
||||
cleaned = _SAFE_NAME_RE.sub("", value.strip())
|
||||
cleaned = " ".join(cleaned.split())
|
||||
if not cleaned:
|
||||
return fallback
|
||||
return cleaned[:50]
|
||||
|
||||
|
||||
async def _resolve_document_text(
|
||||
session: AsyncSession,
|
||||
document: Document,
|
||||
) -> str:
|
||||
"""Build the best available text representation for taxonomy generation.
|
||||
|
||||
Prefers ``document.content``; falls back to joining the first few chunks
|
||||
when content is empty or too short to be useful.
|
||||
"""
|
||||
text = (document.content or "").strip()
|
||||
if len(text) >= 100:
|
||||
return text[:_MAX_CONTENT_CHARS]
|
||||
|
||||
stmt = (
|
||||
select(Chunk.content)
|
||||
.where(Chunk.document_id == document.id)
|
||||
.order_by(Chunk.id)
|
||||
.limit(_MAX_CHUNKS_FOR_CONTEXT)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
chunk_texts = [row[0] for row in result.all() if row[0]]
|
||||
if chunk_texts:
|
||||
combined = "\n\n".join(chunk_texts)
|
||||
return combined[:_MAX_CONTENT_CHARS]
|
||||
|
||||
return text[:_MAX_CONTENT_CHARS]
|
||||
|
||||
|
||||
def _get_cached_taxonomy(document: Document) -> tuple[str, str] | None:
|
||||
"""Return (category, subcategory) from document metadata cache, or None."""
|
||||
meta = document.document_metadata
|
||||
if not isinstance(meta, dict):
|
||||
return None
|
||||
cat = meta.get("ai_sort_category")
|
||||
subcat = meta.get("ai_sort_subcategory")
|
||||
if cat and subcat and isinstance(cat, str) and isinstance(subcat, str):
|
||||
return cat, subcat
|
||||
return None
|
||||
|
||||
|
||||
def _set_cached_taxonomy(document: Document, category: str, subcategory: str) -> None:
|
||||
"""Persist the AI taxonomy on document metadata for deterministic re-sorts."""
|
||||
meta = dict(document.document_metadata or {})
|
||||
meta["ai_sort_category"] = category
|
||||
meta["ai_sort_subcategory"] = subcategory
|
||||
document.document_metadata = meta
|
||||
|
||||
|
||||
async def generate_ai_taxonomy(
|
||||
title: str,
|
||||
summary_or_content: str,
|
||||
llm,
|
||||
) -> tuple[str, str]:
|
||||
"""Return (category, subcategory) using a single structured LLM call."""
|
||||
text = (summary_or_content or "").strip()
|
||||
if not text:
|
||||
return _FALLBACK_CATEGORY, _FALLBACK_SUBCATEGORY
|
||||
|
||||
if len(text) > _MAX_CONTENT_CHARS:
|
||||
text = text[:_MAX_CONTENT_CHARS]
|
||||
|
||||
prompt = _CATEGORY_PROMPT.format(title=title or "Untitled", summary=text)
|
||||
try:
|
||||
result = await llm.ainvoke([HumanMessage(content=prompt)])
|
||||
raw = result.content.strip()
|
||||
if raw.startswith("```"):
|
||||
raw = raw.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
|
||||
parsed = json.loads(raw)
|
||||
category = sanitize_category_folder_name(
|
||||
parsed.get("category"), _FALLBACK_CATEGORY
|
||||
)
|
||||
subcategory = sanitize_category_folder_name(
|
||||
parsed.get("subcategory"), _FALLBACK_SUBCATEGORY
|
||||
)
|
||||
return category, subcategory
|
||||
except Exception:
|
||||
logger.warning("AI taxonomy generation failed, using fallback", exc_info=True)
|
||||
return _FALLBACK_CATEGORY, _FALLBACK_SUBCATEGORY
|
||||
|
||||
|
||||
def _build_path_segments(
|
||||
root_label: str,
|
||||
date_label: str,
|
||||
category: str,
|
||||
subcategory: str,
|
||||
) -> list[dict]:
|
||||
return [
|
||||
{"name": root_label, "metadata": {"ai_sort": True, "ai_sort_level": 1}},
|
||||
{"name": date_label, "metadata": {"ai_sort": True, "ai_sort_level": 2}},
|
||||
{"name": category, "metadata": {"ai_sort": True, "ai_sort_level": 3}},
|
||||
{"name": subcategory, "metadata": {"ai_sort": True, "ai_sort_level": 4}},
|
||||
]
|
||||
|
||||
|
||||
async def _resolve_taxonomy(
|
||||
session: AsyncSession,
|
||||
document: Document,
|
||||
llm,
|
||||
) -> tuple[str, str]:
|
||||
"""Return (category, subcategory), reusing cached values when available."""
|
||||
cached = _get_cached_taxonomy(document)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
content_text = await _resolve_document_text(session, document)
|
||||
category, subcategory = await generate_ai_taxonomy(
|
||||
document.title, content_text, llm
|
||||
)
|
||||
_set_cached_taxonomy(document, category, subcategory)
|
||||
return category, subcategory
|
||||
|
||||
|
||||
async def ai_sort_document(
|
||||
session: AsyncSession,
|
||||
document: Document,
|
||||
llm,
|
||||
) -> Document:
|
||||
"""Sort a single document into the 4-level AI folder hierarchy."""
|
||||
connector: SearchSourceConnector | None = None
|
||||
if document.connector_id is not None:
|
||||
connector = await session.get(SearchSourceConnector, document.connector_id)
|
||||
|
||||
root_label = resolve_root_folder_label(document, connector)
|
||||
date_label = resolve_date_folder(document)
|
||||
|
||||
category, subcategory = await _resolve_taxonomy(session, document, llm)
|
||||
|
||||
segments = _build_path_segments(root_label, date_label, category, subcategory)
|
||||
|
||||
leaf_folder = await ensure_folder_hierarchy_with_depth_validation(
|
||||
session,
|
||||
document.search_space_id,
|
||||
segments,
|
||||
)
|
||||
|
||||
document.folder_id = leaf_folder.id
|
||||
await session.flush()
|
||||
return document
|
||||
|
||||
|
||||
async def ai_sort_all_documents(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
llm,
|
||||
) -> tuple[int, int]:
|
||||
"""Sort all documents in a search space. Returns (sorted_count, failed_count)."""
|
||||
stmt = (
|
||||
select(Document)
|
||||
.where(Document.search_space_id == search_space_id)
|
||||
.options(selectinload(Document.connector))
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
documents = list(result.scalars().all())
|
||||
|
||||
sorted_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for doc in documents:
|
||||
try:
|
||||
connector = doc.connector
|
||||
root_label = resolve_root_folder_label(doc, connector)
|
||||
date_label = resolve_date_folder(doc)
|
||||
|
||||
category, subcategory = await _resolve_taxonomy(session, doc, llm)
|
||||
segments = _build_path_segments(
|
||||
root_label, date_label, category, subcategory
|
||||
)
|
||||
|
||||
leaf_folder = await ensure_folder_hierarchy_with_depth_validation(
|
||||
session,
|
||||
search_space_id,
|
||||
segments,
|
||||
)
|
||||
doc.folder_id = leaf_folder.id
|
||||
sorted_count += 1
|
||||
except Exception:
|
||||
logger.error("Failed to AI-sort document %s", doc.id, exc_info=True)
|
||||
failed_count += 1
|
||||
|
||||
await session.commit()
|
||||
logger.info(
|
||||
"AI sort complete for search_space=%d: sorted=%d, failed=%d",
|
||||
search_space_id,
|
||||
sorted_count,
|
||||
failed_count,
|
||||
)
|
||||
return sorted_count, failed_count
|
||||
|
|
@ -142,6 +142,58 @@ async def generate_folder_position(
|
|||
return generate_key_between(last_position, None)
|
||||
|
||||
|
||||
async def ensure_folder_hierarchy_with_depth_validation(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
path_segments: list[dict],
|
||||
) -> Folder:
|
||||
"""Create or return a nested folder chain, validating depth at each step.
|
||||
|
||||
Each item in ``path_segments`` is a dict with:
|
||||
- ``name`` (str): folder display name
|
||||
- ``metadata`` (dict | None): optional ``folder_metadata`` JSONB payload
|
||||
|
||||
Returns the deepest (leaf) Folder in the chain.
|
||||
"""
|
||||
parent_id: int | None = None
|
||||
current_folder: Folder | None = None
|
||||
|
||||
for segment in path_segments:
|
||||
name = segment["name"]
|
||||
metadata = segment.get("metadata")
|
||||
|
||||
stmt = select(Folder).where(
|
||||
Folder.search_space_id == search_space_id,
|
||||
Folder.name == name,
|
||||
Folder.parent_id == parent_id
|
||||
if parent_id is not None
|
||||
else Folder.parent_id.is_(None),
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
folder = result.scalar_one_or_none()
|
||||
|
||||
if folder is None:
|
||||
await validate_folder_depth(session, parent_id, subtree_depth=0)
|
||||
position = await generate_folder_position(
|
||||
session, search_space_id, parent_id
|
||||
)
|
||||
folder = Folder(
|
||||
name=name,
|
||||
search_space_id=search_space_id,
|
||||
parent_id=parent_id,
|
||||
position=position,
|
||||
folder_metadata=metadata,
|
||||
)
|
||||
session.add(folder)
|
||||
await session.flush()
|
||||
|
||||
current_folder = folder
|
||||
parent_id = folder.id
|
||||
|
||||
assert current_folder is not None, "path_segments must not be empty"
|
||||
return current_folder
|
||||
|
||||
|
||||
async def get_folder_subtree_ids(session: AsyncSession, folder_id: int) -> list[int]:
|
||||
"""Return all folder IDs in the subtree rooted at folder_id (inclusive)."""
|
||||
result = await session.execute(
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import asyncio
|
|||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from uuid import UUID
|
||||
|
||||
from app.celery_app import celery_app
|
||||
|
|
@ -1551,3 +1552,121 @@ async def _index_uploaded_folder_files_async(
|
|||
heartbeat_task.cancel()
|
||||
if notification_id is not None:
|
||||
_stop_heartbeat(notification_id)
|
||||
|
||||
|
||||
# ===== AI File Sort tasks =====
|
||||
|
||||
AI_SORT_LOCK_TTL_SECONDS = 600 # 10 minutes
|
||||
_ai_sort_redis = None
|
||||
|
||||
|
||||
def _get_ai_sort_redis():
|
||||
import redis
|
||||
|
||||
global _ai_sort_redis
|
||||
if _ai_sort_redis is None:
|
||||
_ai_sort_redis = redis.from_url(config.REDIS_APP_URL, decode_responses=True)
|
||||
return _ai_sort_redis
|
||||
|
||||
|
||||
def _ai_sort_lock_key(search_space_id: int) -> str:
|
||||
return f"ai_sort:search_space:{search_space_id}:lock"
|
||||
|
||||
|
||||
@celery_app.task(name="ai_sort_search_space", bind=True, max_retries=1)
|
||||
def ai_sort_search_space_task(self, search_space_id: int, user_id: str):
|
||||
"""Full AI sort for all documents in a search space."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(_ai_sort_search_space_async(search_space_id, user_id))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
|
||||
r = _get_ai_sort_redis()
|
||||
lock_key = _ai_sort_lock_key(search_space_id)
|
||||
|
||||
if not r.set(lock_key, "running", nx=True, ex=AI_SORT_LOCK_TTL_SECONDS):
|
||||
logger.info(
|
||||
"AI sort already running for search_space=%d, skipping",
|
||||
search_space_id,
|
||||
)
|
||||
return
|
||||
|
||||
t_start = time.perf_counter()
|
||||
try:
|
||||
from app.services.ai_file_sort_service import ai_sort_all_documents
|
||||
from app.services.llm_service import get_document_summary_llm
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
llm = await get_document_summary_llm(
|
||||
session, search_space_id, disable_streaming=True
|
||||
)
|
||||
if llm is None:
|
||||
logger.warning(
|
||||
"No LLM configured for search_space=%d, skipping AI sort",
|
||||
search_space_id,
|
||||
)
|
||||
return
|
||||
|
||||
sorted_count, failed_count = await ai_sort_all_documents(
|
||||
session, search_space_id, llm
|
||||
)
|
||||
elapsed = time.perf_counter() - t_start
|
||||
logger.info(
|
||||
"AI sort search_space=%d done in %.1fs: sorted=%d failed=%d",
|
||||
search_space_id,
|
||||
elapsed,
|
||||
sorted_count,
|
||||
failed_count,
|
||||
)
|
||||
finally:
|
||||
r.delete(lock_key)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="ai_sort_document", bind=True, max_retries=2, default_retry_delay=10
|
||||
)
|
||||
def ai_sort_document_task(self, search_space_id: int, user_id: str, document_id: int):
|
||||
"""Incremental AI sort for a single document after indexing."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_ai_sort_document_async(search_space_id, user_id, document_id)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _ai_sort_document_async(search_space_id: int, user_id: str, document_id: int):
|
||||
from app.db import Document
|
||||
from app.services.ai_file_sort_service import ai_sort_document
|
||||
from app.services.llm_service import get_document_summary_llm
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
document = await session.get(Document, document_id)
|
||||
if document is None:
|
||||
logger.warning("Document %d not found, skipping AI sort", document_id)
|
||||
return
|
||||
|
||||
llm = await get_document_summary_llm(
|
||||
session, search_space_id, disable_streaming=True
|
||||
)
|
||||
if llm is None:
|
||||
logger.warning(
|
||||
"No LLM for search_space=%d, skipping AI sort of doc=%d",
|
||||
search_space_id,
|
||||
document_id,
|
||||
)
|
||||
return
|
||||
|
||||
await ai_sort_document(session, document, llm)
|
||||
await session.commit()
|
||||
logger.info(
|
||||
"AI sorted document=%d into search_space=%d",
|
||||
document_id,
|
||||
search_space_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ from app.services.new_streaming_service import VercelStreamingService
|
|||
from app.utils.content_utils import bootstrap_history_from_db
|
||||
from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
|
||||
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
|
|
@ -142,7 +143,7 @@ class StreamResult:
|
|||
accumulated_text: str = ""
|
||||
is_interrupted: bool = False
|
||||
interrupt_value: dict[str, Any] | None = None
|
||||
sandbox_files: list[str] = field(default_factory=list) # unused, kept for compat
|
||||
sandbox_files: list[str] = field(default_factory=list)
|
||||
agent_called_update_memory: bool = False
|
||||
|
||||
|
||||
|
|
@ -440,7 +441,7 @@ async def _stream_agent_events(
|
|||
status="in_progress",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "execute":
|
||||
elif tool_name in ("execute", "execute_code"):
|
||||
cmd = (
|
||||
tool_input.get("command", "")
|
||||
if isinstance(tool_input, dict)
|
||||
|
|
@ -738,7 +739,7 @@ async def _stream_agent_events(
|
|||
status="completed",
|
||||
items=completed_items,
|
||||
)
|
||||
elif tool_name == "execute":
|
||||
elif tool_name in ("execute", "execute_code"):
|
||||
raw_text = (
|
||||
tool_output.get("result", "")
|
||||
if isinstance(tool_output, dict)
|
||||
|
|
@ -985,7 +986,7 @@ async def _stream_agent_events(
|
|||
if isinstance(tool_output, dict)
|
||||
else {"result": tool_output},
|
||||
)
|
||||
elif tool_name == "execute":
|
||||
elif tool_name in ("execute", "execute_code"):
|
||||
raw_text = (
|
||||
tool_output.get("result", "")
|
||||
if isinstance(tool_output, dict)
|
||||
|
|
@ -1552,7 +1553,7 @@ async def stream_new_chat(
|
|||
# Shared threads write to team memory; private threads write to user memory.
|
||||
if not stream_result.agent_called_update_memory:
|
||||
if visibility == ChatVisibility.SEARCH_SPACE:
|
||||
asyncio.create_task(
|
||||
task = asyncio.create_task(
|
||||
extract_and_save_team_memory(
|
||||
user_message=user_query,
|
||||
search_space_id=search_space_id,
|
||||
|
|
@ -1560,14 +1561,18 @@ async def stream_new_chat(
|
|||
author_display_name=current_user_display_name,
|
||||
)
|
||||
)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
elif user_id:
|
||||
asyncio.create_task(
|
||||
task = asyncio.create_task(
|
||||
extract_and_save_memory(
|
||||
user_message=user_query,
|
||||
user_id=user_id,
|
||||
llm=llm,
|
||||
)
|
||||
)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
# Finish the step and message
|
||||
yield streaming_service.format_finish_step()
|
||||
|
|
@ -1617,6 +1622,21 @@ async def stream_new_chat(
|
|||
with contextlib.suppress(Exception):
|
||||
await session.close()
|
||||
|
||||
# Persist any sandbox-produced files to local storage so they
|
||||
# remain downloadable after the Daytona sandbox auto-deletes.
|
||||
if stream_result and stream_result.sandbox_files:
|
||||
with contextlib.suppress(Exception):
|
||||
from app.agents.new_chat.sandbox import (
|
||||
is_sandbox_enabled,
|
||||
persist_and_delete_sandbox,
|
||||
)
|
||||
|
||||
if is_sandbox_enabled():
|
||||
with anyio.CancelScope(shield=True):
|
||||
await persist_and_delete_sandbox(
|
||||
chat_id, stream_result.sandbox_files
|
||||
)
|
||||
|
||||
# Break circular refs held by the agent graph, tools, and LLM
|
||||
# wrappers so the GC can reclaim them in a single pass.
|
||||
agent = llm = connector_service = None
|
||||
|
|
|
|||
|
|
@ -961,6 +961,7 @@ async def index_google_drive_files(
|
|||
vision_llm = None
|
||||
if connector_enable_vision_llm:
|
||||
from app.services.llm_service import get_vision_llm
|
||||
|
||||
vision_llm = await get_vision_llm(session, search_space_id)
|
||||
drive_client = GoogleDriveClient(
|
||||
session, connector_id, credentials=pre_built_credentials
|
||||
|
|
@ -1168,6 +1169,7 @@ async def index_google_drive_single_file(
|
|||
vision_llm = None
|
||||
if connector_enable_vision_llm:
|
||||
from app.services.llm_service import get_vision_llm
|
||||
|
||||
vision_llm = await get_vision_llm(session, search_space_id)
|
||||
drive_client = GoogleDriveClient(
|
||||
session, connector_id, credentials=pre_built_credentials
|
||||
|
|
@ -1306,6 +1308,7 @@ async def index_google_drive_selected_files(
|
|||
vision_llm = None
|
||||
if connector_enable_vision_llm:
|
||||
from app.services.llm_service import get_vision_llm
|
||||
|
||||
vision_llm = await get_vision_llm(session, search_space_id)
|
||||
drive_client = GoogleDriveClient(
|
||||
session, connector_id, credentials=pre_built_credentials
|
||||
|
|
|
|||
|
|
@ -1360,7 +1360,9 @@ async def index_uploaded_files(
|
|||
|
||||
try:
|
||||
content, content_hash = await _compute_file_content_hash(
|
||||
temp_path, filename, search_space_id,
|
||||
temp_path,
|
||||
filename,
|
||||
search_space_id,
|
||||
vision_llm=vision_llm_instance,
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -656,6 +656,7 @@ async def index_onedrive_files(
|
|||
vision_llm = None
|
||||
if connector_enable_vision_llm:
|
||||
from app.services.llm_service import get_vision_llm
|
||||
|
||||
vision_llm = await get_vision_llm(session, search_space_id)
|
||||
|
||||
onedrive_client = OneDriveClient(session, connector_id)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,11 @@ from pathlib import Path
|
|||
from dotenv import load_dotenv
|
||||
|
||||
_here = Path(__file__).parent
|
||||
for candidate in [_here / "../surfsense_backend/.env", _here / ".env", _here / "../.env"]:
|
||||
for candidate in [
|
||||
_here / "../surfsense_backend/.env",
|
||||
_here / ".env",
|
||||
_here / "../.env",
|
||||
]:
|
||||
if candidate.exists():
|
||||
load_dotenv(candidate)
|
||||
break
|
||||
|
|
@ -57,7 +61,10 @@ def main() -> None:
|
|||
api_key = os.environ.get("DAYTONA_API_KEY")
|
||||
if not api_key:
|
||||
print("ERROR: DAYTONA_API_KEY is not set.", file=sys.stderr)
|
||||
print("Add it to surfsense_backend/.env or export it in your shell.", file=sys.stderr)
|
||||
print(
|
||||
"Add it to surfsense_backend/.env or export it in your shell.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
daytona = Daytona()
|
||||
|
|
@ -67,7 +74,7 @@ def main() -> None:
|
|||
print(f"Deleting existing snapshot '{SNAPSHOT_NAME}' …")
|
||||
daytona.snapshot.delete(existing)
|
||||
print(f"Deleted '{SNAPSHOT_NAME}'. Waiting for removal to propagate …")
|
||||
for attempt in range(30):
|
||||
for _attempt in range(30):
|
||||
time.sleep(2)
|
||||
try:
|
||||
daytona.snapshot.get(SNAPSHOT_NAME)
|
||||
|
|
@ -75,7 +82,9 @@ def main() -> None:
|
|||
print(f"Confirmed '{SNAPSHOT_NAME}' is gone.\n")
|
||||
break
|
||||
else:
|
||||
print(f"WARNING: '{SNAPSHOT_NAME}' may still exist after 60s. Proceeding anyway.\n")
|
||||
print(
|
||||
f"WARNING: '{SNAPSHOT_NAME}' may still exist after 60s. Proceeding anyway.\n"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -431,7 +431,9 @@ async def test_llamacloud_heif_accepted_only_with_azure_di(tmp_path, mocker):
|
|||
mocker.patch("app.config.config.AZURE_DI_ENDPOINT", None, create=True)
|
||||
mocker.patch("app.config.config.AZURE_DI_KEY", None, create=True)
|
||||
|
||||
with pytest.raises(EtlUnsupportedFileError, match="document parser does not support this format"):
|
||||
with pytest.raises(
|
||||
EtlUnsupportedFileError, match="document parser does not support this format"
|
||||
):
|
||||
await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=str(heif_file), filename="photo.heif")
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from app.agents.new_chat.middleware.knowledge_search import (
|
||||
KBSearchPlan,
|
||||
KnowledgeBaseSearchMiddleware,
|
||||
_build_document_xml,
|
||||
_normalize_optional_date_range,
|
||||
|
|
@ -366,3 +367,146 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
|
|||
assert captured["query"] == "deel founders guide summary"
|
||||
assert captured["start_date"] is None
|
||||
assert captured["end_date"] is None
|
||||
|
||||
async def test_middleware_routes_to_recency_browse_when_flagged(
|
||||
self,
|
||||
monkeypatch,
|
||||
):
|
||||
"""When the planner sets is_recency_query=true, browse_recent_documents
|
||||
is called instead of search_knowledge_base."""
|
||||
browse_captured: dict = {}
|
||||
search_called = False
|
||||
|
||||
async def fake_browse_recent_documents(**kwargs):
|
||||
browse_captured.update(kwargs)
|
||||
return []
|
||||
|
||||
async def fake_search_knowledge_base(**kwargs):
|
||||
nonlocal search_called
|
||||
search_called = True
|
||||
return []
|
||||
|
||||
async def fake_build_scoped_filesystem(**kwargs):
|
||||
return {}, {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.browse_recent_documents",
|
||||
fake_browse_recent_documents,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
|
||||
fake_build_scoped_filesystem,
|
||||
)
|
||||
|
||||
llm = FakeLLM(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "latest uploaded file",
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"is_recency_query": True,
|
||||
}
|
||||
)
|
||||
)
|
||||
middleware = KnowledgeBaseSearchMiddleware(llm=llm, search_space_id=42)
|
||||
|
||||
result = await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="what's my latest file?")]},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert browse_captured["search_space_id"] == 42
|
||||
assert not search_called
|
||||
|
||||
async def test_middleware_uses_hybrid_search_when_not_recency(
|
||||
self,
|
||||
monkeypatch,
|
||||
):
|
||||
"""When is_recency_query is false (default), hybrid search is used."""
|
||||
search_captured: dict = {}
|
||||
browse_called = False
|
||||
|
||||
async def fake_browse_recent_documents(**kwargs):
|
||||
nonlocal browse_called
|
||||
browse_called = True
|
||||
return []
|
||||
|
||||
async def fake_search_knowledge_base(**kwargs):
|
||||
search_captured.update(kwargs)
|
||||
return []
|
||||
|
||||
async def fake_build_scoped_filesystem(**kwargs):
|
||||
return {}, {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.browse_recent_documents",
|
||||
fake_browse_recent_documents,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
|
||||
fake_build_scoped_filesystem,
|
||||
)
|
||||
|
||||
llm = FakeLLM(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "quarterly revenue report analysis",
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"is_recency_query": False,
|
||||
}
|
||||
)
|
||||
)
|
||||
middleware = KnowledgeBaseSearchMiddleware(llm=llm, search_space_id=42)
|
||||
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="find the quarterly revenue report")]},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
assert search_captured["query"] == "quarterly revenue report analysis"
|
||||
assert not browse_called
|
||||
|
||||
|
||||
# ── KBSearchPlan schema ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestKBSearchPlanSchema:
|
||||
def test_is_recency_query_defaults_to_false(self):
|
||||
plan = KBSearchPlan(optimized_query="test query")
|
||||
assert plan.is_recency_query is False
|
||||
|
||||
def test_is_recency_query_parses_true(self):
|
||||
plan = _parse_kb_search_plan_response(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "latest uploaded file",
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"is_recency_query": True,
|
||||
}
|
||||
)
|
||||
)
|
||||
assert plan.is_recency_query is True
|
||||
assert plan.optimized_query == "latest uploaded file"
|
||||
|
||||
def test_missing_is_recency_query_defaults_to_false(self):
|
||||
plan = _parse_kb_search_plan_response(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "meeting notes",
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
}
|
||||
)
|
||||
)
|
||||
assert plan.is_recency_query is False
|
||||
|
|
|
|||
|
|
@ -0,0 +1,275 @@
|
|||
"""Unit tests for AI file sort service: folder label resolution, date extraction, category sanitization."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ── resolve_root_folder_label ──
|
||||
|
||||
|
||||
def _make_document(document_type: str, connector_id=None):
|
||||
doc = MagicMock()
|
||||
doc.document_type = document_type
|
||||
doc.connector_id = connector_id
|
||||
return doc
|
||||
|
||||
|
||||
def _make_connector(connector_type: str):
|
||||
conn = MagicMock()
|
||||
conn.connector_type = connector_type
|
||||
return conn
|
||||
|
||||
|
||||
def test_root_label_uses_connector_type_when_available():
|
||||
from app.services.ai_file_sort_service import resolve_root_folder_label
|
||||
|
||||
doc = _make_document("FILE", connector_id=1)
|
||||
conn = _make_connector("GOOGLE_DRIVE_CONNECTOR")
|
||||
assert resolve_root_folder_label(doc, conn) == "Google Drive"
|
||||
|
||||
|
||||
def test_root_label_falls_back_to_document_type():
|
||||
from app.services.ai_file_sort_service import resolve_root_folder_label
|
||||
|
||||
doc = _make_document("SLACK_CONNECTOR")
|
||||
assert resolve_root_folder_label(doc, None) == "Slack"
|
||||
|
||||
|
||||
def test_root_label_unknown_doctype_returns_raw_value():
|
||||
from app.services.ai_file_sort_service import resolve_root_folder_label
|
||||
|
||||
doc = _make_document("UNKNOWN_TYPE")
|
||||
assert resolve_root_folder_label(doc, None) == "UNKNOWN_TYPE"
|
||||
|
||||
|
||||
# ── resolve_date_folder ──
|
||||
|
||||
|
||||
def test_date_folder_from_updated_at():
|
||||
from app.services.ai_file_sort_service import resolve_date_folder
|
||||
|
||||
doc = MagicMock()
|
||||
doc.updated_at = datetime(2025, 3, 15, 10, 30, 0, tzinfo=UTC)
|
||||
doc.created_at = datetime(2025, 1, 1, 0, 0, 0, tzinfo=UTC)
|
||||
assert resolve_date_folder(doc) == "2025-03-15"
|
||||
|
||||
|
||||
def test_date_folder_falls_back_to_created_at():
|
||||
from app.services.ai_file_sort_service import resolve_date_folder
|
||||
|
||||
doc = MagicMock()
|
||||
doc.updated_at = None
|
||||
doc.created_at = datetime(2024, 12, 25, 23, 59, 0, tzinfo=UTC)
|
||||
assert resolve_date_folder(doc) == "2024-12-25"
|
||||
|
||||
|
||||
def test_date_folder_both_none_uses_today():
|
||||
from app.services.ai_file_sort_service import resolve_date_folder
|
||||
|
||||
doc = MagicMock()
|
||||
doc.updated_at = None
|
||||
doc.created_at = None
|
||||
result = resolve_date_folder(doc)
|
||||
today = datetime.now(UTC).strftime("%Y-%m-%d")
|
||||
assert result == today
|
||||
|
||||
|
||||
# ── sanitize_category_folder_name ──
|
||||
|
||||
|
||||
def test_sanitize_normal_value():
|
||||
from app.services.ai_file_sort_service import sanitize_category_folder_name
|
||||
|
||||
assert sanitize_category_folder_name("Machine Learning") == "Machine Learning"
|
||||
|
||||
|
||||
def test_sanitize_strips_special_chars():
|
||||
from app.services.ai_file_sort_service import sanitize_category_folder_name
|
||||
|
||||
assert sanitize_category_folder_name("Tax/Reports!") == "TaxReports"
|
||||
|
||||
|
||||
def test_sanitize_empty_returns_fallback():
|
||||
from app.services.ai_file_sort_service import sanitize_category_folder_name
|
||||
|
||||
assert sanitize_category_folder_name("") == "Uncategorized"
|
||||
assert sanitize_category_folder_name(None) == "Uncategorized"
|
||||
|
||||
|
||||
def test_sanitize_truncates_long_names():
|
||||
from app.services.ai_file_sort_service import sanitize_category_folder_name
|
||||
|
||||
long_name = "A" * 100
|
||||
result = sanitize_category_folder_name(long_name)
|
||||
assert len(result) <= 50
|
||||
|
||||
|
||||
# ── generate_ai_taxonomy ──
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_ai_taxonomy_parses_json():
|
||||
from app.services.ai_file_sort_service import generate_ai_taxonomy
|
||||
|
||||
mock_llm = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.content = '{"category": "Science", "subcategory": "Physics"}'
|
||||
mock_llm.ainvoke.return_value = mock_result
|
||||
|
||||
cat, sub = await generate_ai_taxonomy(
|
||||
"Physics Paper", "Some science document about physics", mock_llm
|
||||
)
|
||||
assert cat == "Science"
|
||||
assert sub == "Physics"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_ai_taxonomy_handles_markdown_code_block():
|
||||
from app.services.ai_file_sort_service import generate_ai_taxonomy
|
||||
|
||||
mock_llm = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.content = (
|
||||
'```json\n{"category": "Finance", "subcategory": "Tax Reports"}\n```'
|
||||
)
|
||||
mock_llm.ainvoke.return_value = mock_result
|
||||
|
||||
cat, sub = await generate_ai_taxonomy("Tax Doc", "A tax report document", mock_llm)
|
||||
assert cat == "Finance"
|
||||
assert sub == "Tax Reports"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_ai_taxonomy_includes_title_in_prompt():
|
||||
from app.services.ai_file_sort_service import generate_ai_taxonomy
|
||||
|
||||
mock_llm = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.content = '{"category": "Engineering", "subcategory": "Backend"}'
|
||||
mock_llm.ainvoke.return_value = mock_result
|
||||
|
||||
await generate_ai_taxonomy("API Design Guide", "content about REST APIs", mock_llm)
|
||||
|
||||
prompt_text = mock_llm.ainvoke.call_args[0][0][0].content
|
||||
assert "API Design Guide" in prompt_text
|
||||
assert "content about REST APIs" in prompt_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_ai_taxonomy_fallback_on_error():
|
||||
from app.services.ai_file_sort_service import generate_ai_taxonomy
|
||||
|
||||
mock_llm = AsyncMock()
|
||||
mock_llm.ainvoke.side_effect = RuntimeError("LLM down")
|
||||
|
||||
cat, sub = await generate_ai_taxonomy("Title", "some content", mock_llm)
|
||||
assert cat == "Uncategorized"
|
||||
assert sub == "General"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_ai_taxonomy_fallback_on_empty_content():
|
||||
from app.services.ai_file_sort_service import generate_ai_taxonomy
|
||||
|
||||
mock_llm = AsyncMock()
|
||||
cat, sub = await generate_ai_taxonomy("Title", "", mock_llm)
|
||||
assert cat == "Uncategorized"
|
||||
assert sub == "General"
|
||||
mock_llm.ainvoke.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_ai_taxonomy_fallback_on_invalid_json():
|
||||
from app.services.ai_file_sort_service import generate_ai_taxonomy
|
||||
|
||||
mock_llm = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.content = "not valid json at all"
|
||||
mock_llm.ainvoke.return_value = mock_result
|
||||
|
||||
cat, sub = await generate_ai_taxonomy("Title", "some content", mock_llm)
|
||||
assert cat == "Uncategorized"
|
||||
assert sub == "General"
|
||||
|
||||
|
||||
# ── taxonomy caching ──
|
||||
|
||||
|
||||
def test_get_cached_taxonomy_returns_none_when_no_metadata():
|
||||
from app.services.ai_file_sort_service import _get_cached_taxonomy
|
||||
|
||||
doc = MagicMock()
|
||||
doc.document_metadata = None
|
||||
assert _get_cached_taxonomy(doc) is None
|
||||
|
||||
|
||||
def test_get_cached_taxonomy_returns_none_when_keys_missing():
|
||||
from app.services.ai_file_sort_service import _get_cached_taxonomy
|
||||
|
||||
doc = MagicMock()
|
||||
doc.document_metadata = {"some_other_key": "value"}
|
||||
assert _get_cached_taxonomy(doc) is None
|
||||
|
||||
|
||||
def test_get_cached_taxonomy_returns_cached_values():
|
||||
from app.services.ai_file_sort_service import _get_cached_taxonomy
|
||||
|
||||
doc = MagicMock()
|
||||
doc.document_metadata = {
|
||||
"ai_sort_category": "Finance",
|
||||
"ai_sort_subcategory": "Tax Reports",
|
||||
}
|
||||
assert _get_cached_taxonomy(doc) == ("Finance", "Tax Reports")
|
||||
|
||||
|
||||
def test_set_cached_taxonomy_persists_on_metadata():
|
||||
from app.services.ai_file_sort_service import _set_cached_taxonomy
|
||||
|
||||
doc = MagicMock()
|
||||
doc.document_metadata = {"existing_key": "keep_me"}
|
||||
_set_cached_taxonomy(doc, "Science", "Physics")
|
||||
assert doc.document_metadata["ai_sort_category"] == "Science"
|
||||
assert doc.document_metadata["ai_sort_subcategory"] == "Physics"
|
||||
assert doc.document_metadata["existing_key"] == "keep_me"
|
||||
|
||||
|
||||
def test_set_cached_taxonomy_creates_metadata_when_none():
|
||||
from app.services.ai_file_sort_service import _set_cached_taxonomy
|
||||
|
||||
doc = MagicMock()
|
||||
doc.document_metadata = None
|
||||
_set_cached_taxonomy(doc, "Engineering", "Backend")
|
||||
assert doc.document_metadata == {
|
||||
"ai_sort_category": "Engineering",
|
||||
"ai_sort_subcategory": "Backend",
|
||||
}
|
||||
|
||||
|
||||
# ── _build_path_segments ──
|
||||
|
||||
|
||||
def test_build_path_segments_structure():
|
||||
from app.services.ai_file_sort_service import _build_path_segments
|
||||
|
||||
segments = _build_path_segments("Google Drive", "2025-03-15", "Science", "Physics")
|
||||
assert len(segments) == 4
|
||||
assert segments[0] == {
|
||||
"name": "Google Drive",
|
||||
"metadata": {"ai_sort": True, "ai_sort_level": 1},
|
||||
}
|
||||
assert segments[1] == {
|
||||
"name": "2025-03-15",
|
||||
"metadata": {"ai_sort": True, "ai_sort_level": 2},
|
||||
}
|
||||
assert segments[2] == {
|
||||
"name": "Science",
|
||||
"metadata": {"ai_sort": True, "ai_sort_level": 3},
|
||||
}
|
||||
assert segments[3] == {
|
||||
"name": "Physics",
|
||||
"metadata": {"ai_sort": True, "ai_sort_level": 4},
|
||||
}
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
"""Unit tests for AI sort task Redis deduplication lock."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_lock_key_format():
|
||||
from app.tasks.celery_tasks.document_tasks import _ai_sort_lock_key
|
||||
|
||||
key = _ai_sort_lock_key(42)
|
||||
assert key == "ai_sort:search_space:42:lock"
|
||||
|
||||
|
||||
def test_lock_prevents_duplicate_run():
|
||||
"""When the Redis lock already exists, the task should skip execution."""
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.set.return_value = False # Lock already held
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.tasks.celery_tasks.document_tasks._get_ai_sort_redis",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
patch(
|
||||
"app.tasks.celery_tasks.document_tasks.get_celery_session_maker"
|
||||
) as mock_session_maker,
|
||||
):
|
||||
import asyncio
|
||||
|
||||
from app.tasks.celery_tasks.document_tasks import _ai_sort_search_space_async
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
loop.run_until_complete(_ai_sort_search_space_async(1, "user-123"))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Session maker should never be called since lock was not acquired
|
||||
mock_session_maker.assert_not_called()
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
"""Unit tests for ensure_folder_hierarchy_with_depth_validation."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_missing_folders_in_chain():
|
||||
"""Should create all folders when none exist."""
|
||||
from app.services.folder_service import (
|
||||
ensure_folder_hierarchy_with_depth_validation,
|
||||
)
|
||||
|
||||
session = AsyncMock()
|
||||
# All lookups return None (no existing folders)
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
session.execute.return_value = mock_result
|
||||
|
||||
folder_instances = []
|
||||
|
||||
def track_add(obj):
|
||||
folder_instances.append(obj)
|
||||
|
||||
session.add = track_add
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.folder_service.validate_folder_depth", new_callable=AsyncMock
|
||||
),
|
||||
patch(
|
||||
"app.services.folder_service.generate_folder_position",
|
||||
new_callable=AsyncMock,
|
||||
return_value="a0",
|
||||
),
|
||||
):
|
||||
# Mock flush to assign IDs
|
||||
call_count = 0
|
||||
|
||||
async def mock_flush():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if folder_instances:
|
||||
folder_instances[-1].id = call_count
|
||||
|
||||
session.flush = mock_flush
|
||||
|
||||
segments = [
|
||||
{"name": "Slack", "metadata": {"ai_sort": True, "ai_sort_level": 1}},
|
||||
{"name": "2025-03-15", "metadata": {"ai_sort": True, "ai_sort_level": 2}},
|
||||
]
|
||||
|
||||
result = await ensure_folder_hierarchy_with_depth_validation(
|
||||
session, 1, segments
|
||||
)
|
||||
|
||||
assert len(folder_instances) == 2
|
||||
assert folder_instances[0].name == "Slack"
|
||||
assert folder_instances[1].name == "2025-03-15"
|
||||
assert result is folder_instances[-1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reuses_existing_folder():
|
||||
"""When a folder already exists, it should be reused, not created."""
|
||||
from app.services.folder_service import (
|
||||
ensure_folder_hierarchy_with_depth_validation,
|
||||
)
|
||||
|
||||
session = AsyncMock()
|
||||
|
||||
existing_folder = MagicMock()
|
||||
existing_folder.id = 42
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = existing_folder
|
||||
session.execute.return_value = mock_result
|
||||
|
||||
segments = [{"name": "Existing", "metadata": None}]
|
||||
|
||||
result = await ensure_folder_hierarchy_with_depth_validation(session, 1, segments)
|
||||
|
||||
assert result is existing_folder
|
||||
session.add.assert_not_called()
|
||||
|
|
@ -798,7 +798,7 @@ export default function NewChatPage() {
|
|||
});
|
||||
} else {
|
||||
const tcId = `interrupt-${action.name}`;
|
||||
addToolCall(contentPartsState, TOOLS_WITH_UI, tcId, action.name, action.args);
|
||||
addToolCall(contentPartsState, TOOLS_WITH_UI, tcId, action.name, action.args, true);
|
||||
updateToolCall(contentPartsState, tcId, {
|
||||
result: { __interrupt__: true, ...interruptData },
|
||||
});
|
||||
|
|
@ -1125,7 +1125,7 @@ export default function NewChatPage() {
|
|||
});
|
||||
} else {
|
||||
const tcId = `interrupt-${action.name}`;
|
||||
addToolCall(contentPartsState, TOOLS_WITH_UI, tcId, action.name, action.args);
|
||||
addToolCall(contentPartsState, TOOLS_WITH_UI, tcId, action.name, action.args, true);
|
||||
updateToolCall(contentPartsState, tcId, {
|
||||
result: {
|
||||
__interrupt__: true,
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ export function ProfileContent() {
|
|||
</div>
|
||||
) : (
|
||||
<form onSubmit={handleSubmit} className="space-y-6">
|
||||
<div className="rounded-lg border bg-card p-6">
|
||||
<div className="rounded-lg bg-card">
|
||||
<div className="flex flex-col gap-6">
|
||||
<div className="space-y-2">
|
||||
<Label>{t("profile_avatar")}</Label>
|
||||
|
|
|
|||
44
surfsense_web/app/dashboard/error.tsx
Normal file
44
surfsense_web/app/dashboard/error.tsx
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
"use client";
|
||||
|
||||
import Link from "next/link";
|
||||
import { useEffect } from "react";
|
||||
|
||||
export default function DashboardError({
|
||||
error,
|
||||
reset,
|
||||
}: {
|
||||
error: globalThis.Error & { digest?: string };
|
||||
reset: () => void;
|
||||
}) {
|
||||
useEffect(() => {
|
||||
import("posthog-js")
|
||||
.then(({ default: posthog }) => {
|
||||
posthog.captureException(error);
|
||||
})
|
||||
.catch(() => {});
|
||||
}, [error]);
|
||||
|
||||
return (
|
||||
<div className="flex flex-1 flex-col items-center justify-center gap-4 p-8 text-center">
|
||||
<h2 className="text-xl font-semibold">Something went wrong</h2>
|
||||
<p className="text-muted-foreground max-w-md">
|
||||
An error occurred in this section. Your dashboard is still available.
|
||||
</p>
|
||||
<div className="flex gap-2">
|
||||
<button
|
||||
type="button"
|
||||
onClick={reset}
|
||||
className="rounded-md bg-primary px-4 py-2 text-sm font-medium text-primary-foreground hover:bg-primary/90 transition-colors"
|
||||
>
|
||||
Try again
|
||||
</button>
|
||||
<Link
|
||||
href="/dashboard"
|
||||
className="rounded-md border border-input bg-background px-4 py-2 text-sm font-medium hover:bg-accent hover:text-accent-foreground transition-colors"
|
||||
>
|
||||
Go to dashboard home
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -43,6 +43,7 @@ import { useComments } from "@/hooks/use-comments";
|
|||
import { useMediaQuery } from "@/hooks/use-media-query";
|
||||
import { useElectronAPI } from "@/hooks/use-platform";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { openSafeNavigationHref, resolveSafeNavigationHref } from "@/components/tool-ui/shared/media";
|
||||
|
||||
// Captured once at module load — survives client-side navigations that strip the query param.
|
||||
const IS_QUICK_ASSIST_WINDOW =
|
||||
|
|
@ -384,6 +385,7 @@ const AssistantMessageInner: FC = () => {
|
|||
generate_image: GenerateImageToolUI,
|
||||
update_memory: UpdateMemoryToolUI,
|
||||
execute: SandboxExecuteToolUI,
|
||||
execute_code: SandboxExecuteToolUI,
|
||||
create_notion_page: CreateNotionPageToolUI,
|
||||
update_notion_page: UpdateNotionPageToolUI,
|
||||
delete_notion_page: DeleteNotionPageToolUI,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import { Search, Unplug } from "lucide-react";
|
||||
import type { FC } from "react";
|
||||
import { getDocumentTypeLabel } from "@/components/documents/DocumentTypeIcon";
|
||||
import { getDocumentTypeLabel } from "@/lib/documents/document-type-labels";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { TabsContent } from "@/components/ui/tabs";
|
||||
|
|
|
|||
|
|
@ -499,10 +499,14 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
const empty = text.length === 0 && mentionedDocs.size === 0;
|
||||
setIsEmpty(empty);
|
||||
|
||||
// Check for @ mentions
|
||||
// Unified trigger scan: find the leftmost @ or / in the current word.
|
||||
// Whichever trigger was typed first owns the token — the other character
|
||||
// is treated as part of the query, not as a separate trigger.
|
||||
const selection = window.getSelection();
|
||||
let shouldTriggerMention = false;
|
||||
let mentionQuery = "";
|
||||
let shouldTriggerAction = false;
|
||||
let actionQuery = "";
|
||||
|
||||
if (selection && selection.rangeCount > 0) {
|
||||
const range = selection.getRangeAt(0);
|
||||
|
|
@ -512,63 +516,41 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
const textContent = textNode.textContent || "";
|
||||
const cursorPos = range.startOffset;
|
||||
|
||||
// Look for @ before cursor
|
||||
let atIndex = -1;
|
||||
let wordStart = 0;
|
||||
for (let i = cursorPos - 1; i >= 0; i--) {
|
||||
if (textContent[i] === "@") {
|
||||
atIndex = i;
|
||||
break;
|
||||
}
|
||||
// Stop if we hit a space (@ must be at word boundary)
|
||||
if (textContent[i] === " " || textContent[i] === "\n") {
|
||||
wordStart = i + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (atIndex !== -1) {
|
||||
const query = textContent.slice(atIndex + 1, cursorPos);
|
||||
// Only trigger if query doesn't start with space
|
||||
let triggerChar: "@" | "/" | null = null;
|
||||
let triggerIndex = -1;
|
||||
for (let i = wordStart; i < cursorPos; i++) {
|
||||
if (textContent[i] === "@" || textContent[i] === "/") {
|
||||
triggerChar = textContent[i] as "@" | "/";
|
||||
triggerIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (triggerChar === "@" && triggerIndex !== -1) {
|
||||
const query = textContent.slice(triggerIndex + 1, cursorPos);
|
||||
if (!query.startsWith(" ")) {
|
||||
shouldTriggerMention = true;
|
||||
mentionQuery = query;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for / actions (same pattern as @)
|
||||
let shouldTriggerAction = false;
|
||||
let actionQuery = "";
|
||||
|
||||
if (!shouldTriggerMention && selection && selection.rangeCount > 0) {
|
||||
const range = selection.getRangeAt(0);
|
||||
const textNode = range.startContainer;
|
||||
|
||||
if (textNode.nodeType === Node.TEXT_NODE) {
|
||||
const textContent = textNode.textContent || "";
|
||||
const cursorPos = range.startOffset;
|
||||
|
||||
let slashIndex = -1;
|
||||
for (let i = cursorPos - 1; i >= 0; i--) {
|
||||
if (textContent[i] === "/") {
|
||||
slashIndex = i;
|
||||
break;
|
||||
}
|
||||
if (textContent[i] === " " || textContent[i] === "\n") {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
slashIndex !== -1 &&
|
||||
(slashIndex === 0 ||
|
||||
textContent[slashIndex - 1] === " " ||
|
||||
textContent[slashIndex - 1] === "\n")
|
||||
) {
|
||||
const query = textContent.slice(slashIndex + 1, cursorPos);
|
||||
if (!query.startsWith(" ")) {
|
||||
shouldTriggerAction = true;
|
||||
actionQuery = query;
|
||||
} else if (triggerChar === "/" && triggerIndex !== -1) {
|
||||
if (
|
||||
triggerIndex === 0 ||
|
||||
textContent[triggerIndex - 1] === " " ||
|
||||
textContent[triggerIndex - 1] === "\n"
|
||||
) {
|
||||
const query = textContent.slice(triggerIndex + 1, cursorPos);
|
||||
if (!query.startsWith(" ")) {
|
||||
shouldTriggerAction = true;
|
||||
actionQuery = query;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,8 +28,7 @@ import {
|
|||
import { AnimatePresence, motion } from "motion/react";
|
||||
import Image from "next/image";
|
||||
import { useParams } from "next/navigation";
|
||||
import { type FC, useCallback, useEffect, useLayoutEffect, useMemo, useRef, useState } from "react";
|
||||
import { createPortal } from "react-dom";
|
||||
import { type FC, useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import {
|
||||
agentToolsAtom,
|
||||
disabledToolsAtom,
|
||||
|
|
@ -61,7 +60,7 @@ import {
|
|||
} from "@/components/assistant-ui/inline-mention-editor";
|
||||
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
||||
import { UserMessage } from "@/components/assistant-ui/user-message";
|
||||
import { SLIDEOUT_PANEL_OPENED_EVENT } from "@/components/layout/ui/sidebar/SidebarSlideOutPanel";
|
||||
import { SLIDEOUT_PANEL_OPENED_EVENT } from "@/lib/layout-events";
|
||||
import {
|
||||
DocumentMentionPicker,
|
||||
type DocumentMentionPickerRef,
|
||||
|
|
@ -124,16 +123,18 @@ const ThreadContent: FC = () => {
|
|||
}}
|
||||
/>
|
||||
|
||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
||||
<div className="grow" />
|
||||
</AuiIf>
|
||||
|
||||
<ThreadPrimitive.ViewportFooter
|
||||
className="aui-thread-viewport-footer sticky bottom-0 z-10 mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible rounded-t-3xl bg-main-panel pb-4 md:pb-6"
|
||||
style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }}
|
||||
>
|
||||
<ThreadScrollToBottom />
|
||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
||||
<div className="fade-in slide-in-from-bottom-4 animate-in duration-500 ease-out fill-mode-both">
|
||||
<Composer />
|
||||
</div>
|
||||
</AuiIf>
|
||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
||||
<Composer />
|
||||
</AuiIf>
|
||||
</ThreadPrimitive.ViewportFooter>
|
||||
</ThreadPrimitive.Viewport>
|
||||
</ThreadPrimitive.Root>
|
||||
|
|
@ -339,10 +340,7 @@ const Composer: FC = () => {
|
|||
const [showPromptPicker, setShowPromptPicker] = useState(false);
|
||||
const [mentionQuery, setMentionQuery] = useState("");
|
||||
const [actionQuery, setActionQuery] = useState("");
|
||||
const [containerPos, setContainerPos] = useState({ bottom: "200px", left: "50%", top: "auto" });
|
||||
const editorRef = useRef<InlineMentionEditorRef>(null);
|
||||
const editorContainerRef = useRef<HTMLDivElement>(null);
|
||||
const composerBoxRef = useRef<HTMLDivElement>(null);
|
||||
const documentPickerRef = useRef<DocumentMentionPickerRef>(null);
|
||||
const promptPickerRef = useRef<PromptPickerRef>(null);
|
||||
const viewportRef = useRef<Element | null>(null);
|
||||
|
|
@ -363,38 +361,13 @@ const Composer: FC = () => {
|
|||
viewportRef.current = document.querySelector(".aui-thread-viewport");
|
||||
}, []);
|
||||
|
||||
// Compute picker positions using ResizeObserver to avoid layout reads during render
|
||||
useLayoutEffect(() => {
|
||||
if (!editorContainerRef.current) return;
|
||||
|
||||
const updatePosition = () => {
|
||||
if (!editorContainerRef.current) return;
|
||||
const rect = editorContainerRef.current.getBoundingClientRect();
|
||||
const composerRect = composerBoxRef.current?.getBoundingClientRect();
|
||||
setContainerPos({
|
||||
bottom: `${window.innerHeight - rect.top + 8}px`,
|
||||
left: `${rect.left}px`,
|
||||
top: composerRect ? `${composerRect.bottom + 8}px` : "auto",
|
||||
});
|
||||
};
|
||||
|
||||
updatePosition();
|
||||
const ro = new ResizeObserver(updatePosition);
|
||||
ro.observe(editorContainerRef.current);
|
||||
if (composerBoxRef.current) {
|
||||
ro.observe(composerBoxRef.current);
|
||||
}
|
||||
|
||||
return () => ro.disconnect();
|
||||
}, []);
|
||||
|
||||
const electronAPI = useElectronAPI();
|
||||
const [clipboardInitialText, setClipboardInitialText] = useState<string | undefined>();
|
||||
const clipboardLoadedRef = useRef(false);
|
||||
useEffect(() => {
|
||||
if (!electronAPI || clipboardLoadedRef.current) return;
|
||||
clipboardLoadedRef.current = true;
|
||||
electronAPI.getQuickAskText().then((text) => {
|
||||
electronAPI.getQuickAskText().then((text: string) => {
|
||||
if (text) {
|
||||
setClipboardInitialText(text);
|
||||
}
|
||||
|
|
@ -587,23 +560,15 @@ const Composer: FC = () => {
|
|||
|
||||
// Submit message (blocked during streaming, document picker open, or AI responding to another user)
|
||||
const handleSubmit = useCallback(() => {
|
||||
if (isThreadRunning || isBlockedByOtherUser) {
|
||||
return;
|
||||
}
|
||||
if (!showDocumentPopover && !showPromptPicker) {
|
||||
if (clipboardInitialText) {
|
||||
const userText = editorRef.current?.getText() ?? "";
|
||||
const combined = userText ? `${userText}\n\n${clipboardInitialText}` : clipboardInitialText;
|
||||
aui.composer().setText(combined);
|
||||
setClipboardInitialText(undefined);
|
||||
}
|
||||
aui.composer().send();
|
||||
editorRef.current?.clear();
|
||||
setMentionedDocuments([]);
|
||||
setSidebarDocs([]);
|
||||
}
|
||||
if (isThreadRunning || isBlockedByOtherUser) return;
|
||||
if (showDocumentPopover) return;
|
||||
if (showDocumentPopover || showPromptPicker) return;
|
||||
|
||||
if (clipboardInitialText) {
|
||||
const userText = editorRef.current?.getText() ?? "";
|
||||
const combined = userText ? `${userText}\n\n${clipboardInitialText}` : clipboardInitialText;
|
||||
aui.composer().setText(combined);
|
||||
setClipboardInitialText(undefined);
|
||||
}
|
||||
|
||||
const viewportEl = viewportRef.current;
|
||||
const heightBefore = viewportEl?.scrollHeight ?? 0;
|
||||
|
|
@ -617,18 +582,14 @@ const Composer: FC = () => {
|
|||
// assistant message so that scrolling-to-bottom actually positions the
|
||||
// user message at the TOP of the viewport. That slack height is
|
||||
// calculated asynchronously (ResizeObserver → style → layout).
|
||||
//
|
||||
// We poll via rAF for ~2 s, re-scrolling whenever scrollHeight changes
|
||||
// (user msg render → assistant placeholder → ViewportSlack min-height →
|
||||
// first streamed content). Backup setTimeout calls cover cases where
|
||||
// the batcher's 50 ms throttle delays the DOM update past the rAF.
|
||||
// Poll via rAF for ~500ms, re-scrolling whenever scrollHeight changes.
|
||||
const scrollToBottom = () =>
|
||||
threadViewportStore.getState().scrollToBottom({ behavior: "instant" });
|
||||
|
||||
let lastHeight = heightBefore;
|
||||
let frames = 0;
|
||||
let cancelled = false;
|
||||
const POLL_FRAMES = 120;
|
||||
const POLL_FRAMES = 30;
|
||||
|
||||
const pollAndScroll = () => {
|
||||
if (cancelled) return;
|
||||
|
|
@ -648,16 +609,11 @@ const Composer: FC = () => {
|
|||
|
||||
const t1 = setTimeout(scrollToBottom, 100);
|
||||
const t2 = setTimeout(scrollToBottom, 300);
|
||||
const t3 = setTimeout(scrollToBottom, 600);
|
||||
|
||||
// Cleanup if component unmounts during the polling window. The ref is
|
||||
// checked inside pollAndScroll; timeouts are cleared in the return below.
|
||||
// Store cleanup fn so it can be called from a useEffect cleanup if needed.
|
||||
submitCleanupRef.current = () => {
|
||||
cancelled = true;
|
||||
clearTimeout(t1);
|
||||
clearTimeout(t2);
|
||||
clearTimeout(t3);
|
||||
};
|
||||
}, [
|
||||
showDocumentPopover,
|
||||
|
|
@ -705,28 +661,54 @@ const Composer: FC = () => {
|
|||
);
|
||||
|
||||
return (
|
||||
<ComposerPrimitive.Root
|
||||
className="aui-composer-root relative flex w-full flex-col gap-2"
|
||||
style={showPromptPicker && clipboardInitialText ? { marginBottom: 220 } : undefined}
|
||||
>
|
||||
<ComposerPrimitive.Root className="aui-composer-root relative flex w-full flex-col gap-2">
|
||||
<ChatSessionStatus
|
||||
isAiResponding={isAiResponding}
|
||||
respondingToUserId={respondingToUserId}
|
||||
currentUserId={currentUser?.id ?? null}
|
||||
members={members ?? []}
|
||||
/>
|
||||
<div
|
||||
ref={composerBoxRef}
|
||||
className="aui-composer-attachment-dropzone flex w-full flex-col overflow-hidden rounded-2xl border-input bg-muted pt-2 outline-none transition-shadow"
|
||||
>
|
||||
{showDocumentPopover && (
|
||||
<div className="absolute bottom-full left-0 z-[9999] mb-2">
|
||||
<DocumentMentionPicker
|
||||
ref={documentPickerRef}
|
||||
searchSpaceId={Number(search_space_id)}
|
||||
onSelectionChange={handleDocumentsMention}
|
||||
onDone={() => {
|
||||
setShowDocumentPopover(false);
|
||||
setMentionQuery("");
|
||||
}}
|
||||
initialSelectedDocuments={mentionedDocuments}
|
||||
externalSearch={mentionQuery}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{showPromptPicker && (
|
||||
<div
|
||||
className={cn(
|
||||
"absolute left-0 z-[9999]",
|
||||
clipboardInitialText ? "top-full mt-2" : "bottom-full mb-2"
|
||||
)}
|
||||
>
|
||||
<PromptPicker
|
||||
ref={promptPickerRef}
|
||||
onSelect={clipboardInitialText ? handleQuickAskSelect : handleActionSelect}
|
||||
onDone={() => {
|
||||
setShowPromptPicker(false);
|
||||
setActionQuery("");
|
||||
}}
|
||||
externalSearch={actionQuery}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<div className="aui-composer-attachment-dropzone flex w-full flex-col overflow-hidden rounded-2xl border-input bg-muted pt-2 outline-none transition-shadow">
|
||||
{clipboardInitialText && (
|
||||
<ClipboardChip
|
||||
text={clipboardInitialText}
|
||||
onDismiss={() => setClipboardInitialText(undefined)}
|
||||
/>
|
||||
)}
|
||||
{/* Inline editor with @mention support */}
|
||||
<div ref={editorContainerRef} className="aui-composer-input-wrapper px-4 pt-3 pb-6">
|
||||
<div className="aui-composer-input-wrapper px-4 pt-3 pb-6">
|
||||
<InlineMentionEditor
|
||||
ref={editorRef}
|
||||
placeholder={currentPlaceholder}
|
||||
|
|
@ -741,49 +723,6 @@ const Composer: FC = () => {
|
|||
className="min-h-[24px]"
|
||||
/>
|
||||
</div>
|
||||
{/* Document picker popover (portal to body for proper z-index stacking) */}
|
||||
{showDocumentPopover &&
|
||||
typeof document !== "undefined" &&
|
||||
createPortal(
|
||||
<DocumentMentionPicker
|
||||
ref={documentPickerRef}
|
||||
searchSpaceId={Number(search_space_id)}
|
||||
onSelectionChange={handleDocumentsMention}
|
||||
onDone={() => {
|
||||
setShowDocumentPopover(false);
|
||||
setMentionQuery("");
|
||||
}}
|
||||
initialSelectedDocuments={mentionedDocuments}
|
||||
externalSearch={mentionQuery}
|
||||
containerStyle={{
|
||||
bottom: containerPos.bottom,
|
||||
left: containerPos.left,
|
||||
}}
|
||||
/>,
|
||||
document.body
|
||||
)}
|
||||
{showPromptPicker &&
|
||||
typeof document !== "undefined" &&
|
||||
createPortal(
|
||||
<PromptPicker
|
||||
ref={promptPickerRef}
|
||||
onSelect={clipboardInitialText ? handleQuickAskSelect : handleActionSelect}
|
||||
onDone={() => {
|
||||
setShowPromptPicker(false);
|
||||
setActionQuery("");
|
||||
}}
|
||||
externalSearch={actionQuery}
|
||||
containerStyle={{
|
||||
position: "fixed",
|
||||
...(clipboardInitialText
|
||||
? { top: containerPos.top }
|
||||
: { bottom: containerPos.bottom }),
|
||||
left: containerPos.left,
|
||||
zIndex: 50,
|
||||
}}
|
||||
/>,
|
||||
document.body
|
||||
)}
|
||||
<ComposerAction isBlockedByOtherUser={isBlockedByOtherUser} />
|
||||
<ConnectorIndicator showTrigger={false} />
|
||||
<ConnectToolsBanner isThreadEmpty={isThreadEmpty} />
|
||||
|
|
|
|||
|
|
@ -1,14 +1,16 @@
|
|||
import type { ToolCallMessagePartComponent } from "@assistant-ui/react";
|
||||
import { CheckIcon, ChevronDownIcon, ChevronUpIcon, XCircleIcon } from "lucide-react";
|
||||
import { useMemo, useState } from "react";
|
||||
import { GenericHitlApprovalToolUI } from "@/components/tool-ui/generic-hitl-approval";
|
||||
import { getToolIcon } from "@/contracts/enums/toolIcons";
|
||||
import { isInterruptResult } from "@/lib/hitl";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
function formatToolName(name: string): string {
|
||||
return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase());
|
||||
}
|
||||
|
||||
export const ToolFallback: ToolCallMessagePartComponent = ({
|
||||
const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({
|
||||
toolName,
|
||||
argsText,
|
||||
result,
|
||||
|
|
@ -145,3 +147,10 @@ export const ToolFallback: ToolCallMessagePartComponent = ({
|
|||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export const ToolFallback: ToolCallMessagePartComponent = (props) => {
|
||||
if (isInterruptResult(props.result)) {
|
||||
return <GenericHitlApprovalToolUI {...props} />;
|
||||
}
|
||||
return <DefaultToolFallbackInner {...props} />;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import { Button } from "@/components/ui/button";
|
|||
import { cn } from "@/lib/utils";
|
||||
import { CommentComposer } from "../comment-composer/comment-composer";
|
||||
import { CommentActions } from "./comment-actions";
|
||||
import { convertRenderedToDisplay } from "@/lib/comments/utils";
|
||||
import type { CommentItemProps } from "./types";
|
||||
|
||||
function getInitials(name: string | null, email: string): string {
|
||||
|
|
@ -69,10 +70,6 @@ function formatTimestamp(dateString: string): string {
|
|||
);
|
||||
}
|
||||
|
||||
export function convertRenderedToDisplay(contentRendered: string): string {
|
||||
// Convert @{DisplayName} format to @DisplayName for editing
|
||||
return contentRendered.replace(/@\{([^}]+)\}/g, "@$1");
|
||||
}
|
||||
|
||||
function renderMentions(content: string): React.ReactNode {
|
||||
// Match @{DisplayName} format from backend
|
||||
|
|
|
|||
|
|
@ -82,11 +82,12 @@ export const DocumentNode = React.memo(function DocumentNode({
|
|||
onContextMenuOpenChange,
|
||||
}: DocumentNodeProps) {
|
||||
const statusState = doc.status?.state ?? "ready";
|
||||
const isSelectable = statusState !== "pending" && statusState !== "processing";
|
||||
const isFailed = statusState === "failed";
|
||||
const isProcessing = statusState === "pending" || statusState === "processing";
|
||||
const isUnavailable = isProcessing || isFailed;
|
||||
const isSelectable = !isUnavailable;
|
||||
const isEditable =
|
||||
EDITABLE_DOCUMENT_TYPES.has(doc.document_type) &&
|
||||
statusState !== "pending" &&
|
||||
statusState !== "processing";
|
||||
EDITABLE_DOCUMENT_TYPES.has(doc.document_type) && !isUnavailable;
|
||||
|
||||
const handleCheckChange = useCallback(() => {
|
||||
if (isSelectable) {
|
||||
|
|
@ -103,7 +104,6 @@ export const DocumentNode = React.memo(function DocumentNode({
|
|||
[doc.id]
|
||||
);
|
||||
|
||||
const isProcessing = statusState === "pending" || statusState === "processing";
|
||||
const [dropdownOpen, setDropdownOpen] = useState(false);
|
||||
const [exporting, setExporting] = useState<string | null>(null);
|
||||
const [titleTooltipOpen, setTitleTooltipOpen] = useState(false);
|
||||
|
|
@ -261,38 +261,38 @@ export const DocumentNode = React.memo(function DocumentNode({
|
|||
className="w-40"
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
<DropdownMenuItem onClick={() => onPreview(doc)} disabled={isProcessing}>
|
||||
<Eye className="mr-2 h-4 w-4" />
|
||||
Open
|
||||
<DropdownMenuItem onClick={() => onPreview(doc)} disabled={isUnavailable}>
|
||||
<Eye className="mr-2 h-4 w-4" />
|
||||
Open
|
||||
</DropdownMenuItem>
|
||||
{isEditable && (
|
||||
<DropdownMenuItem onClick={() => onEdit(doc)}>
|
||||
<PenLine className="mr-2 h-4 w-4" />
|
||||
Edit
|
||||
</DropdownMenuItem>
|
||||
{isEditable && (
|
||||
<DropdownMenuItem onClick={() => onEdit(doc)}>
|
||||
<PenLine className="mr-2 h-4 w-4" />
|
||||
Edit
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
<DropdownMenuItem onClick={() => onMove(doc)}>
|
||||
<Move className="mr-2 h-4 w-4" />
|
||||
Move to...
|
||||
)}
|
||||
<DropdownMenuItem onClick={() => onMove(doc)}>
|
||||
<Move className="mr-2 h-4 w-4" />
|
||||
Move to...
|
||||
</DropdownMenuItem>
|
||||
{onExport && (
|
||||
<DropdownMenuSub>
|
||||
<DropdownMenuSubTrigger disabled={isUnavailable}>
|
||||
<Download className="mr-2 h-4 w-4" />
|
||||
Export
|
||||
</DropdownMenuSubTrigger>
|
||||
<DropdownMenuSubContent className="min-w-[180px]">
|
||||
<ExportDropdownItems onExport={handleExport} exporting={exporting} />
|
||||
</DropdownMenuSubContent>
|
||||
</DropdownMenuSub>
|
||||
)}
|
||||
{onVersionHistory && isVersionableType(doc.document_type) && (
|
||||
<DropdownMenuItem disabled={isUnavailable} onClick={() => onVersionHistory(doc)}>
|
||||
<History className="mr-2 h-4 w-4" />
|
||||
Versions
|
||||
</DropdownMenuItem>
|
||||
{onExport && (
|
||||
<DropdownMenuSub>
|
||||
<DropdownMenuSubTrigger disabled={isProcessing}>
|
||||
<Download className="mr-2 h-4 w-4" />
|
||||
Export
|
||||
</DropdownMenuSubTrigger>
|
||||
<DropdownMenuSubContent className="min-w-[180px]">
|
||||
<ExportDropdownItems onExport={handleExport} exporting={exporting} />
|
||||
</DropdownMenuSubContent>
|
||||
</DropdownMenuSub>
|
||||
)}
|
||||
{onVersionHistory && isVersionableType(doc.document_type) && (
|
||||
<DropdownMenuItem disabled={isProcessing} onClick={() => onVersionHistory(doc)}>
|
||||
<History className="mr-2 h-4 w-4" />
|
||||
Versions
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
<DropdownMenuItem disabled={isProcessing} onClick={() => onDelete(doc)}>
|
||||
)}
|
||||
<DropdownMenuItem disabled={isProcessing} onClick={() => onDelete(doc)}>
|
||||
<Trash2 className="mr-2 h-4 w-4" />
|
||||
Delete
|
||||
</DropdownMenuItem>
|
||||
|
|
@ -304,38 +304,38 @@ export const DocumentNode = React.memo(function DocumentNode({
|
|||
|
||||
{contextMenuOpen && (
|
||||
<ContextMenuContent className="w-40" onClick={(e) => e.stopPropagation()}>
|
||||
<ContextMenuItem onClick={() => onPreview(doc)} disabled={isProcessing}>
|
||||
<Eye className="mr-2 h-4 w-4" />
|
||||
Open
|
||||
<ContextMenuItem onClick={() => onPreview(doc)} disabled={isUnavailable}>
|
||||
<Eye className="mr-2 h-4 w-4" />
|
||||
Open
|
||||
</ContextMenuItem>
|
||||
{isEditable && (
|
||||
<ContextMenuItem onClick={() => onEdit(doc)}>
|
||||
<PenLine className="mr-2 h-4 w-4" />
|
||||
Edit
|
||||
</ContextMenuItem>
|
||||
{isEditable && (
|
||||
<ContextMenuItem onClick={() => onEdit(doc)}>
|
||||
<PenLine className="mr-2 h-4 w-4" />
|
||||
Edit
|
||||
</ContextMenuItem>
|
||||
)}
|
||||
<ContextMenuItem onClick={() => onMove(doc)}>
|
||||
<Move className="mr-2 h-4 w-4" />
|
||||
Move to...
|
||||
)}
|
||||
<ContextMenuItem onClick={() => onMove(doc)}>
|
||||
<Move className="mr-2 h-4 w-4" />
|
||||
Move to...
|
||||
</ContextMenuItem>
|
||||
{onExport && (
|
||||
<ContextMenuSub>
|
||||
<ContextMenuSubTrigger disabled={isUnavailable}>
|
||||
<Download className="mr-2 h-4 w-4" />
|
||||
Export
|
||||
</ContextMenuSubTrigger>
|
||||
<ContextMenuSubContent className="min-w-[180px]">
|
||||
<ExportContextItems onExport={handleExport} exporting={exporting} />
|
||||
</ContextMenuSubContent>
|
||||
</ContextMenuSub>
|
||||
)}
|
||||
{onVersionHistory && isVersionableType(doc.document_type) && (
|
||||
<ContextMenuItem disabled={isUnavailable} onClick={() => onVersionHistory(doc)}>
|
||||
<History className="mr-2 h-4 w-4" />
|
||||
Versions
|
||||
</ContextMenuItem>
|
||||
{onExport && (
|
||||
<ContextMenuSub>
|
||||
<ContextMenuSubTrigger disabled={isProcessing}>
|
||||
<Download className="mr-2 h-4 w-4" />
|
||||
Export
|
||||
</ContextMenuSubTrigger>
|
||||
<ContextMenuSubContent className="min-w-[180px]">
|
||||
<ExportContextItems onExport={handleExport} exporting={exporting} />
|
||||
</ContextMenuSubContent>
|
||||
</ContextMenuSub>
|
||||
)}
|
||||
{onVersionHistory && isVersionableType(doc.document_type) && (
|
||||
<ContextMenuItem disabled={isProcessing} onClick={() => onVersionHistory(doc)}>
|
||||
<History className="mr-2 h-4 w-4" />
|
||||
Versions
|
||||
</ContextMenuItem>
|
||||
)}
|
||||
<ContextMenuItem disabled={isProcessing} onClick={() => onDelete(doc)}>
|
||||
)}
|
||||
<ContextMenuItem disabled={isProcessing} onClick={() => onDelete(doc)}>
|
||||
<Trash2 className="mr-2 h-4 w-4" />
|
||||
Delete
|
||||
</ContextMenuItem>
|
||||
|
|
|
|||
|
|
@ -4,52 +4,12 @@ import type React from "react";
|
|||
import { useEffect, useRef, useState } from "react";
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
|
||||
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
||||
import { getDocumentTypeLabel } from "@/lib/documents/document-type-labels";
|
||||
|
||||
export function getDocumentTypeIcon(type: string, className?: string): React.ReactNode {
|
||||
return getConnectorIcon(type, className);
|
||||
}
|
||||
|
||||
export function getDocumentTypeLabel(type: string): string {
|
||||
const labelMap: Record<string, string> = {
|
||||
EXTENSION: "Extension",
|
||||
CRAWLED_URL: "Web Page",
|
||||
FILE: "File",
|
||||
SLACK_CONNECTOR: "Slack",
|
||||
TEAMS_CONNECTOR: "Microsoft Teams",
|
||||
ONEDRIVE_FILE: "OneDrive",
|
||||
DROPBOX_FILE: "Dropbox",
|
||||
NOTION_CONNECTOR: "Notion",
|
||||
YOUTUBE_VIDEO: "YouTube Video",
|
||||
GITHUB_CONNECTOR: "GitHub",
|
||||
LINEAR_CONNECTOR: "Linear",
|
||||
DISCORD_CONNECTOR: "Discord",
|
||||
JIRA_CONNECTOR: "Jira",
|
||||
CONFLUENCE_CONNECTOR: "Confluence",
|
||||
CLICKUP_CONNECTOR: "ClickUp",
|
||||
GOOGLE_CALENDAR_CONNECTOR: "Google Calendar",
|
||||
GOOGLE_GMAIL_CONNECTOR: "Gmail",
|
||||
GOOGLE_DRIVE_FILE: "Google Drive",
|
||||
AIRTABLE_CONNECTOR: "Airtable",
|
||||
LUMA_CONNECTOR: "Luma",
|
||||
ELASTICSEARCH_CONNECTOR: "Elasticsearch",
|
||||
BOOKSTACK_CONNECTOR: "BookStack",
|
||||
CIRCLEBACK: "Circleback",
|
||||
OBSIDIAN_CONNECTOR: "Obsidian",
|
||||
LOCAL_FOLDER_FILE: "Local Folder",
|
||||
SURFSENSE_DOCS: "SurfSense Docs",
|
||||
NOTE: "Note",
|
||||
COMPOSIO_GOOGLE_DRIVE_CONNECTOR: "Composio Google Drive",
|
||||
COMPOSIO_GMAIL_CONNECTOR: "Composio Gmail",
|
||||
COMPOSIO_GOOGLE_CALENDAR_CONNECTOR: "Composio Google Calendar",
|
||||
};
|
||||
return (
|
||||
labelMap[type] ||
|
||||
type
|
||||
.split("_")
|
||||
.map((word) => word.charAt(0) + word.slice(1).toLowerCase())
|
||||
.join(" ")
|
||||
);
|
||||
}
|
||||
|
||||
export function DocumentTypeChip({ type, className }: { type: string; className?: string }) {
|
||||
const icon = getDocumentTypeIcon(type, "h-4 w-4");
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
"use client";
|
||||
|
||||
import { Download, FolderPlus, ListFilter, Loader2, Search, Upload, X } from "lucide-react";
|
||||
import { IconBinaryTree, IconBinaryTreeFilled } from "@tabler/icons-react";
|
||||
import { FolderPlus, ListFilter, Search, Upload, X } from "lucide-react";
|
||||
import { AnimatePresence, motion } from "motion/react";
|
||||
import { useTranslations } from "next-intl";
|
||||
import React, { useCallback, useMemo, useRef, useState } from "react";
|
||||
import { useDocumentUploadDialog } from "@/components/assistant-ui/document-upload-popup";
|
||||
|
|
@ -10,8 +12,10 @@ import { Input } from "@/components/ui/input";
|
|||
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
|
||||
import { ToggleGroup, ToggleGroupItem } from "@/components/ui/toggle-group";
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
|
||||
import { cn } from "@/lib/utils";
|
||||
import type { DocumentTypeEnum } from "@/contracts/types/document.types";
|
||||
import { getDocumentTypeIcon, getDocumentTypeLabel } from "./DocumentTypeIcon";
|
||||
import { getDocumentTypeIcon } from "./DocumentTypeIcon";
|
||||
import { getDocumentTypeLabel } from "@/lib/documents/document-type-labels";
|
||||
|
||||
export function DocumentsFilters({
|
||||
typeCounts: typeCountsRecord,
|
||||
|
|
@ -20,8 +24,9 @@ export function DocumentsFilters({
|
|||
onToggleType,
|
||||
activeTypes,
|
||||
onCreateFolder,
|
||||
onExportKB,
|
||||
isExporting,
|
||||
aiSortEnabled = false,
|
||||
aiSortBusy = false,
|
||||
onToggleAiSort,
|
||||
}: {
|
||||
typeCounts: Partial<Record<DocumentTypeEnum, number>>;
|
||||
onSearch: (v: string) => void;
|
||||
|
|
@ -29,8 +34,9 @@ export function DocumentsFilters({
|
|||
onToggleType: (type: DocumentTypeEnum, checked: boolean) => void;
|
||||
activeTypes: DocumentTypeEnum[];
|
||||
onCreateFolder?: () => void;
|
||||
onExportKB?: () => void;
|
||||
isExporting?: boolean;
|
||||
aiSortEnabled?: boolean;
|
||||
aiSortBusy?: boolean;
|
||||
onToggleAiSort?: () => void;
|
||||
}) {
|
||||
const t = useTranslations("documents");
|
||||
const id = React.useId();
|
||||
|
|
@ -68,7 +74,7 @@ export function DocumentsFilters({
|
|||
return (
|
||||
<div className="flex select-none">
|
||||
<div className="flex items-center gap-2 w-full">
|
||||
{/* Filter + New Folder Toggle Group */}
|
||||
{/* New Folder + Filter Toggle Group */}
|
||||
<ToggleGroup type="multiple" variant="outline" value={[]} className="overflow-visible">
|
||||
{onCreateFolder && (
|
||||
<Tooltip>
|
||||
|
|
@ -85,33 +91,8 @@ export function DocumentsFilters({
|
|||
</ToggleGroupItem>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>New folder</TooltipContent>
|
||||
</Tooltip>
|
||||
)}
|
||||
|
||||
{onExportKB && (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<ToggleGroupItem
|
||||
value="export"
|
||||
disabled={isExporting}
|
||||
className="h-9 w-9 shrink-0 border-sidebar-border text-sidebar-foreground/60 hover:text-sidebar-foreground hover:border-sidebar-border bg-sidebar"
|
||||
onClick={(e) => {
|
||||
e.preventDefault();
|
||||
onExportKB();
|
||||
}}
|
||||
>
|
||||
{isExporting ? (
|
||||
<Loader2 size={14} className="animate-spin" />
|
||||
) : (
|
||||
<Download size={14} />
|
||||
)}
|
||||
</ToggleGroupItem>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
{isExporting ? "Exporting…" : "Export knowledge base"}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
)}
|
||||
</Tooltip>
|
||||
)}
|
||||
|
||||
<Popover>
|
||||
<Tooltip>
|
||||
|
|
@ -201,6 +182,70 @@ export function DocumentsFilters({
|
|||
</Popover>
|
||||
</ToggleGroup>
|
||||
|
||||
{/* AI Sort Toggle */}
|
||||
{onToggleAiSort && (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<button
|
||||
type="button"
|
||||
disabled={aiSortBusy}
|
||||
onClick={onToggleAiSort}
|
||||
className={cn(
|
||||
"relative h-9 w-9 shrink-0 rounded-md border inline-flex items-center justify-center transition-all duration-300 ease-out",
|
||||
"focus-visible:border-ring focus-visible:ring-[3px] focus-visible:ring-ring/50 outline-none",
|
||||
"disabled:pointer-events-none disabled:opacity-50",
|
||||
aiSortEnabled
|
||||
? "border-violet-400/60 bg-violet-50 text-violet-600 shadow-[0_0_8px_-1px_rgba(139,92,246,0.3)] hover:bg-violet-100 dark:border-violet-500/40 dark:bg-violet-500/15 dark:text-violet-400 dark:shadow-[0_0_8px_-1px_rgba(139,92,246,0.2)] dark:hover:bg-violet-500/25"
|
||||
: "border-sidebar-border bg-sidebar text-sidebar-foreground/60 hover:text-sidebar-foreground hover:border-sidebar-border hover:bg-accent"
|
||||
)}
|
||||
aria-label={aiSortEnabled ? "Disable AI sort" : "Enable AI sort"}
|
||||
aria-pressed={aiSortEnabled}
|
||||
>
|
||||
<AnimatePresence mode="wait" initial={false}>
|
||||
{aiSortBusy ? (
|
||||
<motion.div
|
||||
key="busy"
|
||||
initial={{ opacity: 0, scale: 0.6, rotate: -90 }}
|
||||
animate={{ opacity: 1, scale: 1, rotate: 0 }}
|
||||
exit={{ opacity: 0, scale: 0.6, rotate: 90 }}
|
||||
transition={{ duration: 0.2, ease: "easeInOut" }}
|
||||
>
|
||||
<IconBinaryTree size={16} className="animate-pulse" />
|
||||
</motion.div>
|
||||
) : aiSortEnabled ? (
|
||||
<motion.div
|
||||
key="on"
|
||||
initial={{ opacity: 0, scale: 0.6, rotate: -90 }}
|
||||
animate={{ opacity: 1, scale: 1, rotate: 0 }}
|
||||
exit={{ opacity: 0, scale: 0.6, rotate: 90 }}
|
||||
transition={{ duration: 0.25, ease: "easeInOut" }}
|
||||
>
|
||||
<IconBinaryTreeFilled size={16} />
|
||||
</motion.div>
|
||||
) : (
|
||||
<motion.div
|
||||
key="off"
|
||||
initial={{ opacity: 0, scale: 0.6, rotate: 90 }}
|
||||
animate={{ opacity: 1, scale: 1, rotate: 0 }}
|
||||
exit={{ opacity: 0, scale: 0.6, rotate: -90 }}
|
||||
transition={{ duration: 0.25, ease: "easeInOut" }}
|
||||
>
|
||||
<IconBinaryTree size={16} />
|
||||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
</button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
{aiSortBusy
|
||||
? "AI sort in progress..."
|
||||
: aiSortEnabled
|
||||
? "AI sort active — click to disable"
|
||||
: "Enable AI sort"}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
)}
|
||||
|
||||
{/* Search Input */}
|
||||
<div className="relative flex-1 min-w-0">
|
||||
<div className="pointer-events-none absolute inset-y-0 left-0 flex items-center pl-3 text-muted-foreground">
|
||||
|
|
|
|||
|
|
@ -90,6 +90,8 @@ export function FolderTreeView({
|
|||
|
||||
const [openContextMenuId, setOpenContextMenuId] = useState<string | null>(null);
|
||||
|
||||
const [manuallyCollapsedAiIds, setManuallyCollapsedAiIds] = useState<Set<number>>(new Set());
|
||||
|
||||
// Single subscription for rename state — derived boolean passed to each FolderNode
|
||||
const [renamingFolderId, setRenamingFolderId] = useAtom(renamingFolderIdAtom);
|
||||
const handleStartRename = useCallback(
|
||||
|
|
@ -98,6 +100,38 @@ export function FolderTreeView({
|
|||
);
|
||||
const handleCancelRename = useCallback(() => setRenamingFolderId(null), [setRenamingFolderId]);
|
||||
|
||||
const aiSortFolderLevels = useMemo(() => {
|
||||
const map = new Map<number, number>();
|
||||
for (const f of folders) {
|
||||
if (f.metadata?.ai_sort === true && typeof f.metadata?.ai_sort_level === "number") {
|
||||
map.set(f.id, f.metadata.ai_sort_level as number);
|
||||
}
|
||||
}
|
||||
return map;
|
||||
}, [folders]);
|
||||
|
||||
const handleToggleExpand = useCallback(
|
||||
(folderId: number) => {
|
||||
const aiLevel = aiSortFolderLevels.get(folderId);
|
||||
if (aiLevel !== undefined && aiLevel < 4) {
|
||||
// AI-auto-expanded folder: only toggle the manual-collapse set.
|
||||
// Calling onToggleExpand would add it to expandedIds and fight auto-expand.
|
||||
setManuallyCollapsedAiIds((prev) => {
|
||||
const next = new Set(prev);
|
||||
if (next.has(folderId)) {
|
||||
next.delete(folderId);
|
||||
} else {
|
||||
next.add(folderId);
|
||||
}
|
||||
return next;
|
||||
});
|
||||
return;
|
||||
}
|
||||
onToggleExpand(folderId);
|
||||
},
|
||||
[onToggleExpand, aiSortFolderLevels]
|
||||
);
|
||||
|
||||
const effectiveActiveTypes = useMemo(() => {
|
||||
if (
|
||||
activeTypes.includes("FILE" as DocumentTypeEnum) &&
|
||||
|
|
@ -212,9 +246,16 @@ export function FolderTreeView({
|
|||
|
||||
function renderLevel(parentId: number | null, depth: number): React.ReactNode[] {
|
||||
const key = parentId ?? "root";
|
||||
const childFolders = (foldersByParent[key] ?? [])
|
||||
.slice()
|
||||
.sort((a, b) => a.position.localeCompare(b.position));
|
||||
const childFolders = (foldersByParent[key] ?? []).slice().sort((a, b) => {
|
||||
const aIsDate =
|
||||
a.metadata?.ai_sort === true && a.metadata?.ai_sort_level === 2;
|
||||
const bIsDate =
|
||||
b.metadata?.ai_sort === true && b.metadata?.ai_sort_level === 2;
|
||||
if (aIsDate && bIsDate) {
|
||||
return b.name.localeCompare(a.name);
|
||||
}
|
||||
return a.position.localeCompare(b.position);
|
||||
});
|
||||
const visibleFolders = hasDescendantMatch
|
||||
? childFolders.filter((f) => hasDescendantMatch[f.id])
|
||||
: childFolders;
|
||||
|
|
@ -226,6 +267,32 @@ export function FolderTreeView({
|
|||
|
||||
const nodes: React.ReactNode[] = [];
|
||||
|
||||
if (parentId === null) {
|
||||
const processingDocs = childDocs.filter((d) => {
|
||||
const state = d.status?.state;
|
||||
return state === "pending" || state === "processing";
|
||||
});
|
||||
for (const d of processingDocs) {
|
||||
nodes.push(
|
||||
<DocumentNode
|
||||
key={`doc-${d.id}`}
|
||||
doc={d}
|
||||
depth={depth}
|
||||
isMentioned={mentionedDocIds.has(d.id)}
|
||||
onToggleChatMention={onToggleChatMention}
|
||||
onPreview={onPreviewDocument}
|
||||
onEdit={onEditDocument}
|
||||
onDelete={onDeleteDocument}
|
||||
onMove={onMoveDocument}
|
||||
onExport={onExportDocument}
|
||||
onVersionHistory={onVersionHistory}
|
||||
contextMenuOpen={openContextMenuId === `doc-${d.id}`}
|
||||
onContextMenuOpenChange={(open) => setOpenContextMenuId(open ? `doc-${d.id}` : null)}
|
||||
/>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for (let i = 0; i < visibleFolders.length; i++) {
|
||||
const f = visibleFolders[i];
|
||||
const siblingPositions = {
|
||||
|
|
@ -233,8 +300,15 @@ export function FolderTreeView({
|
|||
after: i < visibleFolders.length - 1 ? visibleFolders[i + 1].position : null,
|
||||
};
|
||||
|
||||
const isAutoExpanded = !!searchQuery && !!hasDescendantMatch?.[f.id];
|
||||
const isExpanded = expandedIds.has(f.id) || isAutoExpanded;
|
||||
const isSearchAutoExpanded = !!searchQuery && !!hasDescendantMatch?.[f.id];
|
||||
const isAiAutoExpandCandidate =
|
||||
f.metadata?.ai_sort === true &&
|
||||
typeof f.metadata?.ai_sort_level === "number" &&
|
||||
(f.metadata.ai_sort_level as number) < 4;
|
||||
const isManuallyCollapsed = manuallyCollapsedAiIds.has(f.id);
|
||||
const isExpanded = isManuallyCollapsed
|
||||
? isSearchAutoExpanded
|
||||
: expandedIds.has(f.id) || isSearchAutoExpanded || isAiAutoExpandCandidate;
|
||||
|
||||
nodes.push(
|
||||
<FolderNode
|
||||
|
|
@ -246,7 +320,7 @@ export function FolderTreeView({
|
|||
selectionState={folderSelectionStates[f.id] ?? "none"}
|
||||
processingState={folderProcessingStates[f.id] ?? "idle"}
|
||||
onToggleSelect={onToggleFolderSelect}
|
||||
onToggleExpand={onToggleExpand}
|
||||
onToggleExpand={handleToggleExpand}
|
||||
onRename={onRenameFolder}
|
||||
onStartRename={handleStartRename}
|
||||
onCancelRename={handleCancelRename}
|
||||
|
|
@ -270,7 +344,15 @@ export function FolderTreeView({
|
|||
}
|
||||
}
|
||||
|
||||
for (const d of childDocs) {
|
||||
const remainingDocs =
|
||||
parentId === null
|
||||
? childDocs.filter((d) => {
|
||||
const state = d.status?.state;
|
||||
return state !== "pending" && state !== "processing";
|
||||
})
|
||||
: childDocs;
|
||||
|
||||
for (const d of remainingDocs) {
|
||||
nodes.push(
|
||||
<DocumentNode
|
||||
key={`doc-${d.id}`}
|
||||
|
|
|
|||
|
|
@ -1,25 +1,60 @@
|
|||
// ---------------------------------------------------------------------------
|
||||
// MDX curly-brace escaping helper
|
||||
// MDX pre-processing helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
// remarkMdx treats { } as JSX expression delimiters. Arbitrary markdown
|
||||
// (e.g. AI-generated reports) can contain curly braces that are NOT valid JS
|
||||
// expressions, which makes acorn throw "Could not parse expression".
|
||||
// We escape unescaped { and } *outside* of fenced code blocks and inline code
|
||||
// so remarkMdx treats them as literal characters while still parsing
|
||||
// <mark>, <u>, <kbd>, etc. tags correctly.
|
||||
// remarkMdx treats { } as JSX expression delimiters and does NOT support
|
||||
// HTML comments (<!-- -->). Arbitrary markdown from document conversions
|
||||
// (e.g. PDF-to-markdown via Azure/DocIntel) can contain constructs that
|
||||
// break the MDX parser. This module sanitises them before deserialization.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const FENCED_OR_INLINE_CODE = /(```[\s\S]*?```|`[^`\n]+`)/g;
|
||||
|
||||
export function escapeMdxExpressions(md: string): string {
|
||||
// Strip HTML comments that MDX cannot parse.
|
||||
// PDF converters emit <!-- PageHeader="..." -->, <!-- PageBreak -->, etc.
|
||||
// MDX uses JSX-style comments and chokes on HTML comments, causing the
|
||||
// parser to stop at the first occurrence.
|
||||
// - <!-- PageBreak --> becomes a thematic break (---)
|
||||
// - All other HTML comments are removed
|
||||
function stripHtmlComments(md: string): string {
|
||||
return md
|
||||
.replace(/<!--\s*PageBreak\s*-->/gi, "\n---\n")
|
||||
.replace(/<!--[\s\S]*?-->/g, "");
|
||||
}
|
||||
|
||||
// Convert <figure>...</figure> blocks to plain text blockquotes.
|
||||
// <figure> with arbitrary text content is not valid JSX, causing the MDX
|
||||
// parser to fail.
|
||||
function convertFigureBlocks(md: string): string {
|
||||
return md.replace(/<figure[^>]*>([\s\S]*?)<\/figure>/gi, (_match, inner: string) => {
|
||||
const trimmed = (inner as string).trim();
|
||||
if (!trimmed) return "";
|
||||
const quoted = trimmed
|
||||
.split("\n")
|
||||
.map((line) => `> ${line}`)
|
||||
.join("\n");
|
||||
return `\n${quoted}\n`;
|
||||
});
|
||||
}
|
||||
|
||||
// Escape unescaped { and } outside of fenced/inline code so remarkMdx
|
||||
// treats them as literal characters rather than JSX expression delimiters.
|
||||
function escapeCurlyBraces(md: string): string {
|
||||
const parts = md.split(FENCED_OR_INLINE_CODE);
|
||||
|
||||
return parts
|
||||
.map((part, i) => {
|
||||
// Odd indices are code blocks / inline code – leave untouched
|
||||
if (i % 2 === 1) return part;
|
||||
// Escape { and } that are NOT already escaped (no preceding \)
|
||||
return part.replace(/(?<!\\)\{/g, "\\{").replace(/(?<!\\)\}/g, "\\}");
|
||||
})
|
||||
.join("");
|
||||
}
|
||||
|
||||
// Pre-process raw markdown so it can be safely parsed by the MDX-enabled
|
||||
// Plate editor. Applies all sanitisation steps in order.
|
||||
export function escapeMdxExpressions(md: string): string {
|
||||
let result = md;
|
||||
result = stripHtmlComments(result);
|
||||
result = convertFigureBlocks(result);
|
||||
result = escapeCurlyBraces(result);
|
||||
return result;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -532,16 +532,14 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid
|
|||
const isOutOfSync = currentThreadState.id !== null && !params?.chat_id;
|
||||
|
||||
if (isOutOfSync) {
|
||||
// First sync Next.js router by navigating to the current chat's actual URL
|
||||
// This updates the router's internal state to match the browser URL
|
||||
resetCurrentThread();
|
||||
router.replace(`/dashboard/${searchSpaceId}/new-chat/${currentThreadState.id}`);
|
||||
// Allow router to sync, then navigate to fresh new-chat
|
||||
setTimeout(() => {
|
||||
router.push(`/dashboard/${searchSpaceId}/new-chat`);
|
||||
}, 0);
|
||||
// Immediately set the browser URL so the page remounts with a clean /new-chat path
|
||||
window.history.replaceState(null, "", `/dashboard/${searchSpaceId}/new-chat`);
|
||||
// Force-remount the page component to reset all React state synchronously
|
||||
setChatResetKey((k) => k + 1);
|
||||
// Sync Next.js router internals so useParams/usePathname stay correct going forward
|
||||
router.replace(`/dashboard/${searchSpaceId}/new-chat`);
|
||||
} else {
|
||||
// Normal navigation - router is in sync
|
||||
router.push(`/dashboard/${searchSpaceId}/new-chat`);
|
||||
}
|
||||
}, [router, searchSpaceId, currentThreadState.id, params?.chat_id, resetCurrentThread]);
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ import { expandedFolderIdsAtom } from "@/atoms/documents/folder.atoms";
|
|||
import { agentCreatedDocumentsAtom } from "@/atoms/documents/ui.atoms";
|
||||
import { openEditorPanelAtom } from "@/atoms/editor/editor-panel.atom";
|
||||
import { rightPanelCollapsedAtom } from "@/atoms/layout/right-panel.atom";
|
||||
import { searchSpacesAtom } from "@/atoms/search-spaces/search-space-query.atoms";
|
||||
import { CreateFolderDialog } from "@/components/documents/CreateFolderDialog";
|
||||
import type { DocumentNodeDoc } from "@/components/documents/DocumentNode";
|
||||
import { DocumentsFilters } from "@/components/documents/DocumentsFilters";
|
||||
|
|
@ -49,6 +50,7 @@ import { useMediaQuery } from "@/hooks/use-media-query";
|
|||
import { useElectronAPI } from "@/hooks/use-platform";
|
||||
import { documentsApiService } from "@/lib/apis/documents-api.service";
|
||||
import { foldersApiService } from "@/lib/apis/folders-api.service";
|
||||
import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service";
|
||||
import { authenticatedFetch } from "@/lib/auth-utils";
|
||||
import { uploadFolderScan } from "@/lib/folder-sync-upload";
|
||||
import { getSupportedExtensionsSet } from "@/lib/supported-extensions";
|
||||
|
|
@ -108,6 +110,47 @@ export function DocumentsSidebar({
|
|||
const [watchInitialFolder, setWatchInitialFolder] = useState<SelectedFolder | null>(null);
|
||||
const isElectron = typeof window !== "undefined" && !!window.electronAPI;
|
||||
|
||||
// AI File Sort state
|
||||
const { data: searchSpaces, refetch: refetchSearchSpaces } = useAtomValue(searchSpacesAtom);
|
||||
const activeSearchSpace = useMemo(
|
||||
() => searchSpaces?.find((s) => s.id === searchSpaceId),
|
||||
[searchSpaces, searchSpaceId]
|
||||
);
|
||||
const aiSortEnabled = activeSearchSpace?.ai_file_sort_enabled ?? false;
|
||||
const [aiSortBusy, setAiSortBusy] = useState(false);
|
||||
const [aiSortConfirmOpen, setAiSortConfirmOpen] = useState(false);
|
||||
|
||||
const handleToggleAiSort = useCallback(() => {
|
||||
if (aiSortEnabled) {
|
||||
// Disable: just update the setting, no confirmation needed
|
||||
setAiSortBusy(true);
|
||||
searchSpacesApiService
|
||||
.updateSearchSpace({ id: searchSpaceId, data: { ai_file_sort_enabled: false } })
|
||||
.then(() => {
|
||||
refetchSearchSpaces();
|
||||
toast.success("AI file sorting disabled");
|
||||
})
|
||||
.catch(() => toast.error("Failed to disable AI file sorting"))
|
||||
.finally(() => setAiSortBusy(false));
|
||||
} else {
|
||||
setAiSortConfirmOpen(true);
|
||||
}
|
||||
}, [aiSortEnabled, searchSpaceId, refetchSearchSpaces]);
|
||||
|
||||
const handleConfirmEnableAiSort = useCallback(() => {
|
||||
setAiSortConfirmOpen(false);
|
||||
setAiSortBusy(true);
|
||||
searchSpacesApiService
|
||||
.updateSearchSpace({ id: searchSpaceId, data: { ai_file_sort_enabled: true } })
|
||||
.then(() => searchSpacesApiService.triggerAiSort(searchSpaceId))
|
||||
.then(() => {
|
||||
refetchSearchSpaces();
|
||||
toast.success("AI file sorting enabled — organizing your documents in the background");
|
||||
})
|
||||
.catch(() => toast.error("Failed to enable AI file sorting"))
|
||||
.finally(() => setAiSortBusy(false));
|
||||
}, [searchSpaceId, refetchSearchSpaces]);
|
||||
|
||||
const handleWatchLocalFolder = useCallback(async () => {
|
||||
const api = window.electronAPI;
|
||||
if (!api?.selectFolder) return;
|
||||
|
|
@ -406,22 +449,13 @@ export function DocumentsSidebar({
|
|||
setFolderPickerOpen(true);
|
||||
}, []);
|
||||
|
||||
const [isExportingKB, setIsExportingKB] = useState(false);
|
||||
const [, setIsExportingKB] = useState(false);
|
||||
const [exportWarningOpen, setExportWarningOpen] = useState(false);
|
||||
const [exportWarningContext, setExportWarningContext] = useState<{
|
||||
type: "kb" | "folder";
|
||||
folder?: FolderDisplay;
|
||||
folder: FolderDisplay;
|
||||
pendingCount: number;
|
||||
} | null>(null);
|
||||
|
||||
const pendingDocuments = useMemo(
|
||||
() =>
|
||||
treeDocuments.filter(
|
||||
(d) => d.status?.state === "pending" || d.status?.state === "processing"
|
||||
),
|
||||
[treeDocuments]
|
||||
);
|
||||
|
||||
const doExport = useCallback(async (url: string, downloadName: string) => {
|
||||
const response = await authenticatedFetch(url, { method: "GET" });
|
||||
if (!response.ok) {
|
||||
|
|
@ -440,68 +474,28 @@ export function DocumentsSidebar({
|
|||
URL.revokeObjectURL(blobUrl);
|
||||
}, []);
|
||||
|
||||
const handleExportKB = useCallback(async () => {
|
||||
if (isExportingKB) return;
|
||||
|
||||
if (pendingDocuments.length > 0) {
|
||||
setExportWarningContext({ type: "kb", pendingCount: pendingDocuments.length });
|
||||
setExportWarningOpen(true);
|
||||
return;
|
||||
}
|
||||
|
||||
setIsExportingKB(true);
|
||||
try {
|
||||
await doExport(
|
||||
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/export`,
|
||||
"knowledge-base.zip"
|
||||
);
|
||||
toast.success("Knowledge base exported");
|
||||
} catch (err) {
|
||||
console.error("KB export failed:", err);
|
||||
toast.error(err instanceof Error ? err.message : "Export failed");
|
||||
} finally {
|
||||
setIsExportingKB(false);
|
||||
}
|
||||
}, [searchSpaceId, isExportingKB, pendingDocuments.length, doExport]);
|
||||
|
||||
const handleExportWarningConfirm = useCallback(async () => {
|
||||
setExportWarningOpen(false);
|
||||
const ctx = exportWarningContext;
|
||||
if (!ctx) return;
|
||||
if (!ctx?.folder) return;
|
||||
|
||||
if (ctx.type === "kb") {
|
||||
setIsExportingKB(true);
|
||||
try {
|
||||
await doExport(
|
||||
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/export`,
|
||||
"knowledge-base.zip"
|
||||
);
|
||||
toast.success("Knowledge base exported");
|
||||
} catch (err) {
|
||||
console.error("KB export failed:", err);
|
||||
toast.error(err instanceof Error ? err.message : "Export failed");
|
||||
} finally {
|
||||
setIsExportingKB(false);
|
||||
}
|
||||
} else if (ctx.type === "folder" && ctx.folder) {
|
||||
setIsExportingKB(true);
|
||||
try {
|
||||
const safeName =
|
||||
ctx.folder.name
|
||||
.replace(/[^a-zA-Z0-9 _-]/g, "_")
|
||||
.trim()
|
||||
.slice(0, 80) || "folder";
|
||||
await doExport(
|
||||
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/export?folder_id=${ctx.folder.id}`,
|
||||
`${safeName}.zip`
|
||||
);
|
||||
toast.success(`Folder "${ctx.folder.name}" exported`);
|
||||
} catch (err) {
|
||||
console.error("Folder export failed:", err);
|
||||
toast.error(err instanceof Error ? err.message : "Export failed");
|
||||
} finally {
|
||||
setIsExportingKB(false);
|
||||
}
|
||||
setIsExportingKB(true);
|
||||
try {
|
||||
const safeName =
|
||||
ctx.folder.name
|
||||
.replace(/[^a-zA-Z0-9 _-]/g, "_")
|
||||
.trim()
|
||||
.slice(0, 80) || "folder";
|
||||
await doExport(
|
||||
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/export?folder_id=${ctx.folder.id}`,
|
||||
`${safeName}.zip`
|
||||
);
|
||||
toast.success(`Folder "${ctx.folder.name}" exported`);
|
||||
} catch (err) {
|
||||
console.error("Folder export failed:", err);
|
||||
toast.error(err instanceof Error ? err.message : "Export failed");
|
||||
} finally {
|
||||
setIsExportingKB(false);
|
||||
}
|
||||
setExportWarningContext(null);
|
||||
}, [exportWarningContext, searchSpaceId, doExport]);
|
||||
|
|
@ -530,7 +524,6 @@ export function DocumentsSidebar({
|
|||
const folderPendingCount = getPendingCountInSubtree(folder.id);
|
||||
if (folderPendingCount > 0) {
|
||||
setExportWarningContext({
|
||||
type: "folder",
|
||||
folder,
|
||||
pendingCount: folderPendingCount,
|
||||
});
|
||||
|
|
@ -677,9 +670,10 @@ export function DocumentsSidebar({
|
|||
function collectSubtreeDocs(parentId: number): DocumentNodeDoc[] {
|
||||
const directDocs = (treeDocuments ?? []).filter(
|
||||
(d) =>
|
||||
d.folderId === parentId &&
|
||||
d.status?.state !== "pending" &&
|
||||
d.status?.state !== "processing"
|
||||
d.folderId === parentId &&
|
||||
d.status?.state !== "pending" &&
|
||||
d.status?.state !== "processing" &&
|
||||
d.status?.state !== "failed"
|
||||
);
|
||||
const childFolders = foldersByParent[String(parentId)] ?? [];
|
||||
const descendantDocs = childFolders.flatMap((cf) => collectSubtreeDocs(cf.id));
|
||||
|
|
@ -954,8 +948,9 @@ export function DocumentsSidebar({
|
|||
onToggleType={onToggleType}
|
||||
activeTypes={activeTypes}
|
||||
onCreateFolder={() => handleCreateFolder(null)}
|
||||
onExportKB={handleExportKB}
|
||||
isExporting={isExportingKB}
|
||||
aiSortEnabled={aiSortEnabled}
|
||||
aiSortBusy={aiSortBusy}
|
||||
onToggleAiSort={handleToggleAiSort}
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
|
@ -1117,6 +1112,25 @@ export function DocumentsSidebar({
|
|||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
|
||||
<AlertDialog open={aiSortConfirmOpen} onOpenChange={setAiSortConfirmOpen}>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader>
|
||||
<AlertDialogTitle>Enable AI File Sorting?</AlertDialogTitle>
|
||||
<AlertDialogDescription>
|
||||
All documents in this search space will be organized into folders by
|
||||
connector type, date, and AI-generated categories. New documents will
|
||||
also be sorted automatically. You can disable this at any time.
|
||||
</AlertDialogDescription>
|
||||
</AlertDialogHeader>
|
||||
<AlertDialogFooter>
|
||||
<AlertDialogCancel>Cancel</AlertDialogCancel>
|
||||
<AlertDialogAction onClick={handleConfirmEnableAiSort}>
|
||||
Enable
|
||||
</AlertDialogAction>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
</>
|
||||
);
|
||||
|
||||
|
|
|
|||
|
|
@ -22,8 +22,8 @@ import { useParams, useRouter } from "next/navigation";
|
|||
import { useTranslations } from "next-intl";
|
||||
import { useCallback, useDeferredValue, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { setTargetCommentIdAtom } from "@/atoms/chat/current-thread.atom";
|
||||
import { convertRenderedToDisplay } from "@/components/chat-comments/comment-item/comment-item";
|
||||
import { getDocumentTypeLabel } from "@/components/documents/DocumentTypeIcon";
|
||||
import { convertRenderedToDisplay } from "@/lib/comments/utils";
|
||||
import { getDocumentTypeLabel } from "@/lib/documents/document-type-labels";
|
||||
import { Tabs, TabsList, TabsTrigger } from "@/components/ui/animated-tabs";
|
||||
import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar";
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@
|
|||
import { AnimatePresence, motion } from "motion/react";
|
||||
import { useCallback, useEffect } from "react";
|
||||
import { useMediaQuery } from "@/hooks/use-media-query";
|
||||
import { SLIDEOUT_PANEL_OPENED_EVENT } from "@/lib/layout-events";
|
||||
|
||||
export const SLIDEOUT_PANEL_OPENED_EVENT = "slideout-panel-opened";
|
||||
|
||||
interface SidebarSlideOutPanelProps {
|
||||
open: boolean;
|
||||
|
|
|
|||
|
|
@ -44,21 +44,28 @@ export function ChatHeader({ searchSpaceId, className }: ChatHeaderProps) {
|
|||
const [isVisionGlobal, setIsVisionGlobal] = useState(false);
|
||||
const [visionDialogMode, setVisionDialogMode] = useState<"create" | "edit" | "view">("view");
|
||||
|
||||
// Default provider for create dialogs
|
||||
const [defaultLLMProvider, setDefaultLLMProvider] = useState<string | undefined>();
|
||||
const [defaultImageProvider, setDefaultImageProvider] = useState<string | undefined>();
|
||||
const [defaultVisionProvider, setDefaultVisionProvider] = useState<string | undefined>();
|
||||
|
||||
// LLM handlers
|
||||
const handleEditLLMConfig = useCallback(
|
||||
(config: NewLLMConfigPublic | GlobalNewLLMConfig, global: boolean) => {
|
||||
setSelectedConfig(config);
|
||||
setIsGlobal(global);
|
||||
setDialogMode(global ? "view" : "edit");
|
||||
setDefaultLLMProvider(undefined);
|
||||
setDialogOpen(true);
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
const handleAddNewLLM = useCallback(() => {
|
||||
const handleAddNewLLM = useCallback((provider?: string) => {
|
||||
setSelectedConfig(null);
|
||||
setIsGlobal(false);
|
||||
setDialogMode("create");
|
||||
setDefaultLLMProvider(provider);
|
||||
setDialogOpen(true);
|
||||
}, []);
|
||||
|
||||
|
|
@ -68,10 +75,11 @@ export function ChatHeader({ searchSpaceId, className }: ChatHeaderProps) {
|
|||
}, []);
|
||||
|
||||
// Image model handlers
|
||||
const handleAddImageModel = useCallback(() => {
|
||||
const handleAddImageModel = useCallback((provider?: string) => {
|
||||
setSelectedImageConfig(null);
|
||||
setIsImageGlobal(false);
|
||||
setImageDialogMode("create");
|
||||
setDefaultImageProvider(provider);
|
||||
setImageDialogOpen(true);
|
||||
}, []);
|
||||
|
||||
|
|
@ -80,6 +88,7 @@ export function ChatHeader({ searchSpaceId, className }: ChatHeaderProps) {
|
|||
setSelectedImageConfig(config);
|
||||
setIsImageGlobal(global);
|
||||
setImageDialogMode(global ? "view" : "edit");
|
||||
setDefaultImageProvider(undefined);
|
||||
setImageDialogOpen(true);
|
||||
},
|
||||
[]
|
||||
|
|
@ -91,10 +100,11 @@ export function ChatHeader({ searchSpaceId, className }: ChatHeaderProps) {
|
|||
}, []);
|
||||
|
||||
// Vision model handlers
|
||||
const handleAddVisionModel = useCallback(() => {
|
||||
const handleAddVisionModel = useCallback((provider?: string) => {
|
||||
setSelectedVisionConfig(null);
|
||||
setIsVisionGlobal(false);
|
||||
setVisionDialogMode("create");
|
||||
setDefaultVisionProvider(provider);
|
||||
setVisionDialogOpen(true);
|
||||
}, []);
|
||||
|
||||
|
|
@ -103,6 +113,7 @@ export function ChatHeader({ searchSpaceId, className }: ChatHeaderProps) {
|
|||
setSelectedVisionConfig(config);
|
||||
setIsVisionGlobal(global);
|
||||
setVisionDialogMode(global ? "view" : "edit");
|
||||
setDefaultVisionProvider(undefined);
|
||||
setVisionDialogOpen(true);
|
||||
},
|
||||
[]
|
||||
|
|
@ -131,6 +142,7 @@ export function ChatHeader({ searchSpaceId, className }: ChatHeaderProps) {
|
|||
isGlobal={isGlobal}
|
||||
searchSpaceId={searchSpaceId}
|
||||
mode={dialogMode}
|
||||
defaultProvider={defaultLLMProvider}
|
||||
/>
|
||||
<ImageConfigDialog
|
||||
open={imageDialogOpen}
|
||||
|
|
@ -139,6 +151,7 @@ export function ChatHeader({ searchSpaceId, className }: ChatHeaderProps) {
|
|||
isGlobal={isImageGlobal}
|
||||
searchSpaceId={searchSpaceId}
|
||||
mode={imageDialogMode}
|
||||
defaultProvider={defaultImageProvider}
|
||||
/>
|
||||
<VisionConfigDialog
|
||||
open={visionDialogOpen}
|
||||
|
|
@ -147,6 +160,7 @@ export function ChatHeader({ searchSpaceId, className }: ChatHeaderProps) {
|
|||
isGlobal={isVisionGlobal}
|
||||
searchSpaceId={searchSpaceId}
|
||||
mode={visionDialogMode}
|
||||
defaultProvider={defaultVisionProvider}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -29,8 +29,6 @@ interface DocumentMentionPickerProps {
|
|||
onDone: () => void;
|
||||
initialSelectedDocuments?: Pick<Document, "id" | "title" | "document_type">[];
|
||||
externalSearch?: string;
|
||||
/** Positioning styles for the container */
|
||||
containerStyle?: React.CSSProperties;
|
||||
}
|
||||
|
||||
const PAGE_SIZE = 20;
|
||||
|
|
@ -75,7 +73,6 @@ export const DocumentMentionPicker = forwardRef<
|
|||
onDone,
|
||||
initialSelectedDocuments = [],
|
||||
externalSearch = "",
|
||||
containerStyle,
|
||||
},
|
||||
ref
|
||||
) {
|
||||
|
|
@ -394,19 +391,9 @@ export const DocumentMentionPicker = forwardRef<
|
|||
[selectableDocuments, highlightedIndex, handleSelectDocument, onDone]
|
||||
);
|
||||
|
||||
// Hide popup when there are no documents to display (regardless of fetch state)
|
||||
// Search continues in background; popup reappears when results arrive
|
||||
if (!actualLoading && actualDocuments.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className="fixed shadow-2xl rounded-lg border border-border dark:border-white/5 overflow-hidden bg-popover dark:bg-neutral-900 flex flex-col w-[280px] sm:w-[320px] select-none"
|
||||
style={{
|
||||
zIndex: 9999,
|
||||
...containerStyle,
|
||||
}}
|
||||
className="shadow-2xl rounded-lg border border-border dark:border-white/5 overflow-hidden bg-popover dark:bg-neutral-900 flex flex-col w-[280px] sm:w-[320px] select-none"
|
||||
onKeyDown={handleKeyDown}
|
||||
role="listbox"
|
||||
tabIndex={-1}
|
||||
|
|
@ -547,7 +534,11 @@ export const DocumentMentionPicker = forwardRef<
|
|||
</div>
|
||||
)}
|
||||
</div>
|
||||
) : null}
|
||||
) : (
|
||||
<div className="py-1 px-2">
|
||||
<p className="px-3 py-2 text-xs text-muted-foreground">No matching documents</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -15,7 +15,7 @@ import {
|
|||
|
||||
import { promptsAtom } from "@/atoms/prompts/prompts-query.atoms";
|
||||
import { userSettingsDialogAtom } from "@/atoms/settings/settings-dialog.atoms";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export interface PromptPickerRef {
|
||||
|
|
@ -28,11 +28,10 @@ interface PromptPickerProps {
|
|||
onSelect: (action: { name: string; prompt: string; mode: "transform" | "explore" }) => void;
|
||||
onDone: () => void;
|
||||
externalSearch?: string;
|
||||
containerStyle?: React.CSSProperties;
|
||||
}
|
||||
|
||||
export const PromptPicker = forwardRef<PromptPickerRef, PromptPickerProps>(function PromptPicker(
|
||||
{ onSelect, onDone, externalSearch = "", containerStyle },
|
||||
{ onSelect, onDone, externalSearch = "" },
|
||||
ref
|
||||
) {
|
||||
const setUserSettingsDialog = useSetAtom(userSettingsDialogAtom);
|
||||
|
|
@ -60,13 +59,21 @@ export const PromptPicker = forwardRef<PromptPickerRef, PromptPickerProps>(funct
|
|||
}
|
||||
}
|
||||
|
||||
const createPromptIndex = filtered.length;
|
||||
const totalItems = filtered.length + 1;
|
||||
|
||||
const handleSelect = useCallback(
|
||||
(index: number) => {
|
||||
if (index === createPromptIndex) {
|
||||
onDone();
|
||||
setUserSettingsDialog({ open: true, initialTab: "prompts" });
|
||||
return;
|
||||
}
|
||||
const action = filtered[index];
|
||||
if (!action) return;
|
||||
onSelect({ name: action.name, prompt: action.prompt, mode: action.mode });
|
||||
},
|
||||
[filtered, onSelect]
|
||||
[filtered, onSelect, createPromptIndex, onDone, setUserSettingsDialog]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
|
|
@ -93,69 +100,98 @@ export const PromptPicker = forwardRef<PromptPickerRef, PromptPickerProps>(funct
|
|||
() => ({
|
||||
selectHighlighted: () => handleSelect(highlightedIndex),
|
||||
moveUp: () => {
|
||||
if (filtered.length === 0) return;
|
||||
shouldScrollRef.current = true;
|
||||
setHighlightedIndex((prev) => (prev > 0 ? prev - 1 : filtered.length - 1));
|
||||
setHighlightedIndex((prev) => (prev > 0 ? prev - 1 : totalItems - 1));
|
||||
},
|
||||
moveDown: () => {
|
||||
if (filtered.length === 0) return;
|
||||
shouldScrollRef.current = true;
|
||||
setHighlightedIndex((prev) => (prev < filtered.length - 1 ? prev + 1 : 0));
|
||||
setHighlightedIndex((prev) => (prev < totalItems - 1 ? prev + 1 : 0));
|
||||
},
|
||||
}),
|
||||
[filtered.length, highlightedIndex, handleSelect]
|
||||
[totalItems, highlightedIndex, handleSelect]
|
||||
);
|
||||
|
||||
return (
|
||||
<div
|
||||
className="w-64 rounded-lg border bg-popover shadow-lg overflow-hidden"
|
||||
style={containerStyle}
|
||||
>
|
||||
<div ref={scrollContainerRef} className="max-h-48 overflow-y-auto py-1">
|
||||
<div className="shadow-2xl rounded-lg border border-border dark:border-white/5 overflow-hidden bg-popover dark:bg-neutral-900 flex flex-col w-[280px] sm:w-[320px] select-none">
|
||||
<div ref={scrollContainerRef} className="max-h-[180px] sm:max-h-[280px] overflow-y-auto">
|
||||
{isLoading ? (
|
||||
<div className="flex items-center justify-center py-3">
|
||||
<Spinner className="size-4" />
|
||||
<div className="py-1 px-2">
|
||||
<div className="px-3 py-2">
|
||||
<Skeleton className="h-[16px] w-24" />
|
||||
</div>
|
||||
{["a", "b", "c", "d", "e"].map((id, i) => (
|
||||
<div
|
||||
key={id}
|
||||
className={cn(
|
||||
"w-full flex items-center gap-2 px-3 py-2 text-left rounded-md",
|
||||
i >= 3 && "hidden sm:flex"
|
||||
)}
|
||||
>
|
||||
<span className="shrink-0">
|
||||
<Skeleton className="h-4 w-4" />
|
||||
</span>
|
||||
<span className="flex-1 text-sm">
|
||||
<Skeleton className="h-[20px]" style={{ width: `${60 + ((i * 7) % 30)}%` }} />
|
||||
</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
) : isError ? (
|
||||
<p className="px-3 py-2 text-xs text-destructive">Failed to load prompts</p>
|
||||
<div className="py-1 px-2">
|
||||
<p className="px-3 py-2 text-xs text-destructive">Failed to load prompts</p>
|
||||
</div>
|
||||
) : filtered.length === 0 ? (
|
||||
<p className="px-3 py-2 text-xs text-muted-foreground">No matching prompts</p>
|
||||
<div className="py-1 px-2">
|
||||
<p className="px-3 py-2 text-xs text-muted-foreground">No matching prompts</p>
|
||||
</div>
|
||||
) : (
|
||||
filtered.map((action, index) => (
|
||||
<div className="py-1 px-2">
|
||||
<div className="px-3 py-2 text-xs font-bold text-muted-foreground/55">
|
||||
Saved Prompts
|
||||
</div>
|
||||
{filtered.map((action, index) => (
|
||||
<button
|
||||
key={action.id}
|
||||
ref={(el) => {
|
||||
if (el) itemRefs.current.set(index, el);
|
||||
else itemRefs.current.delete(index);
|
||||
}}
|
||||
type="button"
|
||||
onClick={() => handleSelect(index)}
|
||||
onMouseEnter={() => setHighlightedIndex(index)}
|
||||
className={cn(
|
||||
"w-full flex items-center gap-2 px-3 py-2 text-left text-sm transition-colors rounded-md cursor-pointer",
|
||||
index === highlightedIndex && "bg-accent"
|
||||
)}
|
||||
>
|
||||
<span className="shrink-0 text-muted-foreground">
|
||||
<Zap className="size-4" />
|
||||
</span>
|
||||
<span className="flex-1 text-sm truncate">{action.name}</span>
|
||||
</button>
|
||||
))}
|
||||
|
||||
<div className="mx-2 my-1 border-t border-border dark:border-white/5" />
|
||||
<button
|
||||
key={action.id}
|
||||
ref={(el) => {
|
||||
if (el) itemRefs.current.set(index, el);
|
||||
else itemRefs.current.delete(index);
|
||||
if (el) itemRefs.current.set(createPromptIndex, el);
|
||||
else itemRefs.current.delete(createPromptIndex);
|
||||
}}
|
||||
type="button"
|
||||
onClick={() => handleSelect(index)}
|
||||
onMouseEnter={() => setHighlightedIndex(index)}
|
||||
onClick={() => handleSelect(createPromptIndex)}
|
||||
onMouseEnter={() => setHighlightedIndex(createPromptIndex)}
|
||||
className={cn(
|
||||
"flex w-full items-center gap-2 px-3 py-1.5 text-sm cursor-pointer",
|
||||
index === highlightedIndex ? "bg-accent" : "hover:bg-accent/50"
|
||||
"w-full flex items-center gap-2 px-3 py-2 text-left text-sm transition-colors rounded-md cursor-pointer text-muted-foreground",
|
||||
highlightedIndex === createPromptIndex ? "bg-accent text-foreground" : "hover:text-foreground hover:bg-accent/50"
|
||||
)}
|
||||
>
|
||||
<span className="text-muted-foreground">
|
||||
<Zap className="size-3.5" />
|
||||
<span className="shrink-0">
|
||||
<Plus className="size-4" />
|
||||
</span>
|
||||
<span className="truncate">{action.name}</span>
|
||||
<span>Create prompt</span>
|
||||
</button>
|
||||
))
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="my-1 h-px bg-border mx-2" />
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
onDone();
|
||||
setUserSettingsDialog({ open: true, initialTab: "prompts" });
|
||||
}}
|
||||
className="flex w-full items-center gap-2 px-3 py-1.5 text-sm text-muted-foreground hover:text-foreground hover:bg-accent/50 cursor-pointer"
|
||||
>
|
||||
<Plus className="size-3.5" />
|
||||
<span>Create prompt</span>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -1,4 +1,9 @@
|
|||
"use client";
|
||||
import React, { useRef, useEffect, useState } from "react";
|
||||
import { AnimatePresence, motion } from "motion/react";
|
||||
import { IconPlus } from "@tabler/icons-react";
|
||||
import { Pricing } from "@/components/pricing";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const demoPlans = [
|
||||
{
|
||||
|
|
@ -59,13 +64,280 @@ const demoPlans = [
|
|||
},
|
||||
];
|
||||
|
||||
interface FAQItem {
|
||||
question: string;
|
||||
answer: string;
|
||||
}
|
||||
|
||||
interface FAQSection {
|
||||
title: string;
|
||||
items: FAQItem[];
|
||||
}
|
||||
|
||||
const faqData: FAQSection[] = [
|
||||
{
|
||||
title: "Pages & Billing",
|
||||
items: [
|
||||
{
|
||||
question: "What exactly is a \"page\" in SurfSense?",
|
||||
answer:
|
||||
"A page is a simple billing unit that measures how much content you add to your knowledge base. For PDFs, one page equals one real PDF page. For other document types like Word, PowerPoint, and Excel files, pages are automatically estimated based on the file. Every file uses at least 1 page.",
|
||||
},
|
||||
{
|
||||
question: "How does the Pay As You Go plan work?",
|
||||
answer:
|
||||
"There's no monthly subscription. When you need more pages, simply purchase 1,000-page packs at $1 each. Purchased pages are added to your account immediately so you can keep indexing right away. You only pay when you actually need more.",
|
||||
},
|
||||
{
|
||||
question: "What happens if I run out of pages?",
|
||||
answer:
|
||||
"SurfSense checks your remaining pages before processing each file. If you don't have enough, the upload is paused and you'll be notified. You can purchase additional page packs at any time to continue. For cloud connector syncs, a small overage may be allowed so your sync doesn't partially fail.",
|
||||
},
|
||||
{
|
||||
question: "If I delete a document, do I get my pages back?",
|
||||
answer:
|
||||
"No. Deleting a document removes it from your knowledge base, but the pages it used are not refunded. Pages track your total usage over time, not how much is currently stored. So be mindful of what you index. Once pages are spent, they're spent even if you later remove the document.",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
title: "File Types & Connectors",
|
||||
items: [
|
||||
{
|
||||
question: "Which file types count toward my page limit?",
|
||||
answer:
|
||||
"Page limits only apply to document files that need processing, including PDFs, Word documents (DOC, DOCX, ODT, RTF), presentations (PPT, PPTX, ODP), spreadsheets (XLS, XLSX, ODS), ebooks (EPUB), and images (JPG, PNG, TIFF, WebP, BMP). Plain text files, code files, Markdown, CSV, TSV, HTML, audio, and video files do not consume any pages.",
|
||||
},
|
||||
{
|
||||
question: "How are pages consumed?",
|
||||
answer:
|
||||
"Pages are deducted whenever a document file is successfully indexed into your knowledge base, whether through direct uploads or file-based connector syncs (Google Drive, OneDrive, Dropbox, Local Folder). SurfSense checks your remaining pages before processing and only charges you after the file is indexed. Duplicate documents are automatically detected and won't cost you extra pages.",
|
||||
},
|
||||
{
|
||||
question: "Do connectors like Slack, Notion, or Gmail use pages?",
|
||||
answer:
|
||||
"No. Connectors that work with structured text data like Slack, Discord, Notion, Confluence, Jira, Linear, ClickUp, GitHub, Gmail, Google Calendar, Microsoft Teams, Airtable, Elasticsearch, Web Crawler, BookStack, Obsidian, and Luma do not use pages at all. Page limits only apply to file-based connectors that need document processing, such as Google Drive, OneDrive, Dropbox, and Local Folder syncs.",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
title: "Self-Hosting",
|
||||
items: [
|
||||
{
|
||||
question: "Can I self-host SurfSense with unlimited pages?",
|
||||
answer:
|
||||
"Yes! When self-hosting, you have full control over your page limits. The default self-hosted setup gives you effectively unlimited pages, so you can index as much data as your infrastructure supports.",
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const GridLineHorizontal = ({
|
||||
className,
|
||||
offset,
|
||||
}: {
|
||||
className?: string;
|
||||
offset?: string;
|
||||
}) => {
|
||||
return (
|
||||
<div
|
||||
style={
|
||||
{
|
||||
"--background": "#ffffff",
|
||||
"--color": "rgba(0, 0, 0, 0.2)",
|
||||
"--height": "1px",
|
||||
"--width": "5px",
|
||||
"--fade-stop": "90%",
|
||||
"--offset": offset || "200px",
|
||||
"--color-dark": "rgba(255, 255, 255, 0.2)",
|
||||
maskComposite: "exclude",
|
||||
} as React.CSSProperties
|
||||
}
|
||||
className={cn(
|
||||
"[--background:var(--color-neutral-200)] [--color:var(--color-neutral-400)] dark:[--background:var(--color-neutral-800)] dark:[--color:var(--color-neutral-600)]",
|
||||
"absolute left-[calc(var(--offset)/2*-1)] h-(--height) w-[calc(100%+var(--offset))]",
|
||||
"bg-[linear-gradient(to_right,var(--color),var(--color)_50%,transparent_0,transparent)]",
|
||||
"bg-size-[var(--width)_var(--height)]",
|
||||
"[mask:linear-gradient(to_left,var(--background)_var(--fade-stop),transparent),linear-gradient(to_right,var(--background)_var(--fade-stop),transparent),linear-gradient(black,black)]",
|
||||
"mask-exclude",
|
||||
"z-30",
|
||||
"dark:bg-[linear-gradient(to_right,var(--color-dark),var(--color-dark)_50%,transparent_0,transparent)]",
|
||||
className,
|
||||
)}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
const GridLineVertical = ({
|
||||
className,
|
||||
offset,
|
||||
}: {
|
||||
className?: string;
|
||||
offset?: string;
|
||||
}) => {
|
||||
return (
|
||||
<div
|
||||
style={
|
||||
{
|
||||
"--background": "#ffffff",
|
||||
"--color": "rgba(0, 0, 0, 0.2)",
|
||||
"--height": "5px",
|
||||
"--width": "1px",
|
||||
"--fade-stop": "90%",
|
||||
"--offset": offset || "150px",
|
||||
"--color-dark": "rgba(255, 255, 255, 0.2)",
|
||||
maskComposite: "exclude",
|
||||
} as React.CSSProperties
|
||||
}
|
||||
className={cn(
|
||||
"absolute top-[calc(var(--offset)/2*-1)] h-[calc(100%+var(--offset))] w-(--width)",
|
||||
"bg-[linear-gradient(to_bottom,var(--color),var(--color)_50%,transparent_0,transparent)]",
|
||||
"bg-size-[var(--width)_var(--height)]",
|
||||
"[mask:linear-gradient(to_top,var(--background)_var(--fade-stop),transparent),linear-gradient(to_bottom,var(--background)_var(--fade-stop),transparent),linear-gradient(black,black)]",
|
||||
"mask-exclude",
|
||||
"z-30",
|
||||
"dark:bg-[linear-gradient(to_bottom,var(--color-dark),var(--color-dark)_50%,transparent_0,transparent)]",
|
||||
className,
|
||||
)}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
function PricingFAQ() {
|
||||
const [activeId, setActiveId] = useState<string | null>(null);
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
function handleClickOutside(event: MouseEvent) {
|
||||
if (
|
||||
containerRef.current &&
|
||||
!containerRef.current.contains(event.target as Node)
|
||||
) {
|
||||
setActiveId(null);
|
||||
}
|
||||
}
|
||||
|
||||
document.addEventListener("mousedown", handleClickOutside);
|
||||
return () => document.removeEventListener("mousedown", handleClickOutside);
|
||||
}, []);
|
||||
|
||||
const toggleQuestion = (id: string) => {
|
||||
setActiveId(activeId === id ? null : id);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="mx-auto w-full max-w-4xl overflow-hidden px-4 py-20 md:px-8 md:py-32">
|
||||
<div className="text-center">
|
||||
<h2 className="text-4xl font-bold tracking-tight sm:text-5xl">
|
||||
Frequently Asked Questions
|
||||
</h2>
|
||||
<p className="mx-auto mt-4 max-w-2xl text-lg text-muted-foreground">
|
||||
Everything you need to know about SurfSense pages and billing.
|
||||
Can't find what you need? Reach out at{" "}
|
||||
<a
|
||||
href="mailto:rohan@surfsense.com"
|
||||
className="text-blue-500 underline"
|
||||
>
|
||||
rohan@surfsense.com
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div
|
||||
ref={containerRef}
|
||||
className="relative mt-16 flex w-full flex-col gap-12 px-4 md:px-8"
|
||||
>
|
||||
{faqData.map((section) => (
|
||||
<div key={section.title + "faq"}>
|
||||
<h3 className="mb-6 text-lg font-medium text-neutral-800 dark:text-neutral-200">
|
||||
{section.title}
|
||||
</h3>
|
||||
<div className="flex flex-col gap-3">
|
||||
{section.items.map((item, index) => {
|
||||
const id = `${section.title}-${index}`;
|
||||
const isActive = activeId === id;
|
||||
|
||||
return (
|
||||
<div
|
||||
key={id + "faq-item"}
|
||||
className={cn(
|
||||
"relative rounded-lg transition-all duration-200",
|
||||
isActive
|
||||
? "bg-white shadow-sm ring-1 shadow-black/10 ring-black/10 dark:bg-neutral-900 dark:shadow-white/5 dark:ring-white/10"
|
||||
: "hover:bg-neutral-50 dark:hover:bg-neutral-900",
|
||||
)}
|
||||
>
|
||||
{isActive && (
|
||||
<div className="absolute inset-0">
|
||||
<GridLineHorizontal
|
||||
className="-top-[2px]"
|
||||
offset="100px"
|
||||
/>
|
||||
<GridLineHorizontal
|
||||
className="-bottom-[2px]"
|
||||
offset="100px"
|
||||
/>
|
||||
<GridLineVertical
|
||||
className="-left-[2px]"
|
||||
offset="100px"
|
||||
/>
|
||||
<GridLineVertical
|
||||
className="-right-[2px] left-auto"
|
||||
offset="100px"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<button
|
||||
onClick={() => toggleQuestion(id)}
|
||||
className="flex w-full items-center justify-between px-4 py-4 text-left"
|
||||
>
|
||||
<span className="text-sm font-medium text-neutral-700 md:text-base dark:text-neutral-300">
|
||||
{item.question}
|
||||
</span>
|
||||
<motion.div
|
||||
animate={{ rotate: isActive ? 45 : 0 }}
|
||||
transition={{ duration: 0.2 }}
|
||||
className="ml-4 shrink-0"
|
||||
>
|
||||
<IconPlus className="size-5 text-neutral-500 dark:text-neutral-400" />
|
||||
</motion.div>
|
||||
</button>
|
||||
<AnimatePresence initial={false}>
|
||||
{isActive && (
|
||||
<motion.div
|
||||
initial={{ height: 0, opacity: 0 }}
|
||||
animate={{ height: "auto", opacity: 1 }}
|
||||
exit={{ height: 0, opacity: 0 }}
|
||||
transition={{ duration: 0.15, ease: "easeInOut" }}
|
||||
className="relative"
|
||||
>
|
||||
<p className="max-w-[90%] px-4 pb-4 text-sm text-neutral-600 dark:text-neutral-400">
|
||||
{item.answer}
|
||||
</p>
|
||||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function PricingBasic() {
|
||||
return (
|
||||
<Pricing
|
||||
plans={demoPlans}
|
||||
title="SurfSense Pricing"
|
||||
description="Start free with 500 pages and pay as you go."
|
||||
/>
|
||||
<>
|
||||
<Pricing
|
||||
plans={demoPlans}
|
||||
title="SurfSense Pricing"
|
||||
description="Start free with 500 pages and pay as you go."
|
||||
/>
|
||||
<PricingFAQ />
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import {
|
|||
MessageSquareQuote,
|
||||
RefreshCw,
|
||||
Trash2,
|
||||
Wand2,
|
||||
} from "lucide-react";
|
||||
import { useMemo, useState } from "react";
|
||||
import { membersAtom, myAccessAtom } from "@/atoms/members/members-query.atoms";
|
||||
|
|
@ -43,7 +42,7 @@ import { useMediaQuery } from "@/hooks/use-media-query";
|
|||
import { getProviderIcon } from "@/lib/provider-icons";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface ModelConfigManagerProps {
|
||||
interface AgentModelManagerProps {
|
||||
searchSpaceId: number;
|
||||
}
|
||||
|
||||
|
|
@ -55,7 +54,7 @@ function getInitials(name: string): string {
|
|||
return name.slice(0, 2).toUpperCase();
|
||||
}
|
||||
|
||||
export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) {
|
||||
export function AgentModelManager({ searchSpaceId }: AgentModelManagerProps) {
|
||||
const isDesktop = useMediaQuery("(min-width: 768px)");
|
||||
// Mutations
|
||||
const { mutateAsync: deleteConfig, isPending: isDeleting } = useAtomValue(
|
||||
|
|
@ -208,28 +207,26 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) {
|
|||
{["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => (
|
||||
<Card key={key} className="border-border/60">
|
||||
<CardContent className="p-4 flex flex-col gap-3">
|
||||
{/* Header */}
|
||||
<div className="flex items-start justify-between gap-2">
|
||||
{/* Header: Icon + Name */}
|
||||
<div className="flex items-start gap-2.5">
|
||||
<Skeleton className="size-4 rounded-full shrink-0 mt-0.5" />
|
||||
<div className="space-y-1.5 flex-1 min-w-0">
|
||||
<Skeleton className="h-4 w-28 md:w-32" />
|
||||
<Skeleton className="h-3 w-40 md:w-48" />
|
||||
</div>
|
||||
</div>
|
||||
{/* Provider + Model */}
|
||||
<div className="flex items-center gap-2">
|
||||
<Skeleton className="h-5 w-16 rounded-full" />
|
||||
<Skeleton className="h-5 w-24 rounded-md" />
|
||||
</div>
|
||||
{/* Feature badges */}
|
||||
<div className="flex items-center gap-1.5">
|
||||
<Skeleton className="h-5 w-20 rounded-full" />
|
||||
<Skeleton className="h-5 w-16 rounded-full" />
|
||||
</div>
|
||||
{/* Footer */}
|
||||
<div className="flex items-center gap-2 pt-2 border-t border-border/40">
|
||||
<Skeleton className="h-3 w-20" />
|
||||
<Skeleton className="h-4 w-4 rounded-full" />
|
||||
<Skeleton className="h-3 w-16" />
|
||||
<div className="flex items-center pt-2 border-t border-border/40">
|
||||
<Skeleton className="h-3 w-20 flex-1" />
|
||||
<Skeleton className="h-3 w-3 rounded-full shrink-0 mx-1" />
|
||||
<div className="flex-1 flex items-center justify-end gap-1.5">
|
||||
<Skeleton className="h-4 w-4 rounded-full" />
|
||||
<Skeleton className="h-3 w-16" />
|
||||
</div>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
|
@ -262,20 +259,25 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) {
|
|||
<div key={config.id}>
|
||||
<Card className="group relative overflow-hidden transition-all duration-200 border-border/60 hover:shadow-md h-full">
|
||||
<CardContent className="p-4 flex flex-col gap-3 h-full">
|
||||
{/* Header: Name + Actions */}
|
||||
<div className="flex items-start justify-between gap-2">
|
||||
<div className="min-w-0 flex-1">
|
||||
<h4 className="text-sm font-semibold tracking-tight truncate">
|
||||
{config.name}
|
||||
</h4>
|
||||
{config.description && (
|
||||
<p className="text-[11px] text-muted-foreground/70 truncate mt-0.5">
|
||||
{config.description}
|
||||
</p>
|
||||
)}
|
||||
{/* Header: Icon + Name + Actions */}
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<div className="flex items-center gap-2.5 min-w-0 flex-1">
|
||||
<div className="shrink-0">
|
||||
{getProviderIcon(config.provider, { className: "size-4" })}
|
||||
</div>
|
||||
<div className="min-w-0 flex-1">
|
||||
<h4 className="text-sm font-semibold tracking-tight truncate">
|
||||
{config.name}
|
||||
</h4>
|
||||
{config.description && (
|
||||
<p className="text-[11px] text-muted-foreground/70 truncate mt-0.5">
|
||||
{config.description}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
{(canUpdate || canDelete) && (
|
||||
<div className="flex items-center gap-0.5 shrink-0 sm:opacity-0 sm:group-hover:opacity-100 transition-opacity duration-150">
|
||||
<div className="flex items-center gap-1 shrink-0 sm:w-0 sm:overflow-hidden sm:group-hover:w-auto sm:opacity-0 sm:group-hover:opacity-100 transition-all duration-150">
|
||||
{canUpdate && (
|
||||
<TooltipProvider>
|
||||
<Tooltip open={isDesktop ? undefined : false}>
|
||||
|
|
@ -284,7 +286,7 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) {
|
|||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={() => openEditDialog(config)}
|
||||
className="h-7 w-7 text-muted-foreground hover:text-foreground"
|
||||
className="h-7 w-7 rounded-lg text-muted-foreground hover:text-foreground"
|
||||
>
|
||||
<Edit3 className="h-3 w-3" />
|
||||
</Button>
|
||||
|
|
@ -301,7 +303,7 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) {
|
|||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={() => setConfigToDelete(config)}
|
||||
className="h-7 w-7 text-muted-foreground hover:text-destructive"
|
||||
className="h-7 w-7 rounded-lg text-muted-foreground hover:text-destructive"
|
||||
>
|
||||
<Trash2 className="h-3 w-3" />
|
||||
</Button>
|
||||
|
|
@ -314,20 +316,12 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) {
|
|||
)}
|
||||
</div>
|
||||
|
||||
{/* Provider + Model */}
|
||||
<div className="flex items-center gap-2 flex-wrap">
|
||||
{getProviderIcon(config.provider, { className: "size-3.5 shrink-0" })}
|
||||
<code className="text-[11px] font-mono text-muted-foreground bg-muted/60 px-2 py-0.5 rounded-md truncate max-w-[160px]">
|
||||
{config.model_name}
|
||||
</code>
|
||||
</div>
|
||||
|
||||
{/* Feature badges */}
|
||||
<div className="flex items-center gap-1.5 flex-wrap">
|
||||
{config.citations_enabled && (
|
||||
<Badge
|
||||
variant="outline"
|
||||
className="text-[10px] px-1.5 py-0.5 border-emerald-500/30 text-emerald-700 dark:text-emerald-300 bg-emerald-500/5"
|
||||
variant="secondary"
|
||||
className="text-[10px] px-1.5 py-0.5 border-0 text-muted-foreground bg-muted"
|
||||
>
|
||||
<MessageSquareQuote className="h-2.5 w-2.5 mr-1" />
|
||||
Citations
|
||||
|
|
@ -336,8 +330,8 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) {
|
|||
{!config.use_default_system_instructions &&
|
||||
config.system_instructions && (
|
||||
<Badge
|
||||
variant="outline"
|
||||
className="text-[10px] px-1.5 py-0.5 border-blue-500/30 text-blue-700 dark:text-blue-300 bg-blue-500/5"
|
||||
variant="secondary"
|
||||
className="text-[10px] px-1.5 py-0.5 border-0 text-muted-foreground bg-muted"
|
||||
>
|
||||
<FileText className="h-2.5 w-2.5 mr-1" />
|
||||
Custom
|
||||
|
|
@ -346,8 +340,8 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) {
|
|||
</div>
|
||||
|
||||
{/* Footer: Date + Creator */}
|
||||
<div className="flex items-center gap-2 pt-2 border-t border-border/40 mt-auto">
|
||||
<span className="text-[11px] text-muted-foreground/60">
|
||||
<div className="flex items-center pt-2 border-t border-border/40 mt-auto">
|
||||
<span className="shrink-0 text-[11px] text-muted-foreground/60 whitespace-nowrap">
|
||||
{new Date(config.created_at).toLocaleDateString(undefined, {
|
||||
year: "numeric",
|
||||
month: "short",
|
||||
|
|
@ -356,11 +350,11 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) {
|
|||
</span>
|
||||
{member && (
|
||||
<>
|
||||
<Dot className="h-4 w-4 text-muted-foreground/30" />
|
||||
<Dot className="h-4 w-4 text-muted-foreground/30 shrink-0" />
|
||||
<TooltipProvider>
|
||||
<Tooltip open={isDesktop ? undefined : false}>
|
||||
<TooltipTrigger asChild>
|
||||
<div className="flex items-center gap-1.5 cursor-default">
|
||||
<div className="min-w-0 flex items-center gap-1.5 cursor-default">
|
||||
<Avatar className="size-4.5 shrink-0">
|
||||
{member.avatarUrl && (
|
||||
<AvatarImage src={member.avatarUrl} alt={member.name} />
|
||||
|
|
@ -369,7 +363,7 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) {
|
|||
{getInitials(member.name)}
|
||||
</AvatarFallback>
|
||||
</Avatar>
|
||||
<span className="text-[11px] text-muted-foreground/60 truncate max-w-[120px]">
|
||||
<span className="text-[11px] text-muted-foreground/60 truncate">
|
||||
{member.name}
|
||||
</span>
|
||||
</div>
|
||||
|
|
@ -2,18 +2,18 @@
|
|||
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import { useAtomValue } from "jotai";
|
||||
import { Info } from "lucide-react";
|
||||
import { FolderArchive, Info } from "lucide-react";
|
||||
import { useTranslations } from "next-intl";
|
||||
import { useEffect, useState } from "react";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { updateSearchSpaceMutationAtom } from "@/atoms/search-spaces/search-space-mutation.atoms";
|
||||
import { Alert, AlertDescription } from "@/components/ui/alert";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service";
|
||||
import { authenticatedFetch } from "@/lib/auth-utils";
|
||||
import { cacheKeys } from "@/lib/query-client/cache-keys";
|
||||
import { Spinner } from "../ui/spinner";
|
||||
|
||||
|
|
@ -40,6 +40,37 @@ export function GeneralSettingsManager({ searchSpaceId }: GeneralSettingsManager
|
|||
const [name, setName] = useState("");
|
||||
const [description, setDescription] = useState("");
|
||||
const [saving, setSaving] = useState(false);
|
||||
const [isExporting, setIsExporting] = useState(false);
|
||||
|
||||
const handleExportKB = useCallback(async () => {
|
||||
if (isExporting) return;
|
||||
setIsExporting(true);
|
||||
try {
|
||||
const response = await authenticatedFetch(
|
||||
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/export`,
|
||||
{ method: "GET" }
|
||||
);
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json().catch(() => ({ detail: "Export failed" }));
|
||||
throw new Error(errorData.detail || "Export failed");
|
||||
}
|
||||
const blob = await response.blob();
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement("a");
|
||||
a.href = url;
|
||||
a.download = "knowledge-base.zip";
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(url);
|
||||
toast.success("Knowledge base exported");
|
||||
} catch (err) {
|
||||
console.error("KB export failed:", err);
|
||||
toast.error(err instanceof Error ? err.message : "Export failed");
|
||||
} finally {
|
||||
setIsExporting(false);
|
||||
}
|
||||
}, [searchSpaceId, isExporting]);
|
||||
|
||||
// Initialize state from fetched search space
|
||||
useEffect(() => {
|
||||
|
|
@ -83,16 +114,10 @@ export function GeneralSettingsManager({ searchSpaceId }: GeneralSettingsManager
|
|||
if (loading) {
|
||||
return (
|
||||
<div className="space-y-4 md:space-y-6">
|
||||
<Card>
|
||||
<CardHeader className="px-3 md:px-6 pt-3 md:pt-6 pb-2 md:pb-3">
|
||||
<Skeleton className="h-5 md:h-6 w-36 md:w-48" />
|
||||
<Skeleton className="h-3 md:h-4 w-full max-w-md mt-2" />
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-3 md:space-y-4 px-3 md:px-6 pb-3 md:pb-6">
|
||||
<Skeleton className="h-10 md:h-12 w-full" />
|
||||
<Skeleton className="h-10 md:h-12 w-full" />
|
||||
</CardContent>
|
||||
</Card>
|
||||
<div className="flex flex-col gap-6">
|
||||
<Skeleton className="h-10 md:h-12 w-full" />
|
||||
<Skeleton className="h-10 md:h-12 w-full" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -113,61 +138,45 @@ export function GeneralSettingsManager({ searchSpaceId }: GeneralSettingsManager
|
|||
<Alert className="bg-muted/50 py-3 md:py-4">
|
||||
<Info className="h-3 w-3 md:h-4 md:w-4 shrink-0" />
|
||||
<AlertDescription className="text-xs md:text-sm">
|
||||
Update your search space name and description. These details help identify and organize
|
||||
your workspace.
|
||||
Update your search space name and description.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
||||
{/* Search Space Details Card */}
|
||||
<form onSubmit={onSubmit} className="space-y-4 md:space-y-6">
|
||||
<Card>
|
||||
<CardHeader className="px-3 md:px-6 pt-3 md:pt-6 pb-2 md:pb-3">
|
||||
<CardTitle className="text-base md:text-lg">Search Space Details</CardTitle>
|
||||
<CardDescription className="text-xs md:text-sm">
|
||||
Manage the basic information for this search space.
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-4 md:space-y-5 px-3 md:px-6 pb-3 md:pb-6">
|
||||
<div className="space-y-1.5 md:space-y-2">
|
||||
<Label htmlFor="search-space-name" className="text-sm md:text-base font-medium">
|
||||
{t("general_name_label")}
|
||||
</Label>
|
||||
<Input
|
||||
id="search-space-name"
|
||||
placeholder={t("general_name_placeholder")}
|
||||
value={name}
|
||||
onChange={(e) => setName(e.target.value)}
|
||||
className="text-sm md:text-base h-9 md:h-10"
|
||||
/>
|
||||
<p className="text-[10px] md:text-xs text-muted-foreground">
|
||||
{t("general_name_description")}
|
||||
</p>
|
||||
</div>
|
||||
<form onSubmit={onSubmit} className="space-y-6">
|
||||
<div className="flex flex-col gap-6">
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="search-space-name">
|
||||
{t("general_name_label")}
|
||||
</Label>
|
||||
<Input
|
||||
id="search-space-name"
|
||||
placeholder={t("general_name_placeholder")}
|
||||
value={name}
|
||||
onChange={(e) => setName(e.target.value)}
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
{t("general_name_description")}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-1.5 md:space-y-2">
|
||||
<Label
|
||||
htmlFor="search-space-description"
|
||||
className="text-sm md:text-base font-medium"
|
||||
>
|
||||
{t("general_description_label")}{" "}
|
||||
<span className="text-muted-foreground font-normal">({tCommon("optional")})</span>
|
||||
</Label>
|
||||
<Input
|
||||
id="search-space-description"
|
||||
placeholder={t("general_description_placeholder")}
|
||||
value={description}
|
||||
onChange={(e) => setDescription(e.target.value)}
|
||||
className="text-sm md:text-base h-9 md:h-10"
|
||||
/>
|
||||
<p className="text-[10px] md:text-xs text-muted-foreground">
|
||||
{t("general_description_description")}
|
||||
</p>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="search-space-description">
|
||||
{t("general_description_label")}{" "}
|
||||
<span className="text-muted-foreground font-normal">({tCommon("optional")})</span>
|
||||
</Label>
|
||||
<Input
|
||||
id="search-space-description"
|
||||
placeholder={t("general_description_placeholder")}
|
||||
value={description}
|
||||
onChange={(e) => setDescription(e.target.value)}
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
{t("general_description_description")}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Action Buttons */}
|
||||
<div className="flex justify-end pt-3 md:pt-4">
|
||||
<div className="flex justify-end">
|
||||
<Button
|
||||
type="submit"
|
||||
variant="outline"
|
||||
|
|
@ -179,6 +188,29 @@ export function GeneralSettingsManager({ searchSpaceId }: GeneralSettingsManager
|
|||
</Button>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
<div className="border-t pt-6 flex flex-col gap-3 md:flex-row md:items-center md:justify-between">
|
||||
<div className="space-y-1">
|
||||
<Label>Export knowledge base</Label>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Download all documents in this search space as a ZIP of markdown files.
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
type="button"
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
disabled={isExporting}
|
||||
onClick={handleExportKB}
|
||||
className="relative w-fit shrink-0"
|
||||
>
|
||||
<span className={isExporting ? "opacity-0" : ""}>
|
||||
<FolderArchive className="h-3 w-3 opacity-60" />
|
||||
</span>
|
||||
<span className={isExporting ? "opacity-0" : ""}>Export</span>
|
||||
{isExporting && <Spinner size="sm" className="absolute" />}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"use client";
|
||||
|
||||
import { useAtomValue } from "jotai";
|
||||
import { AlertCircle, Dot, Edit3, Info, RefreshCw, Trash2, Wand2 } from "lucide-react";
|
||||
import { AlertCircle, Dot, Edit3, Info, RefreshCw, Trash2 } from "lucide-react";
|
||||
import { useMemo, useState } from "react";
|
||||
import { deleteImageGenConfigMutationAtom } from "@/atoms/image-gen-config/image-gen-config-mutation.atoms";
|
||||
import {
|
||||
|
|
@ -209,20 +209,20 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) {
|
|||
{["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => (
|
||||
<Card key={key} className="border-border/60">
|
||||
<CardContent className="p-4 flex flex-col gap-3">
|
||||
<div className="flex items-start justify-between gap-2">
|
||||
<div className="flex items-center gap-2.5">
|
||||
<Skeleton className="size-4 rounded-full shrink-0" />
|
||||
<div className="space-y-1.5 flex-1 min-w-0">
|
||||
<Skeleton className="h-4 w-28 md:w-32" />
|
||||
<Skeleton className="h-3 w-40 md:w-48" />
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<Skeleton className="h-5 w-16 rounded-full" />
|
||||
<Skeleton className="h-5 w-24 rounded-md" />
|
||||
</div>
|
||||
<div className="flex items-center gap-2 pt-2 border-t border-border/40">
|
||||
<Skeleton className="h-3 w-20" />
|
||||
<Skeleton className="h-4 w-4 rounded-full" />
|
||||
<Skeleton className="h-3 w-16" />
|
||||
<div className="flex items-center pt-2 border-t border-border/40">
|
||||
<Skeleton className="h-3 w-20 flex-1" />
|
||||
<Skeleton className="h-3 w-3 rounded-full shrink-0 mx-1" />
|
||||
<div className="flex-1 flex items-center justify-end gap-1.5">
|
||||
<Skeleton className="h-4 w-4 rounded-full" />
|
||||
<Skeleton className="h-3 w-16" />
|
||||
</div>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
|
@ -255,20 +255,25 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) {
|
|||
<div key={config.id}>
|
||||
<Card className="group relative overflow-hidden transition-all duration-200 border-border/60 hover:shadow-md h-full">
|
||||
<CardContent className="p-4 flex flex-col gap-3 h-full">
|
||||
{/* Header: Name + Actions */}
|
||||
<div className="flex items-start justify-between gap-2">
|
||||
<div className="min-w-0 flex-1">
|
||||
<h4 className="text-sm font-semibold tracking-tight truncate">
|
||||
{config.name}
|
||||
</h4>
|
||||
{config.description && (
|
||||
<p className="text-[11px] text-muted-foreground/70 truncate mt-0.5">
|
||||
{config.description}
|
||||
</p>
|
||||
)}
|
||||
{/* Header: Icon + Name + Actions */}
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<div className="flex items-center gap-2.5 min-w-0 flex-1">
|
||||
<div className="shrink-0">
|
||||
{getProviderIcon(config.provider, { className: "size-4" })}
|
||||
</div>
|
||||
<div className="min-w-0 flex-1">
|
||||
<h4 className="text-sm font-semibold tracking-tight truncate">
|
||||
{config.name}
|
||||
</h4>
|
||||
{config.description && (
|
||||
<p className="text-[11px] text-muted-foreground/70 truncate mt-0.5">
|
||||
{config.description}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
{(canUpdate || canDelete) && (
|
||||
<div className="flex items-center gap-0.5 shrink-0 sm:opacity-0 sm:group-hover:opacity-100 transition-opacity duration-150">
|
||||
<div className="flex items-center gap-1 shrink-0 sm:w-0 sm:overflow-hidden sm:group-hover:w-auto sm:opacity-0 sm:group-hover:opacity-100 transition-all duration-150">
|
||||
{canUpdate && (
|
||||
<TooltipProvider>
|
||||
<Tooltip open={isDesktop ? undefined : false}>
|
||||
|
|
@ -277,7 +282,7 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) {
|
|||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={() => openEditDialog(config)}
|
||||
className="h-7 w-7 text-muted-foreground hover:text-foreground"
|
||||
className="h-7 w-7 rounded-lg text-muted-foreground hover:text-foreground"
|
||||
>
|
||||
<Edit3 className="h-3 w-3" />
|
||||
</Button>
|
||||
|
|
@ -294,7 +299,7 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) {
|
|||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={() => setConfigToDelete(config)}
|
||||
className="h-7 w-7 text-muted-foreground hover:text-destructive"
|
||||
className="h-7 w-7 rounded-lg text-muted-foreground hover:text-destructive"
|
||||
>
|
||||
<Trash2 className="h-3 w-3" />
|
||||
</Button>
|
||||
|
|
@ -307,17 +312,9 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) {
|
|||
)}
|
||||
</div>
|
||||
|
||||
{/* Provider + Model */}
|
||||
<div className="flex items-center gap-2 flex-wrap">
|
||||
{getProviderIcon(config.provider, { className: "size-3.5 shrink-0" })}
|
||||
<code className="text-[11px] font-mono text-muted-foreground bg-muted/60 px-2 py-0.5 rounded-md truncate max-w-[160px]">
|
||||
{config.model_name}
|
||||
</code>
|
||||
</div>
|
||||
|
||||
{/* Footer: Date + Creator */}
|
||||
<div className="flex items-center gap-2 pt-2 border-t border-border/40 mt-auto">
|
||||
<span className="text-[11px] text-muted-foreground/60">
|
||||
<div className="flex items-center pt-2 border-t border-border/40 mt-auto">
|
||||
<span className="shrink-0 text-[11px] text-muted-foreground/60 whitespace-nowrap">
|
||||
{new Date(config.created_at).toLocaleDateString(undefined, {
|
||||
year: "numeric",
|
||||
month: "short",
|
||||
|
|
@ -326,11 +323,11 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) {
|
|||
</span>
|
||||
{member && (
|
||||
<>
|
||||
<Dot className="h-4 w-4 text-muted-foreground/30" />
|
||||
<Dot className="h-4 w-4 text-muted-foreground/30 shrink-0" />
|
||||
<TooltipProvider>
|
||||
<Tooltip open={isDesktop ? undefined : false}>
|
||||
<TooltipTrigger asChild>
|
||||
<div className="flex items-center gap-1.5 cursor-default">
|
||||
<div className="min-w-0 flex items-center gap-1.5 cursor-default">
|
||||
<Avatar className="size-4.5 shrink-0">
|
||||
{member.avatarUrl && (
|
||||
<AvatarImage src={member.avatarUrl} alt={member.name} />
|
||||
|
|
@ -339,7 +336,7 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) {
|
|||
{getInitials(member.name)}
|
||||
</AvatarFallback>
|
||||
</Avatar>
|
||||
<span className="text-[11px] text-muted-foreground/60 truncate max-w-[120px]">
|
||||
<span className="text-[11px] text-muted-foreground/60 truncate">
|
||||
{member.name}
|
||||
</span>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import {
|
|||
Bot,
|
||||
CircleCheck,
|
||||
CircleDashed,
|
||||
Eye,
|
||||
ScanEye,
|
||||
FileText,
|
||||
ImageIcon,
|
||||
RefreshCw,
|
||||
|
|
@ -74,7 +74,7 @@ const ROLE_DESCRIPTIONS = {
|
|||
configType: "image" as const,
|
||||
},
|
||||
vision: {
|
||||
icon: Eye,
|
||||
icon: ScanEye,
|
||||
title: "Vision LLM",
|
||||
description: "Vision-capable model for screenshot analysis and context extraction",
|
||||
color: "text-muted-foreground",
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import { useQuery } from "@tanstack/react-query";
|
|||
import { useAtomValue } from "jotai";
|
||||
import {
|
||||
Bot,
|
||||
ChevronDown,
|
||||
Edit2,
|
||||
FileText,
|
||||
Globe,
|
||||
|
|
@ -47,7 +48,6 @@ import {
|
|||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogTrigger,
|
||||
} from "@/components/ui/dialog";
|
||||
import {
|
||||
DropdownMenu,
|
||||
|
|
@ -58,7 +58,6 @@ import {
|
|||
} from "@/components/ui/dropdown-menu";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import type { PermissionInfo } from "@/contracts/types/permissions.types";
|
||||
import type {
|
||||
|
|
@ -319,100 +318,6 @@ export function RolesManager({ searchSpaceId }: { searchSpaceId: number }) {
|
|||
);
|
||||
}
|
||||
|
||||
// ============ Role Permissions Display ============
|
||||
|
||||
function RolePermissionsDialog({
|
||||
permissions,
|
||||
roleName,
|
||||
children,
|
||||
}: {
|
||||
permissions: string[];
|
||||
roleName: string;
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
const isFullAccess = permissions.includes("*");
|
||||
|
||||
const grouped: Record<string, string[]> = {};
|
||||
if (!isFullAccess) {
|
||||
for (const perm of permissions) {
|
||||
const [category, action] = perm.split(":");
|
||||
if (!grouped[category]) grouped[category] = [];
|
||||
grouped[category].push(action);
|
||||
}
|
||||
}
|
||||
|
||||
const sortedCategories = Object.keys(grouped).sort((a, b) => {
|
||||
const orderA = CATEGORY_CONFIG[a]?.order ?? 99;
|
||||
const orderB = CATEGORY_CONFIG[b]?.order ?? 99;
|
||||
return orderA - orderB;
|
||||
});
|
||||
|
||||
const categoryCount = sortedCategories.length;
|
||||
|
||||
return (
|
||||
<Dialog>
|
||||
<DialogTrigger asChild>{children}</DialogTrigger>
|
||||
<DialogContent className="w-[92vw] max-w-md p-0 gap-0">
|
||||
<DialogHeader className="p-4 md:p-5">
|
||||
<DialogTitle className="text-base">{roleName} — Permissions</DialogTitle>
|
||||
<DialogDescription className="text-xs">
|
||||
{isFullAccess
|
||||
? "This role has unrestricted access to all resources"
|
||||
: `${permissions.length} permissions across ${categoryCount} categories`}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
{isFullAccess ? (
|
||||
<div className="flex items-center gap-3 px-4 md:px-5 py-6">
|
||||
<div className="h-9 w-9 rounded-lg bg-muted/60 flex items-center justify-center shrink-0">
|
||||
<Shield className="h-4 w-4 text-muted-foreground" />
|
||||
</div>
|
||||
<div>
|
||||
<p className="text-sm font-medium">Full access</p>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
All permissions granted across every category
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<ScrollArea className="max-h-[55vh]">
|
||||
<div className="divide-y divide-border/50">
|
||||
{sortedCategories.map((category) => {
|
||||
const actions = grouped[category];
|
||||
const config = CATEGORY_CONFIG[category] || {
|
||||
label: category,
|
||||
icon: FileText,
|
||||
};
|
||||
const IconComponent = config.icon;
|
||||
return (
|
||||
<div
|
||||
key={category}
|
||||
className="flex items-center justify-between gap-3 px-4 md:px-5 py-2.5"
|
||||
>
|
||||
<div className="flex items-center gap-2 shrink-0">
|
||||
<IconComponent className="h-3.5 w-3.5 text-muted-foreground" />
|
||||
<span className="text-sm text-muted-foreground">{config.label}</span>
|
||||
</div>
|
||||
<div className="flex flex-wrap justify-end gap-1">
|
||||
{actions.map((action) => (
|
||||
<span
|
||||
key={action}
|
||||
className="px-1.5 py-0.5 rounded bg-muted text-muted-foreground text-[11px] font-medium"
|
||||
>
|
||||
{ACTION_LABELS[action] || action.replace(/_/g, " ")}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
)}
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
|
||||
function PermissionsBadge({ permissions }: { permissions: string[] }) {
|
||||
if (permissions.includes("*")) {
|
||||
return (
|
||||
|
|
@ -463,6 +368,7 @@ function RolesContent({
|
|||
}) {
|
||||
const [showCreateRole, setShowCreateRole] = useState(false);
|
||||
const [editingRoleId, setEditingRoleId] = useState<number | null>(null);
|
||||
const [expandedRoleId, setExpandedRoleId] = useState<number | null>(null);
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
|
|
@ -508,91 +414,170 @@ function RolesContent({
|
|||
)}
|
||||
|
||||
<div className="space-y-3">
|
||||
{roles.map((role) => (
|
||||
<div key={role.id}>
|
||||
<div className="w-full text-left relative flex items-center gap-4 rounded-lg border border-border/60 p-4 transition-colors hover:bg-muted/30">
|
||||
<div className="flex-1 min-w-0">
|
||||
<RolePermissionsDialog permissions={role.permissions} roleName={role.name}>
|
||||
<button type="button" className="w-full text-left cursor-pointer">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="font-medium text-sm">{role.name}</span>
|
||||
{role.is_system_role && (
|
||||
<span className="text-[10px] px-1.5 py-0.5 rounded bg-muted text-muted-foreground font-medium">
|
||||
System
|
||||
</span>
|
||||
)}
|
||||
{role.is_default && (
|
||||
<span className="text-[10px] px-1.5 py-0.5 rounded bg-muted text-muted-foreground font-medium">
|
||||
Default
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
{role.description && (
|
||||
<p className="text-xs text-muted-foreground mt-0.5 truncate">
|
||||
{role.description}
|
||||
</p>
|
||||
{roles.map((role) => {
|
||||
const isExpanded = expandedRoleId === role.id;
|
||||
const isFullAccess = role.permissions.includes("*");
|
||||
|
||||
const grouped: Record<string, string[]> = {};
|
||||
if (!isFullAccess) {
|
||||
for (const perm of role.permissions) {
|
||||
const [category, action] = perm.split(":");
|
||||
if (!grouped[category]) grouped[category] = [];
|
||||
grouped[category].push(action);
|
||||
}
|
||||
}
|
||||
const sortedCategories = Object.keys(grouped).sort((a, b) => {
|
||||
const orderA = CATEGORY_CONFIG[a]?.order ?? 99;
|
||||
const orderB = CATEGORY_CONFIG[b]?.order ?? 99;
|
||||
return orderA - orderB;
|
||||
});
|
||||
|
||||
return (
|
||||
<div key={role.id} className="rounded-lg border border-border/60 overflow-hidden">
|
||||
<div className="flex items-center gap-4 p-4 transition-colors hover:bg-muted/30">
|
||||
<button
|
||||
type="button"
|
||||
className="flex-1 min-w-0 text-left cursor-pointer"
|
||||
onClick={() => setExpandedRoleId(isExpanded ? null : role.id)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="font-medium text-sm">{role.name}</span>
|
||||
{role.is_system_role && (
|
||||
<span className="text-[10px] px-1.5 py-0.5 rounded bg-muted text-muted-foreground font-medium">
|
||||
System
|
||||
</span>
|
||||
)}
|
||||
</button>
|
||||
</RolePermissionsDialog>
|
||||
{role.is_default && (
|
||||
<span className="text-[10px] px-1.5 py-0.5 rounded bg-muted text-muted-foreground font-medium">
|
||||
Default
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
{role.description && (
|
||||
<p className="text-xs text-muted-foreground mt-0.5 truncate">
|
||||
{role.description}
|
||||
</p>
|
||||
)}
|
||||
</button>
|
||||
|
||||
<div className="shrink-0">
|
||||
<PermissionsBadge permissions={role.permissions} />
|
||||
</div>
|
||||
|
||||
{!role.is_system_role && (
|
||||
<div className="shrink-0" role="none">
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button variant="ghost" size="icon" className="h-8 w-8">
|
||||
<MoreHorizontal className="h-4 w-4" />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end" onCloseAutoFocus={(e) => e.preventDefault()}>
|
||||
{canUpdate && (
|
||||
<DropdownMenuItem onClick={() => setEditingRoleId(role.id)}>
|
||||
<Edit2 className="h-4 w-4 mr-2" />
|
||||
Edit Role
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
{canDelete && (
|
||||
<>
|
||||
<DropdownMenuSeparator />
|
||||
<AlertDialog>
|
||||
<AlertDialogTrigger asChild>
|
||||
<DropdownMenuItem onSelect={(e) => e.preventDefault()}>
|
||||
<Trash2 className="h-4 w-4 mr-2" />
|
||||
Delete Role
|
||||
</DropdownMenuItem>
|
||||
</AlertDialogTrigger>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader>
|
||||
<AlertDialogTitle>Delete role?</AlertDialogTitle>
|
||||
<AlertDialogDescription>
|
||||
This will permanently delete the "{role.name}" role.
|
||||
Members with this role will lose their permissions.
|
||||
</AlertDialogDescription>
|
||||
</AlertDialogHeader>
|
||||
<AlertDialogFooter>
|
||||
<AlertDialogCancel>Cancel</AlertDialogCancel>
|
||||
<AlertDialogAction
|
||||
onClick={() => onDeleteRole(role.id)}
|
||||
className="bg-destructive text-destructive-foreground hover:bg-destructive/90"
|
||||
>
|
||||
Delete
|
||||
</AlertDialogAction>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
</>
|
||||
)}
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<button
|
||||
type="button"
|
||||
className="shrink-0 p-1 cursor-pointer"
|
||||
onClick={() => setExpandedRoleId(isExpanded ? null : role.id)}
|
||||
>
|
||||
<ChevronDown
|
||||
className={cn(
|
||||
"h-4 w-4 text-muted-foreground transition-transform duration-200",
|
||||
isExpanded && "rotate-180"
|
||||
)}
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div className="shrink-0">
|
||||
<PermissionsBadge permissions={role.permissions} />
|
||||
</div>
|
||||
|
||||
{!role.is_system_role && (
|
||||
<div className="shrink-0" role="none">
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button variant="ghost" size="icon" className="h-8 w-8">
|
||||
<MoreHorizontal className="h-4 w-4" />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end" onCloseAutoFocus={(e) => e.preventDefault()}>
|
||||
{canUpdate && (
|
||||
<DropdownMenuItem onClick={() => setEditingRoleId(role.id)}>
|
||||
<Edit2 className="h-4 w-4 mr-2" />
|
||||
Edit Role
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
{canDelete && (
|
||||
<>
|
||||
<DropdownMenuSeparator />
|
||||
<AlertDialog>
|
||||
<AlertDialogTrigger asChild>
|
||||
<DropdownMenuItem onSelect={(e) => e.preventDefault()}>
|
||||
<Trash2 className="h-4 w-4 mr-2" />
|
||||
Delete Role
|
||||
</DropdownMenuItem>
|
||||
</AlertDialogTrigger>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader>
|
||||
<AlertDialogTitle>Delete role?</AlertDialogTitle>
|
||||
<AlertDialogDescription>
|
||||
This will permanently delete the "{role.name}" role.
|
||||
Members with this role will lose their permissions.
|
||||
</AlertDialogDescription>
|
||||
</AlertDialogHeader>
|
||||
<AlertDialogFooter>
|
||||
<AlertDialogCancel>Cancel</AlertDialogCancel>
|
||||
<AlertDialogAction
|
||||
onClick={() => onDeleteRole(role.id)}
|
||||
className="bg-destructive text-destructive-foreground hover:bg-destructive/90"
|
||||
{isExpanded && (
|
||||
<div className="border-t border-border/40 px-4 py-3">
|
||||
{isFullAccess ? (
|
||||
<div className="flex items-center gap-3 py-2">
|
||||
<Shield className="h-4 w-4 text-muted-foreground shrink-0" />
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Full access — all permissions granted across every category
|
||||
</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="divide-y divide-border/30">
|
||||
{sortedCategories.map((category) => {
|
||||
const actions = grouped[category];
|
||||
const config = CATEGORY_CONFIG[category] || {
|
||||
label: category,
|
||||
icon: FileText,
|
||||
};
|
||||
const IconComponent = config.icon;
|
||||
return (
|
||||
<div
|
||||
key={category}
|
||||
className="flex items-center justify-between gap-3 py-2.5"
|
||||
>
|
||||
<div className="flex items-center gap-2 shrink-0">
|
||||
<IconComponent className="h-3.5 w-3.5 text-muted-foreground" />
|
||||
<span className="text-sm text-muted-foreground">
|
||||
{config.label}
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex flex-wrap justify-end gap-1">
|
||||
{actions.map((action) => (
|
||||
<span
|
||||
key={action}
|
||||
className="px-1.5 py-0.5 rounded bg-muted text-muted-foreground text-[11px] font-medium"
|
||||
>
|
||||
Delete
|
||||
</AlertDialogAction>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
</>
|
||||
)}
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
{ACTION_LABELS[action] || action.replace(/_/g, " ")}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
|
@ -676,46 +661,51 @@ function PermissionsEditor({
|
|||
|
||||
return (
|
||||
<div key={category} className="rounded-lg border border-border/60 overflow-hidden">
|
||||
<button
|
||||
type="button"
|
||||
className="w-full flex items-center justify-between px-3 py-2.5 cursor-pointer hover:bg-muted/40 transition-colors"
|
||||
onClick={() => toggleCategoryExpanded(category)}
|
||||
>
|
||||
<div className="flex items-center gap-2.5">
|
||||
<div className="flex items-center justify-between px-3 py-2.5 hover:bg-muted/40 transition-colors">
|
||||
<button
|
||||
type="button"
|
||||
className="flex-1 flex items-center gap-2.5 cursor-pointer"
|
||||
onClick={() => toggleCategoryExpanded(category)}
|
||||
>
|
||||
<IconComponent className="h-4 w-4 text-muted-foreground shrink-0" />
|
||||
<span className="font-medium text-sm">{config.label}</span>
|
||||
<span className="text-[11px] text-muted-foreground tabular-nums">
|
||||
{stats.selected}/{stats.total}
|
||||
</span>
|
||||
</div>
|
||||
</button>
|
||||
<div className="flex items-center gap-2">
|
||||
<Checkbox
|
||||
checked={stats.allSelected}
|
||||
onCheckedChange={() => onToggleCategory(category)}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
aria-label={`Select all ${config.label} permissions`}
|
||||
/>
|
||||
<div
|
||||
className={cn("transition-transform duration-200", isExpanded && "rotate-180")}
|
||||
<button
|
||||
type="button"
|
||||
className="cursor-pointer"
|
||||
onClick={() => toggleCategoryExpanded(category)}
|
||||
>
|
||||
<svg
|
||||
className="h-4 w-4 text-muted-foreground"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
aria-hidden="true"
|
||||
<div
|
||||
className={cn("transition-transform duration-200", isExpanded && "rotate-180")}
|
||||
>
|
||||
<title>Toggle</title>
|
||||
<path
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
strokeWidth={2}
|
||||
d="M19 9l-7 7-7-7"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
<svg
|
||||
className="h-4 w-4 text-muted-foreground"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
aria-hidden="true"
|
||||
>
|
||||
<title>Toggle</title>
|
||||
<path
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
strokeWidth={2}
|
||||
d="M19 9l-7 7-7-7"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{isExpanded && (
|
||||
<div className="border-t border-border/60">
|
||||
|
|
@ -726,28 +716,29 @@ function PermissionsEditor({
|
|||
const isSelected = selectedPermissions.includes(perm.value);
|
||||
|
||||
return (
|
||||
<button
|
||||
<div
|
||||
key={perm.value}
|
||||
type="button"
|
||||
className={cn(
|
||||
"w-full flex items-center justify-between gap-3 px-2.5 py-2 rounded-md cursor-pointer transition-colors",
|
||||
"flex items-center justify-between gap-3 px-2.5 py-2 rounded-md transition-colors",
|
||||
isSelected ? "bg-muted/60 hover:bg-muted/80" : "hover:bg-muted/40"
|
||||
)}
|
||||
onClick={() => onTogglePermission(perm.value)}
|
||||
>
|
||||
<div className="flex-1 min-w-0 text-left">
|
||||
<button
|
||||
type="button"
|
||||
className="flex-1 min-w-0 text-left cursor-pointer"
|
||||
onClick={() => onTogglePermission(perm.value)}
|
||||
>
|
||||
<span className="text-sm font-medium">{actionLabel}</span>
|
||||
<p className="text-xs text-muted-foreground truncate">
|
||||
{perm.description}
|
||||
</p>
|
||||
</div>
|
||||
</button>
|
||||
<Checkbox
|
||||
checked={isSelected}
|
||||
onCheckedChange={() => onTogglePermission(perm.value)}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
className="shrink-0"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import {
|
|||
Brain,
|
||||
CircleUser,
|
||||
Earth,
|
||||
Eye,
|
||||
ScanEye,
|
||||
ImageIcon,
|
||||
ListChecks,
|
||||
UserKey,
|
||||
|
|
@ -25,10 +25,10 @@ const GeneralSettingsManager = dynamic(
|
|||
})),
|
||||
{ ssr: false }
|
||||
);
|
||||
const ModelConfigManager = dynamic(
|
||||
const AgentModelManager = dynamic(
|
||||
() =>
|
||||
import("@/components/settings/model-config-manager").then((m) => ({
|
||||
default: m.ModelConfigManager,
|
||||
import("@/components/settings/agent-model-manager").then((m) => ({
|
||||
default: m.AgentModelManager,
|
||||
})),
|
||||
{ ssr: false }
|
||||
);
|
||||
|
|
@ -88,7 +88,7 @@ export function SearchSpaceSettingsDialog({ searchSpaceId }: SearchSpaceSettings
|
|||
const navItems = [
|
||||
{ value: "general", label: t("nav_general"), icon: <CircleUser className="h-4 w-4" /> },
|
||||
{ value: "roles", label: t("nav_role_assignments"), icon: <ListChecks className="h-4 w-4" /> },
|
||||
{ value: "models", label: t("nav_agent_configs"), icon: <Bot className="h-4 w-4" /> },
|
||||
{ value: "models", label: t("nav_agent_models"), icon: <Bot className="h-4 w-4" /> },
|
||||
{
|
||||
value: "image-models",
|
||||
label: t("nav_image_models"),
|
||||
|
|
@ -97,7 +97,7 @@ export function SearchSpaceSettingsDialog({ searchSpaceId }: SearchSpaceSettings
|
|||
{
|
||||
value: "vision-models",
|
||||
label: t("nav_vision_models"),
|
||||
icon: <Eye className="h-4 w-4" />,
|
||||
icon: <ScanEye className="h-4 w-4" />,
|
||||
},
|
||||
{ value: "team-roles", label: t("nav_team_roles"), icon: <UserKey className="h-4 w-4" /> },
|
||||
{
|
||||
|
|
@ -115,7 +115,7 @@ export function SearchSpaceSettingsDialog({ searchSpaceId }: SearchSpaceSettings
|
|||
|
||||
const content: Record<string, React.ReactNode> = {
|
||||
general: <GeneralSettingsManager searchSpaceId={searchSpaceId} />,
|
||||
models: <ModelConfigManager searchSpaceId={searchSpaceId} />,
|
||||
models: <AgentModelManager searchSpaceId={searchSpaceId} />,
|
||||
roles: <LLMRoleManager searchSpaceId={searchSpaceId} />,
|
||||
"image-models": <ImageModelManager searchSpaceId={searchSpaceId} />,
|
||||
"vision-models": <VisionModelManager searchSpaceId={searchSpaceId} />,
|
||||
|
|
|
|||
|
|
@ -208,20 +208,20 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) {
|
|||
{["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => (
|
||||
<Card key={key} className="border-border/60">
|
||||
<CardContent className="p-4 flex flex-col gap-3">
|
||||
<div className="flex items-start justify-between gap-2">
|
||||
<div className="flex items-center gap-2.5">
|
||||
<Skeleton className="size-4 rounded-full shrink-0" />
|
||||
<div className="space-y-1.5 flex-1 min-w-0">
|
||||
<Skeleton className="h-4 w-28 md:w-32" />
|
||||
<Skeleton className="h-3 w-40 md:w-48" />
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<Skeleton className="h-5 w-16 rounded-full" />
|
||||
<Skeleton className="h-5 w-24 rounded-md" />
|
||||
</div>
|
||||
<div className="flex items-center gap-2 pt-2 border-t border-border/40">
|
||||
<Skeleton className="h-3 w-20" />
|
||||
<Skeleton className="h-4 w-4 rounded-full" />
|
||||
<Skeleton className="h-3 w-16" />
|
||||
<div className="flex items-center pt-2 border-t border-border/40">
|
||||
<Skeleton className="h-3 w-20 flex-1" />
|
||||
<Skeleton className="h-3 w-3 rounded-full shrink-0 mx-1" />
|
||||
<div className="flex-1 flex items-center justify-end gap-1.5">
|
||||
<Skeleton className="h-4 w-4 rounded-full" />
|
||||
<Skeleton className="h-3 w-16" />
|
||||
</div>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
|
@ -253,19 +253,25 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) {
|
|||
<div key={config.id}>
|
||||
<Card className="group relative overflow-hidden transition-all duration-200 border-border/60 hover:shadow-md h-full">
|
||||
<CardContent className="p-4 flex flex-col gap-3 h-full">
|
||||
<div className="flex items-start justify-between gap-2">
|
||||
<div className="min-w-0 flex-1">
|
||||
<h4 className="text-sm font-semibold tracking-tight truncate">
|
||||
{config.name}
|
||||
</h4>
|
||||
{config.description && (
|
||||
<p className="text-[11px] text-muted-foreground/70 truncate mt-0.5">
|
||||
{config.description}
|
||||
</p>
|
||||
)}
|
||||
{/* Header: Icon + Name + Actions */}
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<div className="flex items-center gap-2.5 min-w-0 flex-1">
|
||||
<div className="shrink-0">
|
||||
{getProviderIcon(config.provider, { className: "size-4" })}
|
||||
</div>
|
||||
<div className="min-w-0 flex-1">
|
||||
<h4 className="text-sm font-semibold tracking-tight truncate">
|
||||
{config.name}
|
||||
</h4>
|
||||
{config.description && (
|
||||
<p className="text-[11px] text-muted-foreground/70 truncate mt-0.5">
|
||||
{config.description}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
{(canUpdate || canDelete) && (
|
||||
<div className="flex items-center gap-0.5 shrink-0 sm:opacity-0 sm:group-hover:opacity-100 transition-opacity duration-150">
|
||||
<div className="flex items-center gap-1 shrink-0 sm:w-0 sm:overflow-hidden sm:group-hover:w-auto sm:opacity-0 sm:group-hover:opacity-100 transition-all duration-150">
|
||||
{canUpdate && (
|
||||
<TooltipProvider>
|
||||
<Tooltip open={isDesktop ? undefined : false}>
|
||||
|
|
@ -274,7 +280,7 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) {
|
|||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={() => openEditDialog(config)}
|
||||
className="h-7 w-7 text-muted-foreground hover:text-foreground"
|
||||
className="h-6 w-6 text-muted-foreground hover:text-foreground"
|
||||
>
|
||||
<Edit3 className="h-3 w-3" />
|
||||
</Button>
|
||||
|
|
@ -291,7 +297,7 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) {
|
|||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={() => setConfigToDelete(config)}
|
||||
className="h-7 w-7 text-muted-foreground hover:text-destructive"
|
||||
className="h-7 w-7 rounded-lg text-muted-foreground hover:text-destructive"
|
||||
>
|
||||
<Trash2 className="h-3 w-3" />
|
||||
</Button>
|
||||
|
|
@ -304,17 +310,9 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) {
|
|||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex items-center gap-2 flex-wrap">
|
||||
{getProviderIcon(config.provider, {
|
||||
className: "size-3.5 shrink-0",
|
||||
})}
|
||||
<code className="text-[11px] font-mono text-muted-foreground bg-muted/60 px-2 py-0.5 rounded-md truncate max-w-[160px]">
|
||||
{config.model_name}
|
||||
</code>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center gap-2 pt-2 border-t border-border/40 mt-auto">
|
||||
<span className="text-[11px] text-muted-foreground/60">
|
||||
{/* Footer: Date + Creator */}
|
||||
<div className="flex items-center pt-2 border-t border-border/40 mt-auto">
|
||||
<span className="shrink-0 text-[11px] text-muted-foreground/60 whitespace-nowrap">
|
||||
{new Date(config.created_at).toLocaleDateString(undefined, {
|
||||
year: "numeric",
|
||||
month: "short",
|
||||
|
|
@ -323,11 +321,11 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) {
|
|||
</span>
|
||||
{member && (
|
||||
<>
|
||||
<Dot className="h-4 w-4 text-muted-foreground/30" />
|
||||
<Dot className="h-4 w-4 text-muted-foreground/30 shrink-0" />
|
||||
<TooltipProvider>
|
||||
<Tooltip open={isDesktop ? undefined : false}>
|
||||
<TooltipTrigger asChild>
|
||||
<div className="flex items-center gap-1.5 cursor-default">
|
||||
<div className="min-w-0 flex items-center gap-1.5 cursor-default">
|
||||
<Avatar className="size-4.5 shrink-0">
|
||||
{member.avatarUrl && (
|
||||
<AvatarImage src={member.avatarUrl} alt={member.name} />
|
||||
|
|
@ -336,7 +334,7 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) {
|
|||
{getInitials(member.name)}
|
||||
</AvatarFallback>
|
||||
</Avatar>
|
||||
<span className="text-[11px] text-muted-foreground/60 truncate max-w-[120px]">
|
||||
<span className="text-[11px] text-muted-foreground/60 truncate">
|
||||
{member.name}
|
||||
</span>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ interface ImageConfigDialogProps {
|
|||
isGlobal: boolean;
|
||||
searchSpaceId: number;
|
||||
mode: "create" | "edit" | "view";
|
||||
defaultProvider?: string;
|
||||
}
|
||||
|
||||
const INITIAL_FORM = {
|
||||
|
|
@ -67,6 +68,7 @@ export function ImageConfigDialog({
|
|||
isGlobal,
|
||||
searchSpaceId,
|
||||
mode,
|
||||
defaultProvider,
|
||||
}: ImageConfigDialogProps) {
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [formData, setFormData] = useState(INITIAL_FORM);
|
||||
|
|
@ -87,11 +89,11 @@ export function ImageConfigDialog({
|
|||
api_version: config.api_version || "",
|
||||
});
|
||||
} else if (mode === "create") {
|
||||
setFormData(INITIAL_FORM);
|
||||
setFormData({ ...INITIAL_FORM, provider: defaultProvider ?? "" });
|
||||
}
|
||||
setScrollPos("top");
|
||||
}
|
||||
}, [open, mode, config, isGlobal]);
|
||||
}, [open, mode, config, isGlobal, defaultProvider]);
|
||||
|
||||
const { mutateAsync: createConfig } = useAtomValue(createImageGenConfigMutationAtom);
|
||||
const { mutateAsync: updateConfig } = useAtomValue(updateImageGenConfigMutationAtom);
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ interface ModelConfigDialogProps {
|
|||
isGlobal: boolean;
|
||||
searchSpaceId: number;
|
||||
mode: "create" | "edit" | "view";
|
||||
defaultProvider?: string;
|
||||
}
|
||||
|
||||
export function ModelConfigDialog({
|
||||
|
|
@ -37,6 +38,7 @@ export function ModelConfigDialog({
|
|||
isGlobal,
|
||||
searchSpaceId,
|
||||
mode,
|
||||
defaultProvider,
|
||||
}: ModelConfigDialogProps) {
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [scrollPos, setScrollPos] = useState<"top" | "middle" | "bottom">("top");
|
||||
|
|
@ -194,10 +196,12 @@ export function ModelConfigDialog({
|
|||
|
||||
{mode === "create" ? (
|
||||
<LLMConfigForm
|
||||
key={defaultProvider ?? "no-provider"}
|
||||
searchSpaceId={searchSpaceId}
|
||||
onSubmit={handleSubmit}
|
||||
mode="create"
|
||||
formId="model-config-form"
|
||||
initialData={defaultProvider ? { provider: defaultProvider as LiteLLMProvider } : undefined}
|
||||
/>
|
||||
) : isGlobal && config ? (
|
||||
<div className="space-y-6">
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ interface VisionConfigDialogProps {
|
|||
isGlobal: boolean;
|
||||
searchSpaceId: number;
|
||||
mode: "create" | "edit" | "view";
|
||||
defaultProvider?: string;
|
||||
}
|
||||
|
||||
const INITIAL_FORM = {
|
||||
|
|
@ -68,6 +69,7 @@ export function VisionConfigDialog({
|
|||
isGlobal,
|
||||
searchSpaceId,
|
||||
mode,
|
||||
defaultProvider,
|
||||
}: VisionConfigDialogProps) {
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [formData, setFormData] = useState(INITIAL_FORM);
|
||||
|
|
@ -87,11 +89,11 @@ export function VisionConfigDialog({
|
|||
api_version: (config as VisionLLMConfig).api_version || "",
|
||||
});
|
||||
} else if (mode === "create") {
|
||||
setFormData(INITIAL_FORM);
|
||||
setFormData({ ...INITIAL_FORM, provider: defaultProvider ?? "" });
|
||||
}
|
||||
setScrollPos("top");
|
||||
}
|
||||
}, [open, mode, config, isGlobal]);
|
||||
}, [open, mode, config, isGlobal, defaultProvider]);
|
||||
|
||||
const { mutateAsync: createConfig } = useAtomValue(createVisionLLMConfigMutationAtom);
|
||||
const { mutateAsync: updateConfig } = useAtomValue(updateVisionLLMConfigMutationAtom);
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ import {
|
|||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { isInterruptResult, useHitlDecision } from "@/lib/hitl";
|
||||
import type { InterruptResult, HitlDecision } from "@/lib/hitl";
|
||||
import { useHitlPhase } from "@/hooks/use-hitl-phase";
|
||||
|
||||
interface ConfluenceAccount {
|
||||
|
|
@ -30,24 +32,10 @@ interface ConfluenceSpace {
|
|||
name: string;
|
||||
}
|
||||
|
||||
interface InterruptResult {
|
||||
__interrupt__: true;
|
||||
__decided__?: "approve" | "reject" | "edit";
|
||||
__completed__?: boolean;
|
||||
action_requests: Array<{
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
}>;
|
||||
review_configs: Array<{
|
||||
action_name: string;
|
||||
allowed_decisions: Array<"approve" | "edit" | "reject">;
|
||||
}>;
|
||||
interrupt_type?: string;
|
||||
context?: {
|
||||
accounts?: ConfluenceAccount[];
|
||||
spaces?: ConfluenceSpace[];
|
||||
error?: string;
|
||||
};
|
||||
type CreateConfluencePageInterruptContext = {
|
||||
accounts?: ConfluenceAccount[];
|
||||
spaces?: ConfluenceSpace[];
|
||||
error?: string;
|
||||
}
|
||||
|
||||
interface SuccessResult {
|
||||
|
|
@ -76,21 +64,12 @@ interface InsufficientPermissionsResult {
|
|||
}
|
||||
|
||||
type CreateConfluencePageResult =
|
||||
| InterruptResult
|
||||
| InterruptResult<CreateConfluencePageInterruptContext>
|
||||
| SuccessResult
|
||||
| ErrorResult
|
||||
| AuthErrorResult
|
||||
| InsufficientPermissionsResult;
|
||||
|
||||
function isInterruptResult(result: unknown): result is InterruptResult {
|
||||
return (
|
||||
typeof result === "object" &&
|
||||
result !== null &&
|
||||
"__interrupt__" in result &&
|
||||
(result as InterruptResult).__interrupt__ === true
|
||||
);
|
||||
}
|
||||
|
||||
function isErrorResult(result: unknown): result is ErrorResult {
|
||||
return (
|
||||
typeof result === "object" &&
|
||||
|
|
@ -124,12 +103,8 @@ function ApprovalCard({
|
|||
onDecision,
|
||||
}: {
|
||||
args: { title: string; content?: string; space_id?: string };
|
||||
interruptData: InterruptResult;
|
||||
onDecision: (decision: {
|
||||
type: "approve" | "reject" | "edit";
|
||||
message?: string;
|
||||
edited_action?: { name: string; args: Record<string, unknown> };
|
||||
}) => void;
|
||||
interruptData: InterruptResult<CreateConfluencePageInterruptContext>;
|
||||
onDecision: (decision: HitlDecision) => void;
|
||||
}) {
|
||||
const { phase, setProcessing, setRejected } = useHitlPhase(interruptData);
|
||||
const [isPanelOpen, setIsPanelOpen] = useState(false);
|
||||
|
|
@ -464,18 +439,16 @@ export const CreateConfluencePageToolUI = ({
|
|||
{ title: string; content?: string; space_id?: string },
|
||||
CreateConfluencePageResult
|
||||
>) => {
|
||||
const { dispatch } = useHitlDecision();
|
||||
|
||||
if (!result) return null;
|
||||
|
||||
if (isInterruptResult(result)) {
|
||||
return (
|
||||
<ApprovalCard
|
||||
args={args}
|
||||
interruptData={result}
|
||||
onDecision={(decision) => {
|
||||
window.dispatchEvent(
|
||||
new CustomEvent("hitl-decision", { detail: { decisions: [decision] } })
|
||||
);
|
||||
}}
|
||||
interruptData={result as InterruptResult<CreateConfluencePageInterruptContext>}
|
||||
onDecision={(decision) => dispatch([decision])}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,38 +6,26 @@ import { useCallback, useEffect, useState } from "react";
|
|||
import { TextShimmerLoader } from "@/components/prompt-kit/loader";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Checkbox } from "@/components/ui/checkbox";
|
||||
import { isInterruptResult, useHitlDecision } from "@/lib/hitl";
|
||||
import type { InterruptResult, HitlDecision } from "@/lib/hitl";
|
||||
import { useHitlPhase } from "@/hooks/use-hitl-phase";
|
||||
|
||||
interface InterruptResult {
|
||||
__interrupt__: true;
|
||||
__decided__?: "approve" | "reject";
|
||||
__completed__?: boolean;
|
||||
action_requests: Array<{
|
||||
type DeleteConfluencePageInterruptContext = {
|
||||
account?: {
|
||||
id: number;
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
}>;
|
||||
review_configs: Array<{
|
||||
action_name: string;
|
||||
allowed_decisions: Array<"approve" | "reject">;
|
||||
}>;
|
||||
interrupt_type?: string;
|
||||
context?: {
|
||||
account?: {
|
||||
id: number;
|
||||
name: string;
|
||||
base_url: string;
|
||||
auth_expired?: boolean;
|
||||
};
|
||||
page?: {
|
||||
page_id: string;
|
||||
page_title: string;
|
||||
space_id: string;
|
||||
connector_id?: number;
|
||||
document_id?: number;
|
||||
indexed_at?: string;
|
||||
};
|
||||
error?: string;
|
||||
base_url: string;
|
||||
auth_expired?: boolean;
|
||||
};
|
||||
page?: {
|
||||
page_id: string;
|
||||
page_title: string;
|
||||
space_id: string;
|
||||
connector_id?: number;
|
||||
document_id?: number;
|
||||
indexed_at?: string;
|
||||
};
|
||||
error?: string;
|
||||
}
|
||||
|
||||
interface SuccessResult {
|
||||
|
|
@ -77,7 +65,7 @@ interface InsufficientPermissionsResult {
|
|||
}
|
||||
|
||||
type DeleteConfluencePageResult =
|
||||
| InterruptResult
|
||||
| InterruptResult<DeleteConfluencePageInterruptContext>
|
||||
| SuccessResult
|
||||
| ErrorResult
|
||||
| NotFoundResult
|
||||
|
|
@ -85,15 +73,6 @@ type DeleteConfluencePageResult =
|
|||
| AuthErrorResult
|
||||
| InsufficientPermissionsResult;
|
||||
|
||||
function isInterruptResult(result: unknown): result is InterruptResult {
|
||||
return (
|
||||
typeof result === "object" &&
|
||||
result !== null &&
|
||||
"__interrupt__" in result &&
|
||||
(result as InterruptResult).__interrupt__ === true
|
||||
);
|
||||
}
|
||||
|
||||
function isErrorResult(result: unknown): result is ErrorResult {
|
||||
return (
|
||||
typeof result === "object" &&
|
||||
|
|
@ -145,12 +124,8 @@ function ApprovalCard({
|
|||
interruptData,
|
||||
onDecision,
|
||||
}: {
|
||||
interruptData: InterruptResult;
|
||||
onDecision: (decision: {
|
||||
type: "approve" | "reject";
|
||||
message?: string;
|
||||
edited_action?: { name: string; args: Record<string, unknown> };
|
||||
}) => void;
|
||||
interruptData: InterruptResult<DeleteConfluencePageInterruptContext>;
|
||||
onDecision: (decision: HitlDecision) => void;
|
||||
}) {
|
||||
const { phase, setProcessing, setRejected } = useHitlPhase(interruptData);
|
||||
const [deleteFromKb, setDeleteFromKb] = useState(false);
|
||||
|
|
@ -402,18 +377,15 @@ export const DeleteConfluencePageToolUI = ({
|
|||
{ page_title_or_id: string; delete_from_kb?: boolean },
|
||||
DeleteConfluencePageResult
|
||||
>) => {
|
||||
const { dispatch } = useHitlDecision();
|
||||
|
||||
if (!result) return null;
|
||||
|
||||
if (isInterruptResult(result)) {
|
||||
return (
|
||||
<ApprovalCard
|
||||
interruptData={result}
|
||||
onDecision={(decision) => {
|
||||
const event = new CustomEvent("hitl-decision", {
|
||||
detail: { decisions: [decision] },
|
||||
});
|
||||
window.dispatchEvent(event);
|
||||
}}
|
||||
interruptData={result as InterruptResult<DeleteConfluencePageInterruptContext>}
|
||||
onDecision={(decision) => dispatch([decision])}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,39 +8,27 @@ import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom";
|
|||
import { PlateEditor } from "@/components/editor/plate-editor";
|
||||
import { TextShimmerLoader } from "@/components/prompt-kit/loader";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { isInterruptResult, useHitlDecision } from "@/lib/hitl";
|
||||
import type { InterruptResult, HitlDecision } from "@/lib/hitl";
|
||||
import { useHitlPhase } from "@/hooks/use-hitl-phase";
|
||||
|
||||
interface InterruptResult {
|
||||
__interrupt__: true;
|
||||
__decided__?: "approve" | "reject" | "edit";
|
||||
__completed__?: boolean;
|
||||
action_requests: Array<{
|
||||
type UpdateConfluencePageInterruptContext = {
|
||||
account?: {
|
||||
id: number;
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
}>;
|
||||
review_configs: Array<{
|
||||
action_name: string;
|
||||
allowed_decisions: Array<"approve" | "edit" | "reject">;
|
||||
}>;
|
||||
interrupt_type?: string;
|
||||
context?: {
|
||||
account?: {
|
||||
id: number;
|
||||
name: string;
|
||||
base_url: string;
|
||||
auth_expired?: boolean;
|
||||
};
|
||||
page?: {
|
||||
page_id: string;
|
||||
page_title: string;
|
||||
space_id: string;
|
||||
body: string;
|
||||
version: number;
|
||||
document_id: number;
|
||||
indexed_at?: string;
|
||||
};
|
||||
error?: string;
|
||||
base_url: string;
|
||||
auth_expired?: boolean;
|
||||
};
|
||||
page?: {
|
||||
page_id: string;
|
||||
page_title: string;
|
||||
space_id: string;
|
||||
body: string;
|
||||
version: number;
|
||||
document_id: number;
|
||||
indexed_at?: string;
|
||||
};
|
||||
error?: string;
|
||||
}
|
||||
|
||||
interface SuccessResult {
|
||||
|
|
@ -74,22 +62,13 @@ interface InsufficientPermissionsResult {
|
|||
}
|
||||
|
||||
type UpdateConfluencePageResult =
|
||||
| InterruptResult
|
||||
| InterruptResult<UpdateConfluencePageInterruptContext>
|
||||
| SuccessResult
|
||||
| ErrorResult
|
||||
| NotFoundResult
|
||||
| AuthErrorResult
|
||||
| InsufficientPermissionsResult;
|
||||
|
||||
function isInterruptResult(result: unknown): result is InterruptResult {
|
||||
return (
|
||||
typeof result === "object" &&
|
||||
result !== null &&
|
||||
"__interrupt__" in result &&
|
||||
(result as InterruptResult).__interrupt__ === true
|
||||
);
|
||||
}
|
||||
|
||||
function isErrorResult(result: unknown): result is ErrorResult {
|
||||
return (
|
||||
typeof result === "object" &&
|
||||
|
|
@ -136,12 +115,8 @@ function ApprovalCard({
|
|||
new_title?: string;
|
||||
new_content?: string;
|
||||
};
|
||||
interruptData: InterruptResult;
|
||||
onDecision: (decision: {
|
||||
type: "approve" | "reject" | "edit";
|
||||
message?: string;
|
||||
edited_action?: { name: string; args: Record<string, unknown> };
|
||||
}) => void;
|
||||
interruptData: InterruptResult<UpdateConfluencePageInterruptContext>;
|
||||
onDecision: (decision: HitlDecision) => void;
|
||||
}) {
|
||||
const { phase, setProcessing, setRejected } = useHitlPhase(interruptData);
|
||||
|
||||
|
|
@ -502,18 +477,16 @@ export const UpdateConfluencePageToolUI = ({
|
|||
},
|
||||
UpdateConfluencePageResult
|
||||
>) => {
|
||||
const { dispatch } = useHitlDecision();
|
||||
|
||||
if (!result) return null;
|
||||
|
||||
if (isInterruptResult(result)) {
|
||||
return (
|
||||
<ApprovalCard
|
||||
args={args}
|
||||
interruptData={result}
|
||||
onDecision={(decision) => {
|
||||
window.dispatchEvent(
|
||||
new CustomEvent("hitl-decision", { detail: { decisions: [decision] } })
|
||||
);
|
||||
}}
|
||||
interruptData={result as InterruptResult<UpdateConfluencePageInterruptContext>}
|
||||
onDecision={(decision) => dispatch([decision])}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ import {
|
|||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { useHitlPhase } from "@/hooks/use-hitl-phase";
|
||||
import { isInterruptResult, useHitlDecision } from "@/lib/hitl";
|
||||
import type { InterruptResult, HitlDecision } from "@/lib/hitl";
|
||||
|
||||
interface DropboxAccount {
|
||||
id: number;
|
||||
|
|
@ -29,21 +31,11 @@ interface SupportedType {
|
|||
label: string;
|
||||
}
|
||||
|
||||
interface InterruptResult {
|
||||
__interrupt__: true;
|
||||
__decided__?: "approve" | "reject" | "edit";
|
||||
__completed__?: boolean;
|
||||
action_requests: Array<{ name: string; args: Record<string, unknown> }>;
|
||||
review_configs: Array<{
|
||||
action_name: string;
|
||||
allowed_decisions: Array<"approve" | "edit" | "reject">;
|
||||
}>;
|
||||
context?: {
|
||||
accounts?: DropboxAccount[];
|
||||
parent_folders?: Record<number, Array<{ folder_path: string; name: string }>>;
|
||||
supported_types?: SupportedType[];
|
||||
error?: string;
|
||||
};
|
||||
type DropboxCreateFileContext = {
|
||||
accounts?: DropboxAccount[];
|
||||
parent_folders?: Record<number, Array<{ folder_path: string; name: string }>>;
|
||||
supported_types?: SupportedType[];
|
||||
error?: string;
|
||||
}
|
||||
|
||||
interface SuccessResult {
|
||||
|
|
@ -65,16 +57,7 @@ interface AuthErrorResult {
|
|||
connector_type?: string;
|
||||
}
|
||||
|
||||
type CreateDropboxFileResult = InterruptResult | SuccessResult | ErrorResult | AuthErrorResult;
|
||||
|
||||
function isInterruptResult(result: unknown): result is InterruptResult {
|
||||
return (
|
||||
typeof result === "object" &&
|
||||
result !== null &&
|
||||
"__interrupt__" in result &&
|
||||
(result as InterruptResult).__interrupt__ === true
|
||||
);
|
||||
}
|
||||
type CreateDropboxFileResult = InterruptResult<DropboxCreateFileContext> | SuccessResult | ErrorResult | AuthErrorResult;
|
||||
|
||||
function isErrorResult(result: unknown): result is ErrorResult {
|
||||
return (
|
||||
|
|
@ -100,12 +83,8 @@ function ApprovalCard({
|
|||
onDecision,
|
||||
}: {
|
||||
args: { name: string; file_type?: string; content?: string };
|
||||
interruptData: InterruptResult;
|
||||
onDecision: (decision: {
|
||||
type: "approve" | "reject" | "edit";
|
||||
message?: string;
|
||||
edited_action?: { name: string; args: Record<string, unknown> };
|
||||
}) => void;
|
||||
interruptData: InterruptResult<DropboxCreateFileContext>;
|
||||
onDecision: (decision: HitlDecision) => void;
|
||||
}) {
|
||||
const { phase, setProcessing, setRejected } = useHitlPhase(interruptData);
|
||||
const [isPanelOpen, setIsPanelOpen] = useState(false);
|
||||
|
|
@ -455,17 +434,14 @@ export const CreateDropboxFileToolUI = ({
|
|||
{ name: string; file_type?: string; content?: string },
|
||||
CreateDropboxFileResult
|
||||
>) => {
|
||||
const { dispatch } = useHitlDecision();
|
||||
if (!result) return null;
|
||||
if (isInterruptResult(result)) {
|
||||
return (
|
||||
<ApprovalCard
|
||||
args={args}
|
||||
interruptData={result}
|
||||
onDecision={(decision) => {
|
||||
window.dispatchEvent(
|
||||
new CustomEvent("hitl-decision", { detail: { decisions: [decision] } })
|
||||
);
|
||||
}}
|
||||
interruptData={result as InterruptResult<DropboxCreateFileContext>}
|
||||
onDecision={(decision) => dispatch([decision])}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ import { TextShimmerLoader } from "@/components/prompt-kit/loader";
|
|||
import { Button } from "@/components/ui/button";
|
||||
import { Checkbox } from "@/components/ui/checkbox";
|
||||
import { useHitlPhase } from "@/hooks/use-hitl-phase";
|
||||
import { isInterruptResult, useHitlDecision } from "@/lib/hitl";
|
||||
import type { InterruptResult, HitlDecision } from "@/lib/hitl";
|
||||
|
||||
interface DropboxAccount {
|
||||
id: number;
|
||||
|
|
@ -22,13 +24,10 @@ interface DropboxFile {
|
|||
document_id?: number;
|
||||
}
|
||||
|
||||
interface InterruptResult {
|
||||
__interrupt__: true;
|
||||
__decided__?: "approve" | "reject";
|
||||
__completed__?: boolean;
|
||||
action_requests: Array<{ name: string; args: Record<string, unknown> }>;
|
||||
review_configs: Array<{ action_name: string; allowed_decisions: Array<"approve" | "reject"> }>;
|
||||
context?: { account?: DropboxAccount; file?: DropboxFile; error?: string };
|
||||
type DropboxTrashFileContext = {
|
||||
account?: DropboxAccount;
|
||||
file?: DropboxFile;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
interface SuccessResult {
|
||||
|
|
@ -52,20 +51,12 @@ interface AuthErrorResult {
|
|||
}
|
||||
|
||||
type DeleteDropboxFileResult =
|
||||
| InterruptResult
|
||||
| InterruptResult<DropboxTrashFileContext>
|
||||
| SuccessResult
|
||||
| ErrorResult
|
||||
| NotFoundResult
|
||||
| AuthErrorResult;
|
||||
|
||||
function isInterruptResult(result: unknown): result is InterruptResult {
|
||||
return (
|
||||
typeof result === "object" &&
|
||||
result !== null &&
|
||||
"__interrupt__" in result &&
|
||||
(result as InterruptResult).__interrupt__ === true
|
||||
);
|
||||
}
|
||||
function isErrorResult(result: unknown): result is ErrorResult {
|
||||
return (
|
||||
typeof result === "object" &&
|
||||
|
|
@ -95,12 +86,8 @@ function ApprovalCard({
|
|||
interruptData,
|
||||
onDecision,
|
||||
}: {
|
||||
interruptData: InterruptResult;
|
||||
onDecision: (decision: {
|
||||
type: "approve" | "reject";
|
||||
message?: string;
|
||||
edited_action?: { name: string; args: Record<string, unknown> };
|
||||
}) => void;
|
||||
interruptData: InterruptResult<DropboxTrashFileContext>;
|
||||
onDecision: (decision: HitlDecision) => void;
|
||||
}) {
|
||||
const { phase, setProcessing, setRejected } = useHitlPhase(interruptData);
|
||||
const [deleteFromKb, setDeleteFromKb] = useState(false);
|
||||
|
|
@ -308,16 +295,13 @@ export const DeleteDropboxFileToolUI = ({
|
|||
{ file_name: string; delete_from_kb?: boolean },
|
||||
DeleteDropboxFileResult
|
||||
>) => {
|
||||
const { dispatch } = useHitlDecision();
|
||||
if (!result) return null;
|
||||
if (isInterruptResult(result)) {
|
||||
return (
|
||||
<ApprovalCard
|
||||
interruptData={result}
|
||||
onDecision={(decision) => {
|
||||
window.dispatchEvent(
|
||||
new CustomEvent("hitl-decision", { detail: { decisions: [decision] } })
|
||||
);
|
||||
}}
|
||||
interruptData={result as InterruptResult<DropboxTrashFileContext>}
|
||||
onDecision={(decision) => dispatch([decision])}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
264
surfsense_web/components/tool-ui/generic-hitl-approval.tsx
Normal file
264
surfsense_web/components/tool-ui/generic-hitl-approval.tsx
Normal file
|
|
@ -0,0 +1,264 @@
|
|||
"use client";
|
||||
|
||||
import type { ToolCallMessagePartComponent } from "@assistant-ui/react";
|
||||
import { CornerDownLeftIcon, Pen } from "lucide-react";
|
||||
import { useCallback, useEffect, useMemo, useState } from "react";
|
||||
import { TextShimmerLoader } from "@/components/prompt-kit/loader";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { useHitlPhase } from "@/hooks/use-hitl-phase";
|
||||
import { connectorsApiService } from "@/lib/apis/connectors-api.service";
|
||||
import { isInterruptResult, useHitlDecision } from "@/lib/hitl";
|
||||
import type { HitlDecision, InterruptResult } from "@/lib/hitl";
|
||||
|
||||
function ParamEditor({
|
||||
params,
|
||||
onChange,
|
||||
disabled,
|
||||
}: {
|
||||
params: Record<string, unknown>;
|
||||
onChange: (updated: Record<string, unknown>) => void;
|
||||
disabled: boolean;
|
||||
}) {
|
||||
const entries = Object.entries(params);
|
||||
if (entries.length === 0) return null;
|
||||
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
{entries.map(([key, value]) => {
|
||||
const strValue = value == null ? "" : String(value);
|
||||
const isLong = strValue.length > 120;
|
||||
const fieldId = `hitl-param-${key}`;
|
||||
|
||||
return (
|
||||
<div key={key} className="space-y-1">
|
||||
<label htmlFor={fieldId} className="text-xs font-medium text-muted-foreground">
|
||||
{key}
|
||||
</label>
|
||||
{isLong ? (
|
||||
<Textarea
|
||||
id={fieldId}
|
||||
value={strValue}
|
||||
disabled={disabled}
|
||||
rows={3}
|
||||
onChange={(e) => onChange({ ...params, [key]: e.target.value })}
|
||||
className="text-xs"
|
||||
/>
|
||||
) : (
|
||||
<Input
|
||||
id={fieldId}
|
||||
value={strValue}
|
||||
disabled={disabled}
|
||||
onChange={(e) => onChange({ ...params, [key]: e.target.value })}
|
||||
className="text-xs"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function GenericApprovalCard({
|
||||
toolName,
|
||||
args,
|
||||
interruptData,
|
||||
onDecision,
|
||||
}: {
|
||||
toolName: string;
|
||||
args: Record<string, unknown>;
|
||||
interruptData: InterruptResult;
|
||||
onDecision: (decision: HitlDecision) => void;
|
||||
}) {
|
||||
const { phase, setProcessing, setRejected } = useHitlPhase(interruptData);
|
||||
const [editedParams, setEditedParams] = useState<Record<string, unknown>>(args);
|
||||
const [isEditing, setIsEditing] = useState(false);
|
||||
|
||||
const displayName = toolName.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase());
|
||||
|
||||
const mcpServer = interruptData.context?.mcp_server as string | undefined;
|
||||
const toolDescription = interruptData.context?.tool_description as string | undefined;
|
||||
const mcpConnectorId = interruptData.context?.mcp_connector_id as number | undefined;
|
||||
const isMCPTool = mcpConnectorId != null;
|
||||
|
||||
const reviewConfig = interruptData.review_configs?.[0];
|
||||
const allowedDecisions = reviewConfig?.allowed_decisions ?? ["approve", "reject"];
|
||||
const canEdit = allowedDecisions.includes("edit");
|
||||
|
||||
const hasChanged = useMemo(() => {
|
||||
return JSON.stringify(editedParams) !== JSON.stringify(args);
|
||||
}, [editedParams, args]);
|
||||
|
||||
const handleApprove = useCallback(() => {
|
||||
if (phase !== "pending") return;
|
||||
const isEdited = isEditing && hasChanged;
|
||||
setProcessing();
|
||||
onDecision({
|
||||
type: isEdited ? "edit" : "approve",
|
||||
edited_action: isEdited
|
||||
? { name: interruptData.action_requests[0]?.name ?? toolName, args: editedParams }
|
||||
: undefined,
|
||||
});
|
||||
}, [
|
||||
phase,
|
||||
setProcessing,
|
||||
isEditing,
|
||||
hasChanged,
|
||||
onDecision,
|
||||
interruptData,
|
||||
toolName,
|
||||
editedParams,
|
||||
]);
|
||||
|
||||
const handleAlwaysAllow = useCallback(() => {
|
||||
if (phase !== "pending" || !isMCPTool) return;
|
||||
setProcessing();
|
||||
onDecision({ type: "approve" });
|
||||
connectorsApiService.trustMCPTool(mcpConnectorId, toolName).catch((err) => {
|
||||
console.error("Failed to trust MCP tool:", err);
|
||||
});
|
||||
}, [phase, setProcessing, onDecision, isMCPTool, mcpConnectorId, toolName]);
|
||||
|
||||
useEffect(() => {
|
||||
const handler = (e: KeyboardEvent) => {
|
||||
if (e.key === "Enter" && !e.shiftKey && !e.ctrlKey && !e.metaKey && phase === "pending") {
|
||||
handleApprove();
|
||||
}
|
||||
};
|
||||
window.addEventListener("keydown", handler);
|
||||
return () => window.removeEventListener("keydown", handler);
|
||||
}, [handleApprove, phase]);
|
||||
|
||||
return (
|
||||
<div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 transition-[box-shadow] duration-300">
|
||||
{/* Header */}
|
||||
<div className="flex items-start justify-between px-5 pt-5 pb-4 select-none">
|
||||
<div>
|
||||
<p className="text-sm font-semibold text-foreground">
|
||||
{phase === "rejected"
|
||||
? `${displayName} — Rejected`
|
||||
: phase === "processing" || phase === "complete"
|
||||
? `${displayName} — Approved`
|
||||
: displayName}
|
||||
</p>
|
||||
{phase === "processing" ? (
|
||||
<TextShimmerLoader text="Executing..." size="sm" />
|
||||
) : phase === "complete" ? (
|
||||
<p className="text-xs text-muted-foreground mt-0.5">Action completed</p>
|
||||
) : phase === "rejected" ? (
|
||||
<p className="text-xs text-muted-foreground mt-0.5">Action was cancelled</p>
|
||||
) : (
|
||||
<p className="text-xs text-muted-foreground mt-0.5">
|
||||
Requires your approval to proceed
|
||||
</p>
|
||||
)}
|
||||
{mcpServer && (
|
||||
<p className="text-[10px] text-muted-foreground/70 mt-1">
|
||||
via <span className="font-medium">{mcpServer}</span>
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
{phase === "pending" && canEdit && !isEditing && (
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
className="rounded-lg text-muted-foreground -mt-1 -mr-2"
|
||||
onClick={() => setIsEditing(true)}
|
||||
>
|
||||
<Pen className="size-3.5" />
|
||||
Edit
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Description */}
|
||||
{toolDescription && phase === "pending" && (
|
||||
<>
|
||||
<div className="mx-5 h-px bg-border/50" />
|
||||
<div className="px-5 py-3">
|
||||
<p className="text-xs text-muted-foreground">{toolDescription}</p>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Parameters */}
|
||||
{Object.keys(args).length > 0 && (
|
||||
<>
|
||||
<div className="mx-5 h-px bg-border/50" />
|
||||
<div className="px-5 py-4 space-y-2">
|
||||
<p className="text-xs font-medium text-muted-foreground">Parameters</p>
|
||||
{phase === "pending" && isEditing ? (
|
||||
<ParamEditor
|
||||
params={editedParams}
|
||||
onChange={setEditedParams}
|
||||
disabled={phase !== "pending"}
|
||||
/>
|
||||
) : (
|
||||
<pre className="text-xs text-foreground/80 whitespace-pre-wrap break-all bg-muted/50 rounded-lg p-3">
|
||||
{JSON.stringify(args, null, 2)}
|
||||
</pre>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Action buttons */}
|
||||
{phase === "pending" && (
|
||||
<>
|
||||
<div className="mx-5 h-px bg-border/50" />
|
||||
<div className="px-5 py-4 flex items-center gap-2 select-none">
|
||||
{allowedDecisions.includes("approve") && (
|
||||
<Button size="sm" className="rounded-lg gap-1.5" onClick={handleApprove}>
|
||||
{isEditing && hasChanged ? "Approve with edits" : "Approve"}
|
||||
<CornerDownLeftIcon className="size-3 opacity-60" />
|
||||
</Button>
|
||||
)}
|
||||
{isMCPTool && (
|
||||
<Button
|
||||
size="sm"
|
||||
className="rounded-lg"
|
||||
onClick={handleAlwaysAllow}
|
||||
>
|
||||
Always Allow
|
||||
</Button>
|
||||
)}
|
||||
{allowedDecisions.includes("reject") && (
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
className="rounded-lg text-muted-foreground"
|
||||
onClick={() => {
|
||||
setRejected();
|
||||
onDecision({ type: "reject", message: "User rejected the action." });
|
||||
}}
|
||||
>
|
||||
Reject
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export const GenericHitlApprovalToolUI: ToolCallMessagePartComponent = ({
|
||||
toolName,
|
||||
args,
|
||||
result,
|
||||
}) => {
|
||||
const { dispatch } = useHitlDecision();
|
||||
|
||||
if (!result || !isInterruptResult(result)) return null;
|
||||
|
||||
return (
|
||||
<GenericApprovalCard
|
||||
toolName={toolName}
|
||||
args={args as Record<string, unknown>}
|
||||
interruptData={result}
|
||||
onDecision={(decision) => dispatch([decision])}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
|
@ -16,6 +16,8 @@ import {
|
|||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { isInterruptResult, useHitlDecision } from "@/lib/hitl";
|
||||
import type { HitlDecision, InterruptResult } from "@/lib/hitl";
|
||||
import { useHitlPhase } from "@/hooks/use-hitl-phase";
|
||||
|
||||
interface GmailAccount {
|
||||
|
|
@ -25,22 +27,9 @@ interface GmailAccount {
|
|||
auth_expired?: boolean;
|
||||
}
|
||||
|
||||
interface InterruptResult {
|
||||
__interrupt__: true;
|
||||
__decided__?: "approve" | "reject" | "edit";
|
||||
__completed__?: boolean;
|
||||
action_requests: Array<{
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
}>;
|
||||
review_configs: Array<{
|
||||
action_name: string;
|
||||
allowed_decisions: Array<"approve" | "edit" | "reject">;
|
||||
}>;
|
||||
context?: {
|
||||
accounts?: GmailAccount[];
|
||||
error?: string;
|
||||
};
|
||||
type GmailCreateDraftContext = {
|
||||
accounts?: GmailAccount[];
|
||||
error?: string;
|
||||
}
|
||||
|
||||
interface SuccessResult {
|
||||
|
|
@ -68,21 +57,12 @@ interface InsufficientPermissionsResult {
|
|||
}
|
||||
|
||||
type CreateGmailDraftResult =
|
||||
| InterruptResult
|
||||
| InterruptResult<GmailCreateDraftContext>
|
||||
| SuccessResult
|
||||
| ErrorResult
|
||||
| InsufficientPermissionsResult
|
||||
| AuthErrorResult;
|
||||
|
||||
function isInterruptResult(result: unknown): result is InterruptResult {
|
||||
return (
|
||||
typeof result === "object" &&
|
||||
result !== null &&
|
||||
"__interrupt__" in result &&
|
||||
(result as InterruptResult).__interrupt__ === true
|
||||
);
|
||||
}
|
||||
|
||||
function isErrorResult(result: unknown): result is ErrorResult {
|
||||
return (
|
||||
typeof result === "object" &&
|
||||
|
|
@ -116,12 +96,8 @@ function ApprovalCard({
|
|||
onDecision,
|
||||
}: {
|
||||
args: { to: string; subject: string; body: string; cc?: string; bcc?: string };
|
||||
interruptData: InterruptResult;
|
||||
onDecision: (decision: {
|
||||
type: "approve" | "reject" | "edit";
|
||||
message?: string;
|
||||
edited_action?: { name: string; args: Record<string, unknown> };
|
||||
}) => void;
|
||||
interruptData: InterruptResult<GmailCreateDraftContext>;
|
||||
onDecision: (decision: HitlDecision) => void;
|
||||
}) {
|
||||
const { phase, setProcessing, setRejected } = useHitlPhase(interruptData);
|
||||
const [isPanelOpen, setIsPanelOpen] = useState(false);
|
||||
|
|
@ -473,18 +449,16 @@ export const CreateGmailDraftToolUI = ({
|
|||
{ to: string; subject: string; body: string; cc?: string; bcc?: string },
|
||||
CreateGmailDraftResult
|
||||
>) => {
|
||||
const { dispatch } = useHitlDecision();
|
||||
|
||||
if (!result) return null;
|
||||
|
||||
if (isInterruptResult(result)) {
|
||||
return (
|
||||
<ApprovalCard
|
||||
args={args}
|
||||
interruptData={result}
|
||||
onDecision={(decision) => {
|
||||
window.dispatchEvent(
|
||||
new CustomEvent("hitl-decision", { detail: { decisions: [decision] } })
|
||||
);
|
||||
}}
|
||||
interruptData={result as InterruptResult<GmailCreateDraftContext>}
|
||||
onDecision={(decision) => dispatch([decision])}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ import {
|
|||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { isInterruptResult, useHitlDecision } from "@/lib/hitl";
|
||||
import type { HitlDecision, InterruptResult } from "@/lib/hitl";
|
||||
import { useHitlPhase } from "@/hooks/use-hitl-phase";
|
||||
|
||||
interface GmailAccount {
|
||||
|
|
@ -25,22 +27,9 @@ interface GmailAccount {
|
|||
auth_expired?: boolean;
|
||||
}
|
||||
|
||||
interface InterruptResult {
|
||||
__interrupt__: true;
|
||||
__decided__?: "approve" | "reject" | "edit";
|
||||
__completed__?: boolean;
|
||||
action_requests: Array<{
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
}>;
|
||||
review_configs: Array<{
|
||||
action_name: string;
|
||||
allowed_decisions: Array<"approve" | "edit" | "reject">;
|
||||
}>;
|
||||
context?: {
|
||||
accounts?: GmailAccount[];
|
||||
error?: string;
|
||||
};
|
||||
type GmailSendEmailContext = {
|
||||
accounts?: GmailAccount[];
|
||||
error?: string;
|
||||
}
|
||||
|
||||
interface SuccessResult {
|
||||
|
|
@ -67,21 +56,12 @@ interface InsufficientPermissionsResult {
|
|||
}
|
||||
|
||||
type SendGmailEmailResult =
|
||||
| InterruptResult
|
||||
| InterruptResult<GmailSendEmailContext>
|
||||
| SuccessResult
|
||||
| ErrorResult
|
||||
| InsufficientPermissionsResult
|
||||
| AuthErrorResult;
|
||||
|
||||
function isInterruptResult(result: unknown): result is InterruptResult {
|
||||
return (
|
||||
typeof result === "object" &&
|
||||
result !== null &&
|
||||
"__interrupt__" in result &&
|
||||
(result as InterruptResult).__interrupt__ === true
|
||||
);
|
||||
}
|
||||
|
||||
function isErrorResult(result: unknown): result is ErrorResult {
|
||||
return (
|
||||
typeof result === "object" &&
|
||||
|
|
@ -115,12 +95,8 @@ function ApprovalCard({
|
|||
onDecision,
|
||||
}: {
|
||||
args: { to: string; subject: string; body: string; cc?: string; bcc?: string };
|
||||
interruptData: InterruptResult;
|
||||
onDecision: (decision: {
|
||||
type: "approve" | "reject" | "edit";
|
||||
message?: string;
|
||||
edited_action?: { name: string; args: Record<string, unknown> };
|
||||
}) => void;
|
||||
interruptData: InterruptResult<GmailSendEmailContext>;
|
||||
onDecision: (decision: HitlDecision) => void;
|
||||
}) {
|
||||
const { phase, setProcessing, setRejected } = useHitlPhase(interruptData);
|
||||
const [isPanelOpen, setIsPanelOpen] = useState(false);
|
||||
|
|
@ -471,18 +447,16 @@ export const SendGmailEmailToolUI = ({
|
|||
{ to: string; subject: string; body: string; cc?: string; bcc?: string },
|
||||
SendGmailEmailResult
|
||||
>) => {
|
||||
const { dispatch } = useHitlDecision();
|
||||
|
||||
if (!result) return null;
|
||||
|
||||
if (isInterruptResult(result)) {
|
||||
return (
|
||||
<ApprovalCard
|
||||
args={args}
|
||||
interruptData={result}
|
||||
onDecision={(decision) => {
|
||||
window.dispatchEvent(
|
||||
new CustomEvent("hitl-decision", { detail: { decisions: [decision] } })
|
||||
);
|
||||
}}
|
||||
interruptData={result as InterruptResult<GmailSendEmailContext>}
|
||||
onDecision={(decision) => dispatch([decision])}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ import { useCallback, useEffect, useState } from "react";
|
|||
import { TextShimmerLoader } from "@/components/prompt-kit/loader";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Checkbox } from "@/components/ui/checkbox";
|
||||
import { isInterruptResult, useHitlDecision } from "@/lib/hitl";
|
||||
import type { HitlDecision, InterruptResult } from "@/lib/hitl";
|
||||
import { useHitlPhase } from "@/hooks/use-hitl-phase";
|
||||
|
||||
interface GmailAccount {
|
||||
|
|
@ -25,23 +27,10 @@ interface GmailMessage {
|
|||
document_id: number;
|
||||
}
|
||||
|
||||
interface InterruptResult {
|
||||
__interrupt__: true;
|
||||
__decided__?: "approve" | "reject";
|
||||
__completed__?: boolean;
|
||||
action_requests: Array<{
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
}>;
|
||||
review_configs: Array<{
|
||||
action_name: string;
|
||||
allowed_decisions: Array<"approve" | "reject">;
|
||||
}>;
|
||||
context?: {
|
||||
account?: GmailAccount;
|
||||
email?: GmailMessage;
|
||||
error?: string;
|
||||
};
|
||||
type GmailTrashEmailContext = {
|
||||
account?: GmailAccount;
|
||||
email?: GmailMessage;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
interface SuccessResult {
|
||||
|
|
@ -74,22 +63,13 @@ interface InsufficientPermissionsResult {
|
|||
}
|
||||
|
||||
type TrashGmailEmailResult =
|
||||
| InterruptResult
|
||||
| InterruptResult<GmailTrashEmailContext>
|
||||
| SuccessResult
|
||||
| ErrorResult
|
||||
| NotFoundResult
|
||||
| InsufficientPermissionsResult
|
||||
| AuthErrorResult;
|
||||
|
||||
function isInterruptResult(result: unknown): result is InterruptResult {
|
||||
return (
|
||||
typeof result === "object" &&
|
||||
result !== null &&
|
||||
"__interrupt__" in result &&
|
||||
(result as InterruptResult).__interrupt__ === true
|
||||
);
|
||||
}
|
||||
|
||||
function isErrorResult(result: unknown): result is ErrorResult {
|
||||
return (
|
||||
typeof result === "object" &&
|
||||
|
|
@ -134,12 +114,8 @@ function ApprovalCard({
|
|||
interruptData,
|
||||
onDecision,
|
||||
}: {
|
||||
interruptData: InterruptResult;
|
||||
onDecision: (decision: {
|
||||
type: "approve" | "reject";
|
||||
message?: string;
|
||||
edited_action?: { name: string; args: Record<string, unknown> };
|
||||
}) => void;
|
||||
interruptData: InterruptResult<GmailTrashEmailContext>;
|
||||
onDecision: (decision: HitlDecision) => void;
|
||||
}) {
|
||||
const { phase, setProcessing, setRejected } = useHitlPhase(interruptData);
|
||||
const [deleteFromKb, setDeleteFromKb] = useState(false);
|
||||
|
|
@ -385,18 +361,15 @@ export const TrashGmailEmailToolUI = ({
|
|||
{ email_subject_or_id: string; delete_from_kb?: boolean },
|
||||
TrashGmailEmailResult
|
||||
>) => {
|
||||
const { dispatch } = useHitlDecision();
|
||||
|
||||
if (!result) return null;
|
||||
|
||||
if (isInterruptResult(result)) {
|
||||
return (
|
||||
<ApprovalCard
|
||||
interruptData={result}
|
||||
onDecision={(decision) => {
|
||||
const event = new CustomEvent("hitl-decision", {
|
||||
detail: { decisions: [decision] },
|
||||
});
|
||||
window.dispatchEvent(event);
|
||||
}}
|
||||
interruptData={result as InterruptResult<GmailTrashEmailContext>}
|
||||
onDecision={(decision) => dispatch([decision])}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom";
|
|||
import { PlateEditor } from "@/components/editor/plate-editor";
|
||||
import { TextShimmerLoader } from "@/components/prompt-kit/loader";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { isInterruptResult, useHitlDecision } from "@/lib/hitl";
|
||||
import type { HitlDecision, InterruptResult } from "@/lib/hitl";
|
||||
import { useHitlPhase } from "@/hooks/use-hitl-phase";
|
||||
|
||||
interface GmailAccount {
|
||||
|
|
@ -28,25 +30,12 @@ interface GmailMessage {
|
|||
document_id: number;
|
||||
}
|
||||
|
||||
interface InterruptResult {
|
||||
__interrupt__: true;
|
||||
__decided__?: "approve" | "reject" | "edit";
|
||||
__completed__?: boolean;
|
||||
action_requests: Array<{
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
}>;
|
||||
review_configs: Array<{
|
||||
action_name: string;
|
||||
allowed_decisions: Array<"approve" | "edit" | "reject">;
|
||||
}>;
|
||||
context?: {
|
||||
account?: GmailAccount;
|
||||
email?: GmailMessage;
|
||||
draft_id?: string;
|
||||
existing_body?: string;
|
||||
error?: string;
|
||||
};
|
||||
type GmailUpdateDraftContext = {
|
||||
account?: GmailAccount;
|
||||
email?: GmailMessage;
|
||||
draft_id?: string;
|
||||
existing_body?: string;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
interface SuccessResult {
|
||||
|
|
@ -78,22 +67,13 @@ interface InsufficientPermissionsResult {
|
|||
}
|
||||
|
||||
type UpdateGmailDraftResult =
|
||||
| InterruptResult
|
||||
| InterruptResult<GmailUpdateDraftContext>
|
||||
| SuccessResult
|
||||
| ErrorResult
|
||||
| NotFoundResult
|
||||
| InsufficientPermissionsResult
|
||||
| AuthErrorResult;
|
||||
|
||||
function isInterruptResult(result: unknown): result is InterruptResult {
|
||||
return (
|
||||
typeof result === "object" &&
|
||||
result !== null &&
|
||||
"__interrupt__" in result &&
|
||||
(result as InterruptResult).__interrupt__ === true
|
||||
);
|
||||
}
|
||||
|
||||
function isErrorResult(result: unknown): result is ErrorResult {
|
||||
return (
|
||||
typeof result === "object" &&
|
||||
|
|
@ -143,12 +123,8 @@ function ApprovalCard({
|
|||
cc?: string;
|
||||
bcc?: string;
|
||||
};
|
||||
interruptData: InterruptResult;
|
||||
onDecision: (decision: {
|
||||
type: "approve" | "reject" | "edit";
|
||||
message?: string;
|
||||
edited_action?: { name: string; args: Record<string, unknown> };
|
||||
}) => void;
|
||||
interruptData: InterruptResult<GmailUpdateDraftContext>;
|
||||
onDecision: (decision: HitlDecision) => void;
|
||||
}) {
|
||||
const { phase, setProcessing, setRejected } = useHitlPhase(interruptData);
|
||||
const [isPanelOpen, setIsPanelOpen] = useState(false);
|
||||
|
|
@ -522,20 +498,16 @@ export const UpdateGmailDraftToolUI = ({
|
|||
},
|
||||
UpdateGmailDraftResult
|
||||
>) => {
|
||||
const { dispatch } = useHitlDecision();
|
||||
|
||||
if (!result) return null;
|
||||
|
||||
if (isInterruptResult(result)) {
|
||||
return (
|
||||
<ApprovalCard
|
||||
args={args}
|
||||
interruptData={result}
|
||||
onDecision={(decision) => {
|
||||
window.dispatchEvent(
|
||||
new CustomEvent("hitl-decision", {
|
||||
detail: { decisions: [decision] },
|
||||
})
|
||||
);
|
||||
}}
|
||||
interruptData={result as InterruptResult<GmailUpdateDraftContext>}
|
||||
onDecision={(decision) => dispatch([decision])}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ import {
|
|||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { isInterruptResult, useHitlDecision } from "@/lib/hitl";
|
||||
import type { HitlDecision, InterruptResult } from "@/lib/hitl";
|
||||
import { useHitlPhase } from "@/hooks/use-hitl-phase";
|
||||
|
||||
interface GoogleCalendarAccount {
|
||||
|
|
@ -30,24 +32,11 @@ interface CalendarEntry {
|
|||
primary?: boolean;
|
||||
}
|
||||
|
||||
interface InterruptResult {
|
||||
__interrupt__: true;
|
||||
__decided__?: "approve" | "reject" | "edit";
|
||||
__completed__?: boolean;
|
||||
action_requests: Array<{
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
}>;
|
||||
review_configs: Array<{
|
||||
action_name: string;
|
||||
allowed_decisions: Array<"approve" | "edit" | "reject">;
|
||||
}>;
|
||||
context?: {
|
||||
accounts?: GoogleCalendarAccount[];
|
||||
calendars?: CalendarEntry[];
|
||||
timezone?: string;
|
||||
error?: string;
|
||||
};
|
||||
type CalendarCreateEventContext = {
|
||||
accounts?: GoogleCalendarAccount[];
|
||||
calendars?: CalendarEntry[];
|
||||
timezone?: string;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
interface SuccessResult {
|
||||
|
|
@ -75,21 +64,12 @@ interface InsufficientPermissionsResult {
|
|||
}
|
||||
|
||||
type CreateCalendarEventResult =
|
||||
| InterruptResult
|
||||
| InterruptResult<CalendarCreateEventContext>
|
||||
| SuccessResult
|
||||
| ErrorResult
|
||||
| InsufficientPermissionsResult
|
||||
| AuthErrorResult;
|
||||
|
||||
function isInterruptResult(result: unknown): result is InterruptResult {
|
||||
return (
|
||||
typeof result === "object" &&
|
||||
result !== null &&
|
||||
"__interrupt__" in result &&
|
||||
(result as InterruptResult).__interrupt__ === true
|
||||
);
|
||||
}
|
||||
|
||||
function isErrorResult(result: unknown): result is ErrorResult {
|
||||
return (
|
||||
typeof result === "object" &&
|
||||
|
|
@ -141,12 +121,8 @@ function ApprovalCard({
|
|||
location?: string;
|
||||
attendees?: string[];
|
||||
};
|
||||
interruptData: InterruptResult;
|
||||
onDecision: (decision: {
|
||||
type: "approve" | "reject" | "edit";
|
||||
message?: string;
|
||||
edited_action?: { name: string; args: Record<string, unknown> };
|
||||
}) => void;
|
||||
interruptData: InterruptResult<CalendarCreateEventContext>;
|
||||
onDecision: (decision: HitlDecision) => void;
|
||||
}) {
|
||||
const { phase, setProcessing, setRejected } = useHitlPhase(interruptData);
|
||||
const [isPanelOpen, setIsPanelOpen] = useState(false);
|
||||
|
|
@ -620,18 +596,16 @@ export const CreateCalendarEventToolUI = ({
|
|||
},
|
||||
CreateCalendarEventResult
|
||||
>) => {
|
||||
const { dispatch } = useHitlDecision();
|
||||
|
||||
if (!result) return null;
|
||||
|
||||
if (isInterruptResult(result)) {
|
||||
return (
|
||||
<ApprovalCard
|
||||
args={args}
|
||||
interruptData={result}
|
||||
onDecision={(decision) => {
|
||||
window.dispatchEvent(
|
||||
new CustomEvent("hitl-decision", { detail: { decisions: [decision] } })
|
||||
);
|
||||
}}
|
||||
interruptData={result as InterruptResult<CalendarCreateEventContext>}
|
||||
onDecision={(decision) => dispatch([decision])}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ import { useCallback, useEffect, useState } from "react";
|
|||
import { TextShimmerLoader } from "@/components/prompt-kit/loader";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Checkbox } from "@/components/ui/checkbox";
|
||||
import { isInterruptResult, useHitlDecision } from "@/lib/hitl";
|
||||
import type { HitlDecision, InterruptResult } from "@/lib/hitl";
|
||||
import { useHitlPhase } from "@/hooks/use-hitl-phase";
|
||||
|
||||
interface GoogleCalendarAccount {
|
||||
|
|
@ -27,23 +29,10 @@ interface CalendarEvent {
|
|||
indexed_at?: string;
|
||||
}
|
||||
|
||||
interface InterruptResult {
|
||||
__interrupt__: true;
|
||||
__decided__?: "approve" | "reject";
|
||||
__completed__?: boolean;
|
||||
action_requests: Array<{
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
}>;
|
||||
review_configs: Array<{
|
||||
action_name: string;
|
||||
allowed_decisions: Array<"approve" | "reject">;
|
||||
}>;
|
||||
context?: {
|
||||
account?: GoogleCalendarAccount;
|
||||
event?: CalendarEvent;
|
||||
error?: string;
|
||||
};
|
||||
type CalendarDeleteEventContext = {
|
||||
account?: GoogleCalendarAccount;
|
||||
event?: CalendarEvent;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
interface SuccessResult {
|
||||
|
|
@ -83,7 +72,7 @@ interface InsufficientPermissionsResult {
|
|||
}
|
||||
|
||||
type DeleteCalendarEventResult =
|
||||
| InterruptResult
|
||||
| InterruptResult<CalendarDeleteEventContext>
|
||||
| SuccessResult
|
||||
| ErrorResult
|
||||
| NotFoundResult
|
||||
|
|
@ -91,15 +80,6 @@ type DeleteCalendarEventResult =
|
|||
| InsufficientPermissionsResult
|
||||
| AuthErrorResult;
|
||||
|
||||
function isInterruptResult(result: unknown): result is InterruptResult {
|
||||
return (
|
||||
typeof result === "object" &&
|
||||
result !== null &&
|
||||
"__interrupt__" in result &&
|
||||
(result as InterruptResult).__interrupt__ === true
|
||||
);
|
||||
}
|
||||
|
||||
function isErrorResult(result: unknown): result is ErrorResult {
|
||||
return (
|
||||
typeof result === "object" &&
|
||||
|
|
@ -162,12 +142,8 @@ function ApprovalCard({
|
|||
interruptData,
|
||||
onDecision,
|
||||
}: {
|
||||
interruptData: InterruptResult;
|
||||
onDecision: (decision: {
|
||||
type: "approve" | "reject";
|
||||
message?: string;
|
||||
edited_action?: { name: string; args: Record<string, unknown> };
|
||||
}) => void;
|
||||
interruptData: InterruptResult<CalendarDeleteEventContext>;
|
||||
onDecision: (decision: HitlDecision) => void;
|
||||
}) {
|
||||
const { phase, setProcessing, setRejected } = useHitlPhase(interruptData);
|
||||
const [deleteFromKb, setDeleteFromKb] = useState(false);
|
||||
|
|
@ -437,18 +413,15 @@ export const DeleteCalendarEventToolUI = ({
|
|||
{ event_title_or_id: string; delete_from_kb?: boolean },
|
||||
DeleteCalendarEventResult
|
||||
>) => {
|
||||
const { dispatch } = useHitlDecision();
|
||||
|
||||
if (!result) return null;
|
||||
|
||||
if (isInterruptResult(result)) {
|
||||
return (
|
||||
<ApprovalCard
|
||||
interruptData={result}
|
||||
onDecision={(decision) => {
|
||||
const event = new CustomEvent("hitl-decision", {
|
||||
detail: { decisions: [decision] },
|
||||
});
|
||||
window.dispatchEvent(event);
|
||||
}}
|
||||
interruptData={result as InterruptResult<CalendarDeleteEventContext>}
|
||||
onDecision={(decision) => dispatch([decision])}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ import { openHitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom";
|
|||
import { PlateEditor } from "@/components/editor/plate-editor";
|
||||
import { TextShimmerLoader } from "@/components/prompt-kit/loader";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { isInterruptResult, useHitlDecision } from "@/lib/hitl";
|
||||
import type { HitlDecision, InterruptResult } from "@/lib/hitl";
|
||||
import { useHitlPhase } from "@/hooks/use-hitl-phase";
|
||||
|
||||
interface GoogleCalendarAccount {
|
||||
|
|
@ -37,23 +39,10 @@ interface CalendarEvent {
|
|||
indexed_at?: string;
|
||||
}
|
||||
|
||||
interface InterruptResult {
|
||||
__interrupt__: true;
|
||||
__decided__?: "approve" | "reject" | "edit";
|
||||
__completed__?: boolean;
|
||||
action_requests: Array<{
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
}>;
|
||||
review_configs: Array<{
|
||||
action_name: string;
|
||||
allowed_decisions: Array<"approve" | "edit" | "reject">;
|
||||
}>;
|
||||
context?: {
|
||||
account?: GoogleCalendarAccount;
|
||||
event?: CalendarEvent;
|
||||
error?: string;
|
||||
};
|
||||
type CalendarUpdateEventContext = {
|
||||
account?: GoogleCalendarAccount;
|
||||
event?: CalendarEvent;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
interface SuccessResult {
|
||||
|
|
@ -86,22 +75,13 @@ interface InsufficientPermissionsResult {
|
|||
}
|
||||
|
||||
type UpdateCalendarEventResult =
|
||||
| InterruptResult
|
||||
| InterruptResult<CalendarUpdateEventContext>
|
||||
| SuccessResult
|
||||
| ErrorResult
|
||||
| NotFoundResult
|
||||
| InsufficientPermissionsResult
|
||||
| AuthErrorResult;
|
||||
|
||||
function isInterruptResult(result: unknown): result is InterruptResult {
|
||||
return (
|
||||
typeof result === "object" &&
|
||||
result !== null &&
|
||||
"__interrupt__" in result &&
|
||||
(result as InterruptResult).__interrupt__ === true
|
||||
);
|
||||
}
|
||||
|
||||
function isErrorResult(result: unknown): result is ErrorResult {
|
||||
return (
|
||||
typeof result === "object" &&
|
||||
|
|
@ -163,12 +143,8 @@ function ApprovalCard({
|
|||
new_location?: string;
|
||||
new_attendees?: string[];
|
||||
};
|
||||
interruptData: InterruptResult;
|
||||
onDecision: (decision: {
|
||||
type: "approve" | "reject" | "edit";
|
||||
message?: string;
|
||||
edited_action?: { name: string; args: Record<string, unknown> };
|
||||
}) => void;
|
||||
interruptData: InterruptResult<CalendarUpdateEventContext>;
|
||||
onDecision: (decision: HitlDecision) => void;
|
||||
}) {
|
||||
const { phase, setProcessing, setRejected } = useHitlPhase(interruptData);
|
||||
const actionArgs = interruptData.action_requests[0]?.args ?? {};
|
||||
|
|
@ -686,18 +662,16 @@ export const UpdateCalendarEventToolUI = ({
|
|||
},
|
||||
UpdateCalendarEventResult
|
||||
>) => {
|
||||
const { dispatch } = useHitlDecision();
|
||||
|
||||
if (!result) return null;
|
||||
|
||||
if (isInterruptResult(result)) {
|
||||
return (
|
||||
<ApprovalCard
|
||||
args={args}
|
||||
interruptData={result}
|
||||
onDecision={(decision) => {
|
||||
window.dispatchEvent(
|
||||
new CustomEvent("hitl-decision", { detail: { decisions: [decision] } })
|
||||
);
|
||||
}}
|
||||
interruptData={result as InterruptResult<CalendarUpdateEventContext>}
|
||||
onDecision={(decision) => dispatch([decision])}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue