mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
feat: test script of new agent
This commit is contained in:
parent
e48aa3f1c7
commit
c6cc7c2a6a
14 changed files with 5079 additions and 2923 deletions
13
.vscode/launch.json
vendored
13
.vscode/launch.json
vendored
|
|
@ -32,6 +32,19 @@
|
|||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"cwd": "${workspaceFolder}/surfsense_backend"
|
||||
},
|
||||
{
|
||||
"name": "Python Debugger: Chat DeepAgent",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "app.agents.new_chat.chat_deepagent",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"cwd": "${workspaceFolder}/surfsense_backend",
|
||||
"python": "${command:python.interpreterPath}",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}/surfsense_backend"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
27
surfsense_backend/app/agents/new_chat/__init__.py
Normal file
27
surfsense_backend/app/agents/new_chat/__init__.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"""Chat agents module."""
|
||||
|
||||
from app.agents.new_chat.chat_deepagent import (
|
||||
SURFSENSE_CITATION_INSTRUCTIONS,
|
||||
SURFSENSE_SYSTEM_PROMPT,
|
||||
SurfSenseContextSchema,
|
||||
build_surfsense_system_prompt,
|
||||
create_chat_litellm_from_config,
|
||||
create_search_knowledge_base_tool,
|
||||
create_surfsense_deep_agent,
|
||||
format_documents_for_context,
|
||||
load_llm_config_from_yaml,
|
||||
search_knowledge_base_async,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SURFSENSE_CITATION_INSTRUCTIONS",
|
||||
"SURFSENSE_SYSTEM_PROMPT",
|
||||
"SurfSenseContextSchema",
|
||||
"build_surfsense_system_prompt",
|
||||
"create_chat_litellm_from_config",
|
||||
"create_search_knowledge_base_tool",
|
||||
"create_surfsense_deep_agent",
|
||||
"format_documents_for_context",
|
||||
"load_llm_config_from_yaml",
|
||||
"search_knowledge_base_async",
|
||||
]
|
||||
998
surfsense_backend/app/agents/new_chat/chat_deepagent.py
Normal file
998
surfsense_backend/app/agents/new_chat/chat_deepagent.py
Normal file
|
|
@ -0,0 +1,998 @@
|
|||
"""
|
||||
Test script for create_deep_agent with ChatLiteLLM from global_llm_config.yaml
|
||||
|
||||
This demonstrates:
|
||||
1. Loading LLM config from global_llm_config.yaml
|
||||
2. Creating a ChatLiteLLM instance
|
||||
3. Using context_schema to add custom state fields
|
||||
4. Creating a search_knowledge_base tool similar to fetch_relevant_documents
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path so 'app' module can be found when running directly
|
||||
_THIS_FILE = Path(__file__).resolve()
|
||||
_BACKEND_ROOT = _THIS_FILE.parent.parent.parent.parent # surfsense_backend/
|
||||
if str(_BACKEND_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_BACKEND_ROOT))
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import yaml
|
||||
from deepagents import create_deep_agent
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.tools import tool
|
||||
from langchain_litellm import ChatLiteLLM
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import async_session_maker
|
||||
from app.services.connector_service import ConnectorService
|
||||
|
||||
# =============================================================================
|
||||
# LLM Configuration Loading
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def load_llm_config_from_yaml(llm_config_id: int = -1) -> dict | None:
|
||||
"""
|
||||
Load a specific LLM config from global_llm_config.yaml.
|
||||
|
||||
Args:
|
||||
llm_config_id: The id of the config to load (default: -1)
|
||||
|
||||
Returns:
|
||||
LLM config dict or None if not found
|
||||
"""
|
||||
# Get the config file path
|
||||
base_dir = Path(__file__).resolve().parent.parent.parent.parent
|
||||
config_file = base_dir / "app" / "config" / "global_llm_config.yaml"
|
||||
|
||||
# Fallback to example file if main config doesn't exist
|
||||
if not config_file.exists():
|
||||
config_file = base_dir / "app" / "config" / "global_llm_config.example.yaml"
|
||||
if not config_file.exists():
|
||||
print("Error: No global_llm_config.yaml or example file found")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
configs = data.get("global_llm_configs", [])
|
||||
for cfg in configs:
|
||||
if isinstance(cfg, dict) and cfg.get("id") == llm_config_id:
|
||||
return cfg
|
||||
|
||||
print(f"Error: Global LLM config id {llm_config_id} not found")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error loading config: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
||||
"""
|
||||
Create a ChatLiteLLM instance from a global LLM config.
|
||||
|
||||
Args:
|
||||
llm_config: LLM configuration dictionary from YAML
|
||||
|
||||
Returns:
|
||||
ChatLiteLLM instance or None on error
|
||||
"""
|
||||
# Provider mapping (same as in llm_service.py)
|
||||
provider_map = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
"OLLAMA": "ollama",
|
||||
"MISTRAL": "mistral",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XAI": "xai",
|
||||
"BEDROCK": "bedrock",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"TOGETHER_AI": "together_ai",
|
||||
"FIREWORKS_AI": "fireworks_ai",
|
||||
"DEEPSEEK": "openai",
|
||||
"ALIBABA_QWEN": "openai",
|
||||
"MOONSHOT": "openai",
|
||||
"ZHIPU": "openai",
|
||||
}
|
||||
|
||||
# Build the model string
|
||||
if llm_config.get("custom_provider"):
|
||||
model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}"
|
||||
else:
|
||||
provider = llm_config.get("provider", "").upper()
|
||||
provider_prefix = provider_map.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{llm_config['model_name']}"
|
||||
|
||||
# Create ChatLiteLLM instance
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": llm_config.get("api_key"),
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if llm_config.get("api_base"):
|
||||
litellm_kwargs["api_base"] = llm_config["api_base"]
|
||||
|
||||
# Add any additional litellm parameters
|
||||
if llm_config.get("litellm_params"):
|
||||
litellm_kwargs.update(llm_config["litellm_params"])
|
||||
|
||||
return ChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Custom Context Schema
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class SurfSenseContextSchema(TypedDict):
|
||||
"""
|
||||
Custom state schema for the SurfSense deep agent.
|
||||
|
||||
This extends the default agent state with custom fields.
|
||||
The default state already includes:
|
||||
- messages: Conversation history
|
||||
- todos: Task list from TodoListMiddleware
|
||||
- files: Virtual filesystem from FilesystemMiddleware
|
||||
|
||||
We're adding fields needed for knowledge base search:
|
||||
- search_space_id: The user's search space ID
|
||||
- db_session: Database session (injected at runtime)
|
||||
- connector_service: Connector service instance (injected at runtime)
|
||||
"""
|
||||
|
||||
search_space_id: int
|
||||
# These are runtime-injected and won't be serialized
|
||||
# db_session and connector_service are passed when invoking the agent
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Knowledge Base Search Tool
|
||||
# =============================================================================
|
||||
|
||||
# Canonical connector values used internally by ConnectorService
|
||||
_ALL_CONNECTORS: list[str] = [
|
||||
"EXTENSION",
|
||||
"FILE",
|
||||
"SLACK_CONNECTOR",
|
||||
"NOTION_CONNECTOR",
|
||||
"YOUTUBE_VIDEO",
|
||||
"GITHUB_CONNECTOR",
|
||||
"ELASTICSEARCH_CONNECTOR",
|
||||
"LINEAR_CONNECTOR",
|
||||
"JIRA_CONNECTOR",
|
||||
"CONFLUENCE_CONNECTOR",
|
||||
"CLICKUP_CONNECTOR",
|
||||
"GOOGLE_CALENDAR_CONNECTOR",
|
||||
"GOOGLE_GMAIL_CONNECTOR",
|
||||
"DISCORD_CONNECTOR",
|
||||
"AIRTABLE_CONNECTOR",
|
||||
"TAVILY_API",
|
||||
"SEARXNG_API",
|
||||
"LINKUP_API",
|
||||
"BAIDU_SEARCH_API",
|
||||
"LUMA_CONNECTOR",
|
||||
"NOTE",
|
||||
"BOOKSTACK_CONNECTOR",
|
||||
"CRAWLED_URL",
|
||||
]
|
||||
|
||||
|
||||
def _normalize_connectors(connectors_to_search: list[str] | None) -> list[str]:
|
||||
"""
|
||||
Normalize connectors provided by the model.
|
||||
|
||||
- Accepts user-facing enums like WEBCRAWLER_CONNECTOR and maps them to canonical
|
||||
ConnectorService types.
|
||||
- Drops unknown values.
|
||||
- If None/empty, defaults to searching across all known connectors.
|
||||
"""
|
||||
if not connectors_to_search:
|
||||
return list(_ALL_CONNECTORS)
|
||||
|
||||
normalized: list[str] = []
|
||||
for raw in connectors_to_search:
|
||||
c = (raw or "").strip().upper()
|
||||
if not c:
|
||||
continue
|
||||
if c == "WEBCRAWLER_CONNECTOR":
|
||||
c = "CRAWLED_URL"
|
||||
normalized.append(c)
|
||||
|
||||
# de-dupe while preserving order + filter unknown
|
||||
seen: set[str] = set()
|
||||
out: list[str] = []
|
||||
for c in normalized:
|
||||
if c in seen:
|
||||
continue
|
||||
if c not in _ALL_CONNECTORS:
|
||||
continue
|
||||
seen.add(c)
|
||||
out.append(c)
|
||||
return out if out else list(_ALL_CONNECTORS)
|
||||
|
||||
|
||||
SURFSENSE_CITATION_INSTRUCTIONS = """
|
||||
<citation_instructions>
|
||||
CRITICAL CITATION REQUIREMENTS:
|
||||
|
||||
1. For EVERY piece of information you include from the documents, add a citation in the format [citation:chunk_id] where chunk_id is the exact value from the `<chunk id='...'>` tag inside `<document_content>`.
|
||||
2. Make sure ALL factual statements from the documents have proper citations.
|
||||
3. If multiple chunks support the same point, include all relevant citations [citation:chunk_id1], [citation:chunk_id2].
|
||||
4. You MUST use the exact chunk_id values from the `<chunk id='...'>` attributes. Do not create your own citation numbers.
|
||||
5. Every citation MUST be in the format [citation:chunk_id] where chunk_id is the exact chunk id value.
|
||||
6. Never modify or change the chunk_id - always use the original values exactly as provided in the chunk tags.
|
||||
7. Do not return citations as clickable links.
|
||||
8. Never format citations as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only.
|
||||
9. Citations must ONLY appear as [citation:chunk_id] or [citation:chunk_id1], [citation:chunk_id2] format - never with parentheses, hyperlinks, or other formatting.
|
||||
10. Never make up chunk IDs. Only use chunk_id values that are explicitly provided in the `<chunk id='...'>` tags.
|
||||
11. If you are unsure about a chunk_id, do not include a citation rather than guessing or making one up.
|
||||
|
||||
<document_structure_example>
|
||||
The documents you receive are structured like this:
|
||||
|
||||
<document>
|
||||
<document_metadata>
|
||||
<document_id>42</document_id>
|
||||
<document_type>GITHUB_CONNECTOR</document_type>
|
||||
<title><![CDATA[Some repo / file / issue title]]></title>
|
||||
<url><![CDATA[https://example.com]]></url>
|
||||
<metadata_json><![CDATA[{{"any":"other metadata"}}]]></metadata_json>
|
||||
</document_metadata>
|
||||
|
||||
<document_content>
|
||||
<chunk id='123'><![CDATA[First chunk text...]]></chunk>
|
||||
<chunk id='124'><![CDATA[Second chunk text...]]></chunk>
|
||||
</document_content>
|
||||
</document>
|
||||
|
||||
IMPORTANT: You MUST cite using the chunk ids (e.g. 123, 124). Do NOT cite document_id.
|
||||
</document_structure_example>
|
||||
|
||||
<citation_format>
|
||||
- Every fact from the documents must have a citation in the format [citation:chunk_id] where chunk_id is the EXACT id value from a `<chunk id='...'>` tag
|
||||
- Citations should appear at the end of the sentence containing the information they support
|
||||
- Multiple citations should be separated by commas: [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3]
|
||||
- No need to return references section. Just citations in answer.
|
||||
- NEVER create your own citation format - use the exact chunk_id values from the documents in the [citation:chunk_id] format
|
||||
- NEVER format citations as clickable links or as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only
|
||||
- NEVER make up chunk IDs if you are unsure about the chunk_id. It is better to omit the citation than to guess
|
||||
</citation_format>
|
||||
|
||||
<citation_examples>
|
||||
CORRECT citation formats:
|
||||
- [citation:5]
|
||||
- [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3]
|
||||
|
||||
INCORRECT citation formats (DO NOT use):
|
||||
- Using parentheses and markdown links: ([citation:5](https://github.com/MODSetter/SurfSense))
|
||||
- Using parentheses around brackets: ([citation:5])
|
||||
- Using hyperlinked text: [link to source 5](https://example.com)
|
||||
- Using footnote style: ... library¹
|
||||
- Making up source IDs when source_id is unknown
|
||||
- Using old IEEE format: [1], [2], [3]
|
||||
- Using source types instead of IDs: [citation:GITHUB_CONNECTOR] instead of [citation:5]
|
||||
</citation_examples>
|
||||
|
||||
<citation_output_example>
|
||||
Based on your GitHub repositories and video content, Python's asyncio library provides tools for writing concurrent code using the async/await syntax [citation:5]. It's particularly useful for I/O-bound and high-level structured network code [citation:5].
|
||||
|
||||
The key advantage of asyncio is that it can improve performance by allowing other code to run while waiting for I/O operations to complete [citation:12]. This makes it excellent for scenarios like web scraping, API calls, database operations, or any situation where your program spends time waiting for external resources.
|
||||
|
||||
However, from your video learning, it's important to note that asyncio is not suitable for CPU-bound tasks as it runs on a single thread [citation:12]. For computationally intensive work, you'd want to use multiprocessing instead.
|
||||
</citation_output_example>
|
||||
</citation_instructions>
|
||||
"""
|
||||
|
||||
|
||||
def _parse_date_or_datetime(value: str) -> datetime:
|
||||
"""
|
||||
Parse either an ISO date (YYYY-MM-DD) or ISO datetime into an aware UTC datetime.
|
||||
|
||||
- If `value` is a date, interpret it as start-of-day in UTC.
|
||||
- If `value` is a datetime without timezone, assume UTC.
|
||||
"""
|
||||
raw = (value or "").strip()
|
||||
if not raw:
|
||||
raise ValueError("Empty date string")
|
||||
|
||||
# Date-only
|
||||
if "T" not in raw:
|
||||
d = datetime.fromisoformat(raw).date()
|
||||
return datetime(d.year, d.month, d.day, tzinfo=UTC)
|
||||
|
||||
# Datetime (may be naive)
|
||||
dt = datetime.fromisoformat(raw)
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=UTC)
|
||||
return dt.astimezone(UTC)
|
||||
|
||||
|
||||
def _resolve_date_range(
|
||||
start_date: datetime | None,
|
||||
end_date: datetime | None,
|
||||
) -> tuple[datetime, datetime]:
|
||||
"""
|
||||
Resolve a date range, defaulting to the last 2 years if not provided.
|
||||
Ensures start_date <= end_date.
|
||||
"""
|
||||
resolved_end = end_date or datetime.now(UTC)
|
||||
resolved_start = start_date or (resolved_end - timedelta(days=730))
|
||||
|
||||
if resolved_start > resolved_end:
|
||||
resolved_start, resolved_end = resolved_end, resolved_start
|
||||
|
||||
return resolved_start, resolved_end
|
||||
|
||||
|
||||
def format_documents_for_context(documents: list[dict[str, Any]]) -> str:
|
||||
"""
|
||||
Format retrieved documents into a readable context string for the LLM.
|
||||
|
||||
Args:
|
||||
documents: List of document dictionaries from connector search
|
||||
|
||||
Returns:
|
||||
Formatted string with document contents and metadata
|
||||
"""
|
||||
if not documents:
|
||||
return ""
|
||||
|
||||
# Group chunks by document id (preferred) to produce the XML structure.
|
||||
#
|
||||
# IMPORTANT: ConnectorService returns **document-grouped** results of the form:
|
||||
# {
|
||||
# "document": {...},
|
||||
# "chunks": [{"chunk_id": 123, "content": "..."}, ...],
|
||||
# "source": "NOTION_CONNECTOR" | "FILE" | ...
|
||||
# }
|
||||
#
|
||||
# We must preserve chunk_id so citations like [citation:123] are possible.
|
||||
grouped: dict[str, dict[str, Any]] = {}
|
||||
|
||||
for doc in documents:
|
||||
document_info = (doc.get("document") or {}) if isinstance(doc, dict) else {}
|
||||
metadata = (
|
||||
(document_info.get("metadata") or {})
|
||||
if isinstance(document_info, dict)
|
||||
else {}
|
||||
)
|
||||
if not metadata and isinstance(doc, dict):
|
||||
# Some result shapes may place metadata at the top level.
|
||||
metadata = doc.get("metadata") or {}
|
||||
|
||||
source = (
|
||||
(doc.get("source") if isinstance(doc, dict) else None)
|
||||
or metadata.get("document_type")
|
||||
or "UNKNOWN"
|
||||
)
|
||||
|
||||
# Document identity (prefer document_id; otherwise fall back to type+title+url)
|
||||
document_id_val = document_info.get("id")
|
||||
title = (
|
||||
document_info.get("title") or metadata.get("title") or "Untitled Document"
|
||||
)
|
||||
url = (
|
||||
metadata.get("url")
|
||||
or metadata.get("source")
|
||||
or metadata.get("page_url")
|
||||
or ""
|
||||
)
|
||||
|
||||
doc_key = (
|
||||
str(document_id_val)
|
||||
if document_id_val is not None
|
||||
else f"{source}::{title}::{url}"
|
||||
)
|
||||
|
||||
if doc_key not in grouped:
|
||||
grouped[doc_key] = {
|
||||
"document_id": document_id_val
|
||||
if document_id_val is not None
|
||||
else doc_key,
|
||||
"document_type": metadata.get("document_type") or source,
|
||||
"title": title,
|
||||
"url": url,
|
||||
"metadata": metadata,
|
||||
"chunks": [],
|
||||
}
|
||||
|
||||
# Prefer document-grouped chunks if available
|
||||
chunks_list = doc.get("chunks") if isinstance(doc, dict) else None
|
||||
if isinstance(chunks_list, list) and chunks_list:
|
||||
for ch in chunks_list:
|
||||
if not isinstance(ch, dict):
|
||||
continue
|
||||
chunk_id = ch.get("chunk_id") or ch.get("id")
|
||||
content = (ch.get("content") or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
grouped[doc_key]["chunks"].append(
|
||||
{"chunk_id": chunk_id, "content": content}
|
||||
)
|
||||
continue
|
||||
|
||||
# Fallback: treat this as a flat chunk-like object
|
||||
if not isinstance(doc, dict):
|
||||
continue
|
||||
chunk_id = doc.get("chunk_id") or doc.get("id")
|
||||
content = (doc.get("content") or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
grouped[doc_key]["chunks"].append({"chunk_id": chunk_id, "content": content})
|
||||
|
||||
# Render XML expected by citation instructions
|
||||
parts: list[str] = []
|
||||
for g in grouped.values():
|
||||
metadata_json = json.dumps(g["metadata"], ensure_ascii=False)
|
||||
|
||||
parts.append("<document>")
|
||||
parts.append("<document_metadata>")
|
||||
parts.append(f" <document_id>{g['document_id']}</document_id>")
|
||||
parts.append(f" <document_type>{g['document_type']}</document_type>")
|
||||
parts.append(f" <title><![CDATA[{g['title']}]]></title>")
|
||||
parts.append(f" <url><![CDATA[{g['url']}]]></url>")
|
||||
parts.append(f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>")
|
||||
parts.append("</document_metadata>")
|
||||
parts.append("")
|
||||
parts.append("<document_content>")
|
||||
|
||||
for ch in g["chunks"]:
|
||||
ch_content = ch["content"]
|
||||
ch_id = ch["chunk_id"]
|
||||
if ch_id is None:
|
||||
parts.append(f" <chunk><![CDATA[{ch_content}]]></chunk>")
|
||||
else:
|
||||
parts.append(f" <chunk id='{ch_id}'><![CDATA[{ch_content}]]></chunk>")
|
||||
|
||||
parts.append("</document_content>")
|
||||
parts.append("</document>")
|
||||
parts.append("")
|
||||
|
||||
return "\n".join(parts).strip()
|
||||
|
||||
|
||||
async def search_knowledge_base_async(
|
||||
query: str,
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
connector_service: ConnectorService,
|
||||
connectors_to_search: list[str] | None = None,
|
||||
top_k: int = 10,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Search the user's knowledge base for relevant documents.
|
||||
|
||||
This is the async implementation that searches across multiple connectors.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
search_space_id: The user's search space ID
|
||||
db_session: Database session
|
||||
connector_service: Initialized connector service
|
||||
connectors_to_search: Optional list of connector types to search. If omitted, searches all.
|
||||
top_k: Number of results per connector
|
||||
start_date: Optional start datetime (UTC) for filtering documents
|
||||
end_date: Optional end datetime (UTC) for filtering documents
|
||||
|
||||
Returns:
|
||||
Formatted string with search results
|
||||
"""
|
||||
all_documents = []
|
||||
|
||||
# Resolve date range (default last 2 years)
|
||||
resolved_start_date, resolved_end_date = _resolve_date_range(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
connectors = _normalize_connectors(connectors_to_search)
|
||||
|
||||
for connector in connectors:
|
||||
try:
|
||||
if connector == "YOUTUBE_VIDEO":
|
||||
_, chunks = await connector_service.search_youtube(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "EXTENSION":
|
||||
_, chunks = await connector_service.search_extension(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "CRAWLED_URL":
|
||||
_, chunks = await connector_service.search_crawled_urls(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "FILE":
|
||||
_, chunks = await connector_service.search_files(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "SLACK_CONNECTOR":
|
||||
_, chunks = await connector_service.search_slack(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "NOTION_CONNECTOR":
|
||||
_, chunks = await connector_service.search_notion(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "GITHUB_CONNECTOR":
|
||||
_, chunks = await connector_service.search_github(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "LINEAR_CONNECTOR":
|
||||
_, chunks = await connector_service.search_linear(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "TAVILY_API":
|
||||
_, chunks = await connector_service.search_tavily(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "SEARXNG_API":
|
||||
_, chunks = await connector_service.search_searxng(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "LINKUP_API":
|
||||
# Keep behavior aligned with researcher: default "standard"
|
||||
_, chunks = await connector_service.search_linkup(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
mode="standard",
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "BAIDU_SEARCH_API":
|
||||
_, chunks = await connector_service.search_baidu(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "DISCORD_CONNECTOR":
|
||||
_, chunks = await connector_service.search_discord(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "JIRA_CONNECTOR":
|
||||
_, chunks = await connector_service.search_jira(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "GOOGLE_CALENDAR_CONNECTOR":
|
||||
_, chunks = await connector_service.search_google_calendar(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "AIRTABLE_CONNECTOR":
|
||||
_, chunks = await connector_service.search_airtable(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "GOOGLE_GMAIL_CONNECTOR":
|
||||
_, chunks = await connector_service.search_google_gmail(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "CONFLUENCE_CONNECTOR":
|
||||
_, chunks = await connector_service.search_confluence(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "CLICKUP_CONNECTOR":
|
||||
_, chunks = await connector_service.search_clickup(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "LUMA_CONNECTOR":
|
||||
_, chunks = await connector_service.search_luma(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "ELASTICSEARCH_CONNECTOR":
|
||||
_, chunks = await connector_service.search_elasticsearch(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "NOTE":
|
||||
_, chunks = await connector_service.search_notes(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "BOOKSTACK_CONNECTOR":
|
||||
_, chunks = await connector_service.search_bookstack(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error searching connector {connector}: {e}")
|
||||
continue
|
||||
|
||||
# Deduplicate by content hash
|
||||
seen_doc_ids: set[Any] = set()
|
||||
seen_hashes: set[int] = set()
|
||||
deduplicated: list[dict[str, Any]] = []
|
||||
for doc in all_documents:
|
||||
doc_id = (doc.get("document", {}) or {}).get("id")
|
||||
content = (doc.get("content", "") or "").strip()
|
||||
content_hash = hash(content)
|
||||
|
||||
if (doc_id and doc_id in seen_doc_ids) or content_hash in seen_hashes:
|
||||
continue
|
||||
|
||||
if doc_id:
|
||||
seen_doc_ids.add(doc_id)
|
||||
seen_hashes.add(content_hash)
|
||||
deduplicated.append(doc)
|
||||
|
||||
return format_documents_for_context(deduplicated)
|
||||
|
||||
|
||||
def create_search_knowledge_base_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
connector_service: ConnectorService,
|
||||
):
|
||||
"""
|
||||
Factory function to create the search_knowledge_base tool with injected dependencies.
|
||||
|
||||
Args:
|
||||
search_space_id: The user's search space ID
|
||||
db_session: Database session
|
||||
connector_service: Initialized connector service
|
||||
connectors_to_search: List of connector types to search
|
||||
|
||||
Returns:
|
||||
A configured tool function
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def search_knowledge_base(
|
||||
query: str,
|
||||
top_k: int = 10,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
connectors_to_search: list[str] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Search the user's personal knowledge base for relevant information.
|
||||
|
||||
Use this tool to find documents, notes, files, web pages, and other content
|
||||
that may help answer the user's question.
|
||||
|
||||
IMPORTANT:
|
||||
- If the user requests a specific source type (e.g. "my notes", "Slack messages"),
|
||||
pass `connectors_to_search=[...]` using the enums below.
|
||||
- If `connectors_to_search` is omitted/empty, the system will search broadly.
|
||||
|
||||
## Available connector enums for `connectors_to_search`
|
||||
|
||||
- EXTENSION: "Web content saved via SurfSense browser extension" (personal browsing history)
|
||||
- FILE: "User-uploaded documents (PDFs, Word, etc.)" (personal files)
|
||||
- NOTE: "SurfSense Notes" (notes created inside SurfSense)
|
||||
- SLACK_CONNECTOR: "Slack conversations and shared content" (personal workspace communications)
|
||||
- NOTION_CONNECTOR: "Notion workspace pages and databases" (personal knowledge management)
|
||||
- YOUTUBE_VIDEO: "YouTube video transcripts and metadata" (personally saved videos)
|
||||
- GITHUB_CONNECTOR: "GitHub repository content and issues" (personal repositories and interactions)
|
||||
- ELASTICSEARCH_CONNECTOR: "Elasticsearch indexed documents and data" (personal Elasticsearch instances and custom data sources)
|
||||
- LINEAR_CONNECTOR: "Linear project issues and discussions" (personal project management)
|
||||
- JIRA_CONNECTOR: "Jira project issues, tickets, and comments" (personal project tracking)
|
||||
- CONFLUENCE_CONNECTOR: "Confluence pages and comments" (personal project documentation)
|
||||
- CLICKUP_CONNECTOR: "ClickUp tasks and project data" (personal task management)
|
||||
- GOOGLE_CALENDAR_CONNECTOR: "Google Calendar events, meetings, and schedules" (personal calendar and time management)
|
||||
- GOOGLE_GMAIL_CONNECTOR: "Google Gmail emails and conversations" (personal emails and communications)
|
||||
- DISCORD_CONNECTOR: "Discord server conversations and shared content" (personal community communications)
|
||||
- AIRTABLE_CONNECTOR: "Airtable records, tables, and database content" (personal data management and organization)
|
||||
- TAVILY_API: "Tavily search API results" (personalized search results)
|
||||
- SEARXNG_API: "SearxNG search API results" (personalized search results)
|
||||
- LINKUP_API: "Linkup search API results" (personalized search results)
|
||||
- BAIDU_SEARCH_API: "Baidu search API results" (personalized search results)
|
||||
- LUMA_CONNECTOR: "Luma events"
|
||||
- WEBCRAWLER_CONNECTOR: "Webpages indexed by SurfSense" (personally selected websites)
|
||||
- BOOKSTACK_CONNECTOR: "BookStack pages" (personal documentation)
|
||||
|
||||
NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type `CRAWLED_URL`.
|
||||
|
||||
Args:
|
||||
query: The search query - be specific and include key terms
|
||||
top_k: Number of results to retrieve (default: 10)
|
||||
start_date: Optional ISO date/datetime (e.g. "2025-12-12" or "2025-12-12T00:00:00+00:00")
|
||||
end_date: Optional ISO date/datetime (e.g. "2025-12-19" or "2025-12-19T23:59:59+00:00")
|
||||
connectors_to_search: Optional list of connector enums to search. If omitted, searches all.
|
||||
|
||||
Returns:
|
||||
Formatted string with relevant documents and their content
|
||||
"""
|
||||
parsed_start: datetime | None = None
|
||||
parsed_end: datetime | None = None
|
||||
|
||||
if start_date:
|
||||
parsed_start = _parse_date_or_datetime(start_date)
|
||||
if end_date:
|
||||
parsed_end = _parse_date_or_datetime(end_date)
|
||||
|
||||
return await search_knowledge_base_async(
|
||||
query=query,
|
||||
search_space_id=search_space_id,
|
||||
db_session=db_session,
|
||||
connector_service=connector_service,
|
||||
connectors_to_search=connectors_to_search,
|
||||
top_k=top_k,
|
||||
start_date=parsed_start,
|
||||
end_date=parsed_end,
|
||||
)
|
||||
|
||||
return search_knowledge_base
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# System Prompt
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def build_surfsense_system_prompt(today: datetime | None = None) -> str:
|
||||
resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat()
|
||||
|
||||
return f"""
|
||||
<system_instruction>
|
||||
You are SurfSense, a reasoning and acting AI agent designed to answer user questions using the user's personal knowledge base.
|
||||
|
||||
Today's date (UTC): {resolved_today}
|
||||
|
||||
</system_instruction>
|
||||
<tools>
|
||||
You have access to the following tools:
|
||||
- search_knowledge_base: Search the user's personal knowledge base for relevant information.
|
||||
- Args:
|
||||
- query: The search query - be specific and include key terms
|
||||
- top_k: Number of results to retrieve (default: 10)
|
||||
- start_date: Optional ISO date/datetime (e.g. "2025-12-12" or "2025-12-12T00:00:00+00:00")
|
||||
- end_date: Optional ISO date/datetime (e.g. "2025-12-19" or "2025-12-19T23:59:59+00:00")
|
||||
- connectors_to_search: Optional list of connector enums to search. If omitted, searches all.
|
||||
- Returns: Formatted string with relevant documents and their content
|
||||
</tools>
|
||||
<tool_call_examples>
|
||||
- User: "Fetch all my notes and what's in them?"
|
||||
- Call: `search_knowledge_base(query="*", top_k=50, connectors_to_search=["NOTE"])`
|
||||
|
||||
- User: "What did I discuss on Slack last week about the React migration?"
|
||||
- Call: `search_knowledge_base(query="React migration", connectors_to_search=["SLACK_CONNECTOR"], start_date="YYYY-MM-DD", end_date="YYYY-MM-DD")`
|
||||
</tool_call_examples>
|
||||
|
||||
{SURFSENSE_CITATION_INSTRUCTIONS}
|
||||
"""
|
||||
|
||||
|
||||
SURFSENSE_SYSTEM_PROMPT = build_surfsense_system_prompt()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Deep Agent Factory
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_surfsense_deep_agent(
|
||||
llm: ChatLiteLLM,
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
connector_service: ConnectorService,
|
||||
):
|
||||
"""
|
||||
Create a SurfSense deep agent with knowledge base search capability.
|
||||
|
||||
Args:
|
||||
llm: ChatLiteLLM instance
|
||||
search_space_id: The user's search space ID
|
||||
db_session: Database session
|
||||
connector_service: Initialized connector service
|
||||
connectors_to_search: List of connector types to search (default: common connectors)
|
||||
|
||||
Returns:
|
||||
CompiledStateGraph: The configured deep agent
|
||||
"""
|
||||
# Create the search tool with injected dependencies
|
||||
search_tool = create_search_knowledge_base_tool(
|
||||
search_space_id=search_space_id,
|
||||
db_session=db_session,
|
||||
connector_service=connector_service,
|
||||
)
|
||||
|
||||
# Create the deep agent
|
||||
agent = create_deep_agent(
|
||||
model=llm,
|
||||
tools=[search_tool],
|
||||
system_prompt=build_surfsense_system_prompt(),
|
||||
context_schema=SurfSenseContextSchema,
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Runner
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def run_test():
|
||||
"""Run a basic test of the deep agent."""
|
||||
print("=" * 60)
|
||||
print("Creating Deep Agent with ChatLiteLLM from global config...")
|
||||
print("=" * 60)
|
||||
|
||||
# Create ChatLiteLLM from global config
|
||||
# Use global LLM config by id (negative ids are reserved for global configs)
|
||||
llm_config = load_llm_config_from_yaml(llm_config_id=-2)
|
||||
if not llm_config:
|
||||
raise ValueError("Failed to load LLM config from YAML")
|
||||
llm = create_chat_litellm_from_config(llm_config)
|
||||
if not llm:
|
||||
raise ValueError("Failed to create ChatLiteLLM instance")
|
||||
|
||||
# Create a real DB session + ConnectorService, then build the full SurfSense agent.
|
||||
async with async_session_maker() as session:
|
||||
# Use the known dev search space id
|
||||
search_space_id = 5
|
||||
|
||||
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
||||
|
||||
agent = create_surfsense_deep_agent(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
connector_service=connector_service,
|
||||
)
|
||||
|
||||
print("\nAgent created successfully!")
|
||||
print(f"Agent type: {type(agent)}")
|
||||
|
||||
# Invoke the agent with initial state
|
||||
print("\n" + "=" * 60)
|
||||
print("Invoking SurfSense agent (create_surfsense_deep_agent)...")
|
||||
print("=" * 60)
|
||||
|
||||
initial_state = {
|
||||
"messages": [HumanMessage(content=("What are my notes from last 3 days?"))],
|
||||
"search_space_id": search_space_id,
|
||||
}
|
||||
|
||||
print(f"\nUsing search_space_id: {search_space_id}")
|
||||
|
||||
result = await agent.ainvoke(initial_state)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Agent Response:")
|
||||
print("=" * 60)
|
||||
|
||||
# Print the response
|
||||
if "messages" in result:
|
||||
for msg in result["messages"]:
|
||||
msg_type = type(msg).__name__
|
||||
content = msg.content if hasattr(msg, "content") else str(msg)
|
||||
print(f"\n--- [{msg_type}] ---\n{content}\n")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(run_test())
|
||||
|
|
@ -1163,6 +1163,33 @@ async def fetch_relevant_documents(
|
|||
}
|
||||
)
|
||||
|
||||
elif connector == "BOOKSTACK_CONNECTOR":
|
||||
(
|
||||
source_object,
|
||||
bookstack_chunks,
|
||||
) = await connector_service.search_bookstack(
|
||||
user_query=reformulated_query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
# Add to sources and raw documents
|
||||
if source_object:
|
||||
all_sources.append(source_object)
|
||||
all_raw_documents.extend(bookstack_chunks)
|
||||
|
||||
# Stream found document count
|
||||
if streaming_service and writer:
|
||||
writer(
|
||||
{
|
||||
"yield_value": streaming_service.format_terminal_info_delta(
|
||||
f"📚 Found {len(bookstack_chunks)} BookStack pages related to your query"
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
elif connector == "NOTE":
|
||||
(
|
||||
source_object,
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ def get_connector_emoji(connector_name: str) -> str:
|
|||
"LUMA_CONNECTOR": "✨",
|
||||
"ELASTICSEARCH_CONNECTOR": "⚡",
|
||||
"WEBCRAWLER_CONNECTOR": "🌐",
|
||||
"BOOKSTACK_CONNECTOR": "📚",
|
||||
"NOTE": "📝",
|
||||
}
|
||||
return connector_emojis.get(connector_name, "🔎")
|
||||
|
|
@ -60,6 +61,7 @@ def get_connector_friendly_name(connector_name: str) -> str:
|
|||
"LUMA_CONNECTOR": "Luma",
|
||||
"ELASTICSEARCH_CONNECTOR": "Elasticsearch",
|
||||
"WEBCRAWLER_CONNECTOR": "Web Pages",
|
||||
"BOOKSTACK_CONNECTOR": "BookStack",
|
||||
"NOTE": "Notes",
|
||||
}
|
||||
return connector_friendly_names.get(connector_name, connector_name)
|
||||
|
|
|
|||
|
|
@ -20,8 +20,13 @@ from app.schemas import (
|
|||
ChatRead,
|
||||
ChatReadWithoutMessages,
|
||||
ChatUpdate,
|
||||
NewChatRequest,
|
||||
)
|
||||
from app.tasks.stream_connector_search_results import stream_connector_search_results
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.tasks.chat.stream_connector_search_results import (
|
||||
stream_connector_search_results,
|
||||
)
|
||||
from app.tasks.chat.stream_new_chat import stream_new_chat
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
from app.utils.validators import (
|
||||
|
|
@ -152,6 +157,87 @@ async def handle_chat_data(
|
|||
return response
|
||||
|
||||
|
||||
@router.post("/new_chat")
|
||||
async def handle_new_chat(
|
||||
request: NewChatRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Handle new chat requests using the SurfSense deep agent.
|
||||
|
||||
This endpoint uses the new deep agent with the Vercel AI SDK
|
||||
Data Stream Protocol (SSE format).
|
||||
|
||||
Args:
|
||||
request: NewChatRequest containing chat_id, user_query, and search_space_id
|
||||
session: Database session
|
||||
user: Current authenticated user
|
||||
|
||||
Returns:
|
||||
StreamingResponse with SSE formatted data
|
||||
"""
|
||||
# Validate the user query
|
||||
if not request.user_query or not request.user_query.strip():
|
||||
raise HTTPException(status_code=400, detail="User query cannot be empty")
|
||||
|
||||
# Check if the user has chat access to the search space
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
request.search_space_id,
|
||||
Permission.CHATS_CREATE.value,
|
||||
"You don't have permission to use chat in this search space",
|
||||
)
|
||||
except HTTPException:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have access to this search space"
|
||||
) from None
|
||||
|
||||
# Get LLM config ID from search space preferences (optional enhancement)
|
||||
# For now, we use the default global config (-1)
|
||||
llm_config_id = -1
|
||||
|
||||
# Optionally load LLM preferences from search space
|
||||
try:
|
||||
search_space_result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
|
||||
)
|
||||
search_space = search_space_result.scalars().first()
|
||||
|
||||
if search_space:
|
||||
# Use strategic_llm_id if available, otherwise fall back to fast_llm_id
|
||||
if search_space.strategic_llm_id is not None:
|
||||
llm_config_id = search_space.strategic_llm_id
|
||||
elif search_space.fast_llm_id is not None:
|
||||
llm_config_id = search_space.fast_llm_id
|
||||
except Exception:
|
||||
# Fall back to default config on any error
|
||||
pass
|
||||
|
||||
# Create the streaming response
|
||||
# chat_id is used as LangGraph's thread_id for automatic chat history management
|
||||
response = StreamingResponse(
|
||||
stream_new_chat(
|
||||
user_query=request.user_query.strip(),
|
||||
user_id=user.id,
|
||||
search_space_id=request.search_space_id,
|
||||
chat_id=request.chat_id,
|
||||
session=session,
|
||||
llm_config_id=llm_config_id,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
# Set the required headers for Vercel AI SDK
|
||||
headers = VercelStreamingService.get_response_headers()
|
||||
for key, value in headers.items():
|
||||
response.headers[key] = value
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/chats", response_model=ChatRead)
|
||||
async def create_chat(
|
||||
chat: ChatCreate,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from .chats import (
|
|||
ChatRead,
|
||||
ChatReadWithoutMessages,
|
||||
ChatUpdate,
|
||||
NewChatRequest,
|
||||
)
|
||||
from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate
|
||||
from .documents import (
|
||||
|
|
@ -97,6 +98,7 @@ __all__ = [
|
|||
"MembershipRead",
|
||||
"MembershipReadWithUser",
|
||||
"MembershipUpdate",
|
||||
"NewChatRequest",
|
||||
"PaginatedResponse",
|
||||
"PermissionInfo",
|
||||
"PermissionsListResponse",
|
||||
|
|
|
|||
|
|
@ -48,6 +48,14 @@ class AISDKChatRequest(BaseModel):
|
|||
data: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class NewChatRequest(BaseModel):
|
||||
"""Request schema for the new deep agent chat endpoint."""
|
||||
|
||||
chat_id: int
|
||||
user_query: str
|
||||
search_space_id: int
|
||||
|
||||
|
||||
class ChatCreate(ChatBase):
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -252,24 +252,28 @@ class ConnectorService:
|
|||
# Get more results from each retriever for better fusion
|
||||
retriever_top_k = top_k * 2
|
||||
|
||||
# Run both searches in parallel
|
||||
chunk_results, doc_results = await asyncio.gather(
|
||||
self.chunk_retriever.hybrid_search(
|
||||
query_text=query_text,
|
||||
top_k=retriever_top_k,
|
||||
search_space_id=search_space_id,
|
||||
document_type=document_type,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
),
|
||||
self.document_retriever.hybrid_search(
|
||||
query_text=query_text,
|
||||
top_k=retriever_top_k,
|
||||
search_space_id=search_space_id,
|
||||
document_type=document_type,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
),
|
||||
# IMPORTANT:
|
||||
# These retrievers share the same AsyncSession. AsyncSession does not permit
|
||||
# concurrent awaits that require DB IO on the same session/connection.
|
||||
# Running these in parallel can raise:
|
||||
# "This session is provisioning a new connection; concurrent operations are not permitted"
|
||||
#
|
||||
# So we run them sequentially.
|
||||
chunk_results = await self.chunk_retriever.hybrid_search(
|
||||
query_text=query_text,
|
||||
top_k=retriever_top_k,
|
||||
search_space_id=search_space_id,
|
||||
document_type=document_type,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
doc_results = await self.document_retriever.hybrid_search(
|
||||
query_text=query_text,
|
||||
top_k=retriever_top_k,
|
||||
search_space_id=search_space_id,
|
||||
document_type=document_type,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
# Helper to extract document_id from our doc-grouped result
|
||||
|
|
@ -2432,7 +2436,6 @@ class ConnectorService:
|
|||
async def search_bookstack(
|
||||
self,
|
||||
user_query: str,
|
||||
user_id: str,
|
||||
search_space_id: int,
|
||||
top_k: int = 20,
|
||||
start_date: datetime | None = None,
|
||||
|
|
|
|||
699
surfsense_backend/app/services/new_streaming_service.py
Normal file
699
surfsense_backend/app/services/new_streaming_service.py
Normal file
|
|
@ -0,0 +1,699 @@
|
|||
"""
|
||||
Vercel AI SDK Data Stream Protocol Implementation
|
||||
|
||||
This module implements the Vercel AI SDK streaming protocol for use with
|
||||
@ai-sdk/react's useChat and useCompletion hooks.
|
||||
|
||||
Protocol Reference:
|
||||
- Uses Server-Sent Events (SSE) format
|
||||
- Requires 'x-vercel-ai-ui-message-stream: v1' header
|
||||
- Supports text, reasoning, sources, files, tools, data, and error parts
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
def generate_id() -> str:
|
||||
"""Generate a unique ID for stream parts."""
|
||||
return f"msg_{uuid.uuid4().hex}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamContext:
|
||||
"""
|
||||
Maintains context for streaming operations.
|
||||
Tracks active text and reasoning blocks.
|
||||
"""
|
||||
|
||||
message_id: str = field(default_factory=generate_id)
|
||||
active_text_id: str | None = None
|
||||
active_reasoning_id: str | None = None
|
||||
step_count: int = 0
|
||||
|
||||
|
||||
class VercelStreamingService:
|
||||
"""
|
||||
Implements the Vercel AI SDK Data Stream Protocol.
|
||||
|
||||
This service formats messages according to the SSE-based protocol
|
||||
that the AI SDK frontend expects. All messages are formatted as:
|
||||
data: {json_object}\n\n
|
||||
|
||||
Usage:
|
||||
service = VercelStreamingService()
|
||||
|
||||
# Start a message
|
||||
yield service.format_message_start()
|
||||
|
||||
# Stream text content
|
||||
text_id = service.generate_text_id()
|
||||
yield service.format_text_start(text_id)
|
||||
yield service.format_text_delta(text_id, "Hello, ")
|
||||
yield service.format_text_delta(text_id, "world!")
|
||||
yield service.format_text_end(text_id)
|
||||
|
||||
# Finish the message
|
||||
yield service.format_finish()
|
||||
yield service.format_done()
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.context = StreamContext()
|
||||
|
||||
@staticmethod
|
||||
def get_response_headers() -> dict[str, str]:
|
||||
"""
|
||||
Get the required HTTP headers for Vercel AI SDK streaming.
|
||||
|
||||
Returns:
|
||||
dict: Headers to include in the streaming response
|
||||
"""
|
||||
return {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _format_sse(data: Any) -> str:
|
||||
"""
|
||||
Format data as a Server-Sent Event.
|
||||
|
||||
Args:
|
||||
data: The data to format (will be JSON serialized if not a string)
|
||||
|
||||
Returns:
|
||||
str: SSE formatted string
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
return f"data: {data}\n\n"
|
||||
return f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
@staticmethod
|
||||
def generate_text_id() -> str:
|
||||
"""Generate a unique ID for a text block."""
|
||||
return f"text_{uuid.uuid4().hex}"
|
||||
|
||||
@staticmethod
|
||||
def generate_reasoning_id() -> str:
|
||||
"""Generate a unique ID for a reasoning block."""
|
||||
return f"reasoning_{uuid.uuid4().hex}"
|
||||
|
||||
@staticmethod
|
||||
def generate_tool_call_id() -> str:
|
||||
"""Generate a unique ID for a tool call."""
|
||||
return f"call_{uuid.uuid4().hex}"
|
||||
|
||||
# =========================================================================
|
||||
# Message Lifecycle Parts
|
||||
# =========================================================================
|
||||
|
||||
def format_message_start(self, message_id: str | None = None) -> str:
|
||||
"""
|
||||
Format the start of a new message.
|
||||
|
||||
Args:
|
||||
message_id: Optional custom message ID. If not provided, one is generated.
|
||||
|
||||
Returns:
|
||||
str: SSE formatted message start part
|
||||
|
||||
Example output:
|
||||
data: {"type":"start","messageId":"msg_abc123"}
|
||||
"""
|
||||
if message_id:
|
||||
self.context.message_id = message_id
|
||||
else:
|
||||
self.context.message_id = generate_id()
|
||||
|
||||
return self._format_sse({"type": "start", "messageId": self.context.message_id})
|
||||
|
||||
def format_finish(self) -> str:
|
||||
"""
|
||||
Format the finish message part.
|
||||
|
||||
Returns:
|
||||
str: SSE formatted finish part
|
||||
|
||||
Example output:
|
||||
data: {"type":"finish"}
|
||||
"""
|
||||
return self._format_sse({"type": "finish"})
|
||||
|
||||
def format_done(self) -> str:
|
||||
"""
|
||||
Format the stream termination marker.
|
||||
|
||||
This should be the last thing sent in a stream.
|
||||
|
||||
Returns:
|
||||
str: SSE formatted done marker
|
||||
|
||||
Example output:
|
||||
data: [DONE]
|
||||
"""
|
||||
return "data: [DONE]\n\n"
|
||||
|
||||
# =========================================================================
|
||||
# Text Parts (start/delta/end pattern)
|
||||
# =========================================================================
|
||||
|
||||
def format_text_start(self, text_id: str | None = None) -> str:
|
||||
"""
|
||||
Format the start of a text block.
|
||||
|
||||
Args:
|
||||
text_id: Optional custom text block ID. If not provided, one is generated.
|
||||
|
||||
Returns:
|
||||
str: SSE formatted text start part
|
||||
|
||||
Example output:
|
||||
data: {"type":"text-start","id":"text_abc123"}
|
||||
"""
|
||||
if text_id is None:
|
||||
text_id = self.generate_text_id()
|
||||
self.context.active_text_id = text_id
|
||||
return self._format_sse({"type": "text-start", "id": text_id})
|
||||
|
||||
def format_text_delta(self, text_id: str, delta: str) -> str:
|
||||
"""
|
||||
Format a text delta (incremental content).
|
||||
|
||||
Args:
|
||||
text_id: The text block ID
|
||||
delta: The incremental text content
|
||||
|
||||
Returns:
|
||||
str: SSE formatted text delta part
|
||||
|
||||
Example output:
|
||||
data: {"type":"text-delta","id":"text_abc123","delta":"Hello"}
|
||||
"""
|
||||
return self._format_sse({"type": "text-delta", "id": text_id, "delta": delta})
|
||||
|
||||
def format_text_end(self, text_id: str) -> str:
|
||||
"""
|
||||
Format the end of a text block.
|
||||
|
||||
Args:
|
||||
text_id: The text block ID
|
||||
|
||||
Returns:
|
||||
str: SSE formatted text end part
|
||||
|
||||
Example output:
|
||||
data: {"type":"text-end","id":"text_abc123"}
|
||||
"""
|
||||
if self.context.active_text_id == text_id:
|
||||
self.context.active_text_id = None
|
||||
return self._format_sse({"type": "text-end", "id": text_id})
|
||||
|
||||
def stream_text(self, text_id: str, text: str, chunk_size: int = 10) -> list[str]:
|
||||
"""
|
||||
Convenience method to stream text in chunks.
|
||||
|
||||
Args:
|
||||
text_id: The text block ID
|
||||
text: The full text to stream
|
||||
chunk_size: Size of each chunk (default 10 characters)
|
||||
|
||||
Returns:
|
||||
list[str]: List of SSE formatted text delta parts
|
||||
"""
|
||||
parts = []
|
||||
for i in range(0, len(text), chunk_size):
|
||||
chunk = text[i : i + chunk_size]
|
||||
parts.append(self.format_text_delta(text_id, chunk))
|
||||
return parts
|
||||
|
||||
# =========================================================================
|
||||
# Reasoning Parts (start/delta/end pattern)
|
||||
# =========================================================================
|
||||
|
||||
def format_reasoning_start(self, reasoning_id: str | None = None) -> str:
|
||||
"""
|
||||
Format the start of a reasoning block.
|
||||
|
||||
Args:
|
||||
reasoning_id: Optional custom reasoning block ID.
|
||||
|
||||
Returns:
|
||||
str: SSE formatted reasoning start part
|
||||
|
||||
Example output:
|
||||
data: {"type":"reasoning-start","id":"reasoning_abc123"}
|
||||
"""
|
||||
if reasoning_id is None:
|
||||
reasoning_id = self.generate_reasoning_id()
|
||||
self.context.active_reasoning_id = reasoning_id
|
||||
return self._format_sse({"type": "reasoning-start", "id": reasoning_id})
|
||||
|
||||
def format_reasoning_delta(self, reasoning_id: str, delta: str) -> str:
|
||||
"""
|
||||
Format a reasoning delta (incremental reasoning content).
|
||||
|
||||
Args:
|
||||
reasoning_id: The reasoning block ID
|
||||
delta: The incremental reasoning content
|
||||
|
||||
Returns:
|
||||
str: SSE formatted reasoning delta part
|
||||
|
||||
Example output:
|
||||
data: {"type":"reasoning-delta","id":"reasoning_abc123","delta":"Let me think..."}
|
||||
"""
|
||||
return self._format_sse(
|
||||
{"type": "reasoning-delta", "id": reasoning_id, "delta": delta}
|
||||
)
|
||||
|
||||
def format_reasoning_end(self, reasoning_id: str) -> str:
|
||||
"""
|
||||
Format the end of a reasoning block.
|
||||
|
||||
Args:
|
||||
reasoning_id: The reasoning block ID
|
||||
|
||||
Returns:
|
||||
str: SSE formatted reasoning end part
|
||||
|
||||
Example output:
|
||||
data: {"type":"reasoning-end","id":"reasoning_abc123"}
|
||||
"""
|
||||
if self.context.active_reasoning_id == reasoning_id:
|
||||
self.context.active_reasoning_id = None
|
||||
return self._format_sse({"type": "reasoning-end", "id": reasoning_id})
|
||||
|
||||
# =========================================================================
|
||||
# Source Parts
|
||||
# =========================================================================
|
||||
|
||||
def format_source_url(
|
||||
self, url: str, source_id: str | None = None, title: str | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Format a source URL reference.
|
||||
|
||||
Args:
|
||||
url: The source URL
|
||||
source_id: Optional source identifier (defaults to URL)
|
||||
title: Optional title for the source
|
||||
|
||||
Returns:
|
||||
str: SSE formatted source URL part
|
||||
|
||||
Example output:
|
||||
data: {"type":"source-url","sourceId":"https://example.com","url":"https://example.com"}
|
||||
"""
|
||||
data: dict[str, Any] = {
|
||||
"type": "source-url",
|
||||
"sourceId": source_id or url,
|
||||
"url": url,
|
||||
}
|
||||
if title:
|
||||
data["title"] = title
|
||||
return self._format_sse(data)
|
||||
|
||||
def format_source_document(
|
||||
self,
|
||||
source_id: str,
|
||||
media_type: str = "file",
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format a source document reference.
|
||||
|
||||
Args:
|
||||
source_id: The source identifier
|
||||
media_type: The media type (e.g., "file", "pdf", "document")
|
||||
title: Optional title for the document
|
||||
description: Optional description
|
||||
|
||||
Returns:
|
||||
str: SSE formatted source document part
|
||||
|
||||
Example output:
|
||||
data: {"type":"source-document","sourceId":"doc_123","mediaType":"file","title":"Report"}
|
||||
"""
|
||||
data: dict[str, Any] = {
|
||||
"type": "source-document",
|
||||
"sourceId": source_id,
|
||||
"mediaType": media_type,
|
||||
}
|
||||
if title:
|
||||
data["title"] = title
|
||||
if description:
|
||||
data["description"] = description
|
||||
return self._format_sse(data)
|
||||
|
||||
def format_sources(self, sources: list[dict[str, Any]]) -> list[str]:
|
||||
"""
|
||||
Format multiple sources.
|
||||
|
||||
Args:
|
||||
sources: List of source objects with 'url', 'title', 'type' fields
|
||||
|
||||
Returns:
|
||||
list[str]: List of SSE formatted source parts
|
||||
"""
|
||||
parts = []
|
||||
for source in sources:
|
||||
url = source.get("url")
|
||||
if url:
|
||||
parts.append(
|
||||
self.format_source_url(
|
||||
url=url,
|
||||
source_id=source.get("id", url),
|
||||
title=source.get("title"),
|
||||
)
|
||||
)
|
||||
else:
|
||||
parts.append(
|
||||
self.format_source_document(
|
||||
source_id=source.get("id", ""),
|
||||
media_type=source.get("type", "file"),
|
||||
title=source.get("title"),
|
||||
description=source.get("description"),
|
||||
)
|
||||
)
|
||||
return parts
|
||||
|
||||
# =========================================================================
|
||||
# File Part
|
||||
# =========================================================================
|
||||
|
||||
def format_file(self, url: str, media_type: str) -> str:
|
||||
"""
|
||||
Format a file reference.
|
||||
|
||||
Args:
|
||||
url: The file URL
|
||||
media_type: The MIME type (e.g., "image/png", "application/pdf")
|
||||
|
||||
Returns:
|
||||
str: SSE formatted file part
|
||||
|
||||
Example output:
|
||||
data: {"type":"file","url":"https://example.com/file.png","mediaType":"image/png"}
|
||||
"""
|
||||
return self._format_sse({"type": "file", "url": url, "mediaType": media_type})
|
||||
|
||||
# =========================================================================
|
||||
# Custom Data Parts
|
||||
# =========================================================================
|
||||
|
||||
def format_data(self, data_type: str, data: Any) -> str:
|
||||
"""
|
||||
Format custom data with a type-specific suffix.
|
||||
|
||||
The type will be prefixed with 'data-' automatically.
|
||||
|
||||
Args:
|
||||
data_type: The custom data type suffix (e.g., "weather", "chart")
|
||||
data: The data payload
|
||||
|
||||
Returns:
|
||||
str: SSE formatted data part
|
||||
|
||||
Example output:
|
||||
data: {"type":"data-weather","data":{"location":"SF","temperature":100}}
|
||||
"""
|
||||
return self._format_sse({"type": f"data-{data_type}", "data": data})
|
||||
|
||||
def format_terminal_info(self, text: str, message_type: str = "info") -> str:
|
||||
"""
|
||||
Format terminal info as custom data (SurfSense specific).
|
||||
|
||||
Args:
|
||||
text: The terminal message text
|
||||
message_type: The message type (info, error, success, warning)
|
||||
|
||||
Returns:
|
||||
str: SSE formatted terminal info data part
|
||||
"""
|
||||
return self.format_data("terminal-info", {"text": text, "type": message_type})
|
||||
|
||||
def format_further_questions(self, questions: list[str]) -> str:
|
||||
"""
|
||||
Format further questions as custom data (SurfSense specific).
|
||||
|
||||
Args:
|
||||
questions: List of suggested follow-up questions
|
||||
|
||||
Returns:
|
||||
str: SSE formatted further questions data part
|
||||
"""
|
||||
return self.format_data("further-questions", {"questions": questions})
|
||||
|
||||
# =========================================================================
|
||||
# Error Part
|
||||
# =========================================================================
|
||||
|
||||
def format_error(self, error_text: str) -> str:
|
||||
"""
|
||||
Format an error message.
|
||||
|
||||
Args:
|
||||
error_text: The error message text
|
||||
|
||||
Returns:
|
||||
str: SSE formatted error part
|
||||
|
||||
Example output:
|
||||
data: {"type":"error","errorText":"Something went wrong"}
|
||||
"""
|
||||
return self._format_sse({"type": "error", "errorText": error_text})
|
||||
|
||||
# =========================================================================
|
||||
# Tool Parts
|
||||
# =========================================================================
|
||||
|
||||
def format_tool_input_start(self, tool_call_id: str, tool_name: str) -> str:
|
||||
"""
|
||||
Format the start of tool input streaming.
|
||||
|
||||
Args:
|
||||
tool_call_id: The unique tool call identifier
|
||||
tool_name: The name of the tool being called
|
||||
|
||||
Returns:
|
||||
str: SSE formatted tool input start part
|
||||
|
||||
Example output:
|
||||
data: {"type":"tool-input-start","toolCallId":"call_abc123","toolName":"getWeather"}
|
||||
"""
|
||||
return self._format_sse(
|
||||
{
|
||||
"type": "tool-input-start",
|
||||
"toolCallId": tool_call_id,
|
||||
"toolName": tool_name,
|
||||
}
|
||||
)
|
||||
|
||||
def format_tool_input_delta(self, tool_call_id: str, input_text_delta: str) -> str:
|
||||
"""
|
||||
Format incremental tool input.
|
||||
|
||||
Args:
|
||||
tool_call_id: The tool call identifier
|
||||
input_text_delta: The incremental input text
|
||||
|
||||
Returns:
|
||||
str: SSE formatted tool input delta part
|
||||
|
||||
Example output:
|
||||
data: {"type":"tool-input-delta","toolCallId":"call_abc123","inputTextDelta":"San Fran"}
|
||||
"""
|
||||
return self._format_sse(
|
||||
{
|
||||
"type": "tool-input-delta",
|
||||
"toolCallId": tool_call_id,
|
||||
"inputTextDelta": input_text_delta,
|
||||
}
|
||||
)
|
||||
|
||||
def format_tool_input_available(
|
||||
self, tool_call_id: str, tool_name: str, input_data: dict[str, Any]
|
||||
) -> str:
|
||||
"""
|
||||
Format the completion of tool input.
|
||||
|
||||
Args:
|
||||
tool_call_id: The tool call identifier
|
||||
tool_name: The name of the tool
|
||||
input_data: The complete tool input parameters
|
||||
|
||||
Returns:
|
||||
str: SSE formatted tool input available part
|
||||
|
||||
Example output:
|
||||
data: {"type":"tool-input-available","toolCallId":"call_abc123","toolName":"getWeather","input":{"city":"SF"}}
|
||||
"""
|
||||
return self._format_sse(
|
||||
{
|
||||
"type": "tool-input-available",
|
||||
"toolCallId": tool_call_id,
|
||||
"toolName": tool_name,
|
||||
"input": input_data,
|
||||
}
|
||||
)
|
||||
|
||||
def format_tool_output_available(self, tool_call_id: str, output: Any) -> str:
|
||||
"""
|
||||
Format tool execution output.
|
||||
|
||||
Args:
|
||||
tool_call_id: The tool call identifier
|
||||
output: The tool execution result
|
||||
|
||||
Returns:
|
||||
str: SSE formatted tool output available part
|
||||
|
||||
Example output:
|
||||
data: {"type":"tool-output-available","toolCallId":"call_abc123","output":{"weather":"sunny"}}
|
||||
"""
|
||||
return self._format_sse(
|
||||
{
|
||||
"type": "tool-output-available",
|
||||
"toolCallId": tool_call_id,
|
||||
"output": output,
|
||||
}
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Step Parts
|
||||
# =========================================================================
|
||||
|
||||
def format_start_step(self) -> str:
|
||||
"""
|
||||
Format the start of a step (one LLM API call).
|
||||
|
||||
Returns:
|
||||
str: SSE formatted start step part
|
||||
|
||||
Example output:
|
||||
data: {"type":"start-step"}
|
||||
"""
|
||||
self.context.step_count += 1
|
||||
return self._format_sse({"type": "start-step"})
|
||||
|
||||
def format_finish_step(self) -> str:
|
||||
"""
|
||||
Format the completion of a step.
|
||||
|
||||
This is necessary for correctly processing multiple stitched
|
||||
assistant calls, e.g., when calling tools in the backend.
|
||||
|
||||
Returns:
|
||||
str: SSE formatted finish step part
|
||||
|
||||
Example output:
|
||||
data: {"type":"finish-step"}
|
||||
"""
|
||||
return self._format_sse({"type": "finish-step"})
|
||||
|
||||
# =========================================================================
|
||||
# Convenience Methods
|
||||
# =========================================================================
|
||||
|
||||
def stream_full_text(self, text: str, chunk_size: int = 10) -> list[str]:
|
||||
"""
|
||||
Convenience method to stream a complete text block.
|
||||
|
||||
Generates: text-start, text-deltas, text-end
|
||||
|
||||
Args:
|
||||
text: The full text to stream
|
||||
chunk_size: Size of each chunk
|
||||
|
||||
Returns:
|
||||
list[str]: List of all SSE formatted parts
|
||||
"""
|
||||
text_id = self.generate_text_id()
|
||||
parts = [self.format_text_start(text_id)]
|
||||
parts.extend(self.stream_text(text_id, text, chunk_size))
|
||||
parts.append(self.format_text_end(text_id))
|
||||
return parts
|
||||
|
||||
def stream_full_reasoning(self, reasoning: str, chunk_size: int = 20) -> list[str]:
|
||||
"""
|
||||
Convenience method to stream a complete reasoning block.
|
||||
|
||||
Generates: reasoning-start, reasoning-deltas, reasoning-end
|
||||
|
||||
Args:
|
||||
reasoning: The full reasoning text
|
||||
chunk_size: Size of each chunk
|
||||
|
||||
Returns:
|
||||
list[str]: List of all SSE formatted parts
|
||||
"""
|
||||
reasoning_id = self.generate_reasoning_id()
|
||||
parts = [self.format_reasoning_start(reasoning_id)]
|
||||
for i in range(0, len(reasoning), chunk_size):
|
||||
chunk = reasoning[i : i + chunk_size]
|
||||
parts.append(self.format_reasoning_delta(reasoning_id, chunk))
|
||||
parts.append(self.format_reasoning_end(reasoning_id))
|
||||
return parts
|
||||
|
||||
def create_complete_response(
|
||||
self,
|
||||
text: str,
|
||||
sources: list[dict[str, Any]] | None = None,
|
||||
reasoning: str | None = None,
|
||||
further_questions: list[str] | None = None,
|
||||
chunk_size: int = 10,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Create a complete streaming response with all parts.
|
||||
|
||||
This is a convenience method that generates a full response
|
||||
including message start, optional reasoning, text, sources,
|
||||
further questions, and finish markers.
|
||||
|
||||
Args:
|
||||
text: The main response text
|
||||
sources: Optional list of source references
|
||||
reasoning: Optional reasoning/thinking content
|
||||
further_questions: Optional follow-up questions
|
||||
chunk_size: Size of text chunks
|
||||
|
||||
Returns:
|
||||
list[str]: List of all SSE formatted parts in correct order
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# Start message
|
||||
parts.append(self.format_message_start())
|
||||
parts.append(self.format_start_step())
|
||||
|
||||
# Reasoning (if provided)
|
||||
if reasoning:
|
||||
parts.extend(self.stream_full_reasoning(reasoning))
|
||||
|
||||
# Sources (before main text)
|
||||
if sources:
|
||||
parts.extend(self.format_sources(sources))
|
||||
|
||||
# Main text content
|
||||
parts.extend(self.stream_full_text(text, chunk_size))
|
||||
|
||||
# Further questions (if provided)
|
||||
if further_questions:
|
||||
parts.append(self.format_further_questions(further_questions))
|
||||
|
||||
# Finish
|
||||
parts.append(self.format_finish_step())
|
||||
parts.append(self.format_finish())
|
||||
parts.append(self.format_done())
|
||||
|
||||
return parts
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the streaming context for a new message."""
|
||||
self.context = StreamContext()
|
||||
210
surfsense_backend/app/tasks/chat/stream_new_chat.py
Normal file
210
surfsense_backend/app/tasks/chat/stream_new_chat.py
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
"""
|
||||
Streaming task for the new SurfSense deep agent chat.
|
||||
|
||||
This module streams responses from the deep agent using the Vercel AI SDK
|
||||
Data Stream Protocol (SSE format).
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.chat_deepagent import (
|
||||
create_chat_litellm_from_config,
|
||||
create_surfsense_deep_agent,
|
||||
load_llm_config_from_yaml,
|
||||
)
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
|
||||
|
||||
async def stream_new_chat(
|
||||
user_query: str,
|
||||
user_id: str | UUID,
|
||||
search_space_id: int,
|
||||
chat_id: int,
|
||||
session: AsyncSession,
|
||||
llm_config_id: int = -1,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream chat responses from the new SurfSense deep agent.
|
||||
|
||||
This uses the Vercel AI SDK Data Stream Protocol (SSE format) for streaming.
|
||||
The chat_id is used as LangGraph's thread_id for memory/checkpointing,
|
||||
so chat history is automatically managed by LangGraph.
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
user_id: The user's ID (can be UUID object or string)
|
||||
search_space_id: The search space ID
|
||||
chat_id: The chat ID (used as LangGraph thread_id for memory)
|
||||
session: The database session
|
||||
llm_config_id: The LLM configuration ID (default: -1 for first global config)
|
||||
|
||||
Yields:
|
||||
str: SSE formatted response strings
|
||||
"""
|
||||
streaming_service = VercelStreamingService()
|
||||
|
||||
# Convert UUID to string if needed
|
||||
str(user_id) if isinstance(user_id, UUID) else user_id
|
||||
|
||||
# Track the current text block for streaming (defined early for exception handling)
|
||||
current_text_id: str | None = None
|
||||
|
||||
try:
|
||||
# Load LLM config
|
||||
llm_config = load_llm_config_from_yaml(llm_config_id=llm_config_id)
|
||||
if not llm_config:
|
||||
yield streaming_service.format_error(
|
||||
f"Failed to load LLM config with id {llm_config_id}"
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
# Create ChatLiteLLM instance
|
||||
llm = create_chat_litellm_from_config(llm_config)
|
||||
if not llm:
|
||||
yield streaming_service.format_error("Failed to create LLM instance")
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
# Create connector service
|
||||
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
||||
|
||||
# Create the deep agent
|
||||
agent = create_surfsense_deep_agent(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
connector_service=connector_service,
|
||||
)
|
||||
|
||||
# Build input with just the current user query
|
||||
# Chat history is managed by LangGraph via thread_id
|
||||
input_state = {
|
||||
"messages": [HumanMessage(content=user_query)],
|
||||
"search_space_id": search_space_id,
|
||||
}
|
||||
|
||||
# Configure LangGraph with thread_id for memory
|
||||
config = {
|
||||
"configurable": {
|
||||
"thread_id": str(chat_id),
|
||||
}
|
||||
}
|
||||
|
||||
# Start the message stream
|
||||
yield streaming_service.format_message_start()
|
||||
yield streaming_service.format_start_step()
|
||||
|
||||
# Reset text tracking for this stream
|
||||
accumulated_text = ""
|
||||
|
||||
# Stream the agent response with thread config for memory
|
||||
async for event in agent.astream_events(
|
||||
input_state, config=config, version="v2"
|
||||
):
|
||||
event_type = event.get("event", "")
|
||||
|
||||
# Handle chat model stream events (text streaming)
|
||||
if event_type == "on_chat_model_stream":
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
if chunk and hasattr(chunk, "content"):
|
||||
content = chunk.content
|
||||
if content and isinstance(content, str):
|
||||
# Start a new text block if needed
|
||||
if current_text_id is None:
|
||||
current_text_id = streaming_service.generate_text_id()
|
||||
yield streaming_service.format_text_start(current_text_id)
|
||||
|
||||
# Stream the text delta
|
||||
yield streaming_service.format_text_delta(
|
||||
current_text_id, content
|
||||
)
|
||||
accumulated_text += content
|
||||
|
||||
# Handle tool calls
|
||||
elif event_type == "on_tool_start":
|
||||
tool_name = event.get("name", "unknown_tool")
|
||||
run_id = event.get("run_id", "")
|
||||
tool_input = event.get("data", {}).get("input", {})
|
||||
|
||||
# End current text block if any
|
||||
if current_text_id is not None:
|
||||
yield streaming_service.format_text_end(current_text_id)
|
||||
current_text_id = None
|
||||
|
||||
# Stream tool info
|
||||
tool_call_id = (
|
||||
f"call_{run_id[:32]}"
|
||||
if run_id
|
||||
else streaming_service.generate_tool_call_id()
|
||||
)
|
||||
yield streaming_service.format_tool_input_start(tool_call_id, tool_name)
|
||||
yield streaming_service.format_tool_input_available(
|
||||
tool_call_id,
|
||||
tool_name,
|
||||
tool_input
|
||||
if isinstance(tool_input, dict)
|
||||
else {"input": tool_input},
|
||||
)
|
||||
|
||||
# Send terminal info about the tool call
|
||||
if tool_name == "search_knowledge_base":
|
||||
query = (
|
||||
tool_input.get("query", "")
|
||||
if isinstance(tool_input, dict)
|
||||
else str(tool_input)
|
||||
)
|
||||
yield streaming_service.format_terminal_info(
|
||||
f"Searching knowledge base: {query[:100]}{'...' if len(query) > 100 else ''}",
|
||||
"info",
|
||||
)
|
||||
|
||||
elif event_type == "on_tool_end":
|
||||
run_id = event.get("run_id", "")
|
||||
tool_output = event.get("data", {}).get("output", "")
|
||||
|
||||
tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown"
|
||||
|
||||
# Don't stream the full output (can be very large), just acknowledge
|
||||
yield streaming_service.format_tool_output_available(
|
||||
tool_call_id,
|
||||
{"status": "completed", "result_length": len(str(tool_output))},
|
||||
)
|
||||
|
||||
yield streaming_service.format_terminal_info(
|
||||
"Knowledge base search completed", "success"
|
||||
)
|
||||
|
||||
# Handle chain/agent end to close any open text blocks
|
||||
elif event_type in ("on_chain_end", "on_agent_end"):
|
||||
if current_text_id is not None:
|
||||
yield streaming_service.format_text_end(current_text_id)
|
||||
current_text_id = None
|
||||
|
||||
# Ensure text block is closed
|
||||
if current_text_id is not None:
|
||||
yield streaming_service.format_text_end(current_text_id)
|
||||
|
||||
# Finish the step and message
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
|
||||
except Exception as e:
|
||||
# Handle any errors
|
||||
error_message = f"Error during chat: {e!s}"
|
||||
print(f"[stream_new_chat] {error_message}")
|
||||
|
||||
# Close any open text block
|
||||
if current_text_id is not None:
|
||||
yield streaming_service.format_text_end(current_text_id)
|
||||
|
||||
yield streaming_service.format_error(error_message)
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
|
|
@ -52,6 +52,7 @@ dependencies = [
|
|||
"langchain-litellm>=0.3.5",
|
||||
"langgraph>=1.0.5",
|
||||
"fake-useragent>=2.2.0",
|
||||
"deepagents>=0.3.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
|
|
|||
5886
surfsense_backend/uv.lock
generated
5886
surfsense_backend/uv.lock
generated
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue