Merge remote-tracking branch 'upstream/dev' into fix/auth

This commit is contained in:
Anish Sarkar 2026-02-10 11:36:06 +05:30
commit 2dec643cb4
80 changed files with 2968 additions and 2379 deletions

View file

@ -22,6 +22,7 @@ from app.agents.new_chat.system_prompt import (
build_surfsense_system_prompt,
)
from app.agents.new_chat.tools.registry import build_tools_async
from app.db import ChatVisibility
from app.services.connector_service import ConnectorService
# =============================================================================
@ -126,6 +127,7 @@ async def create_surfsense_deep_agent(
disabled_tools: list[str] | None = None,
additional_tools: Sequence[BaseTool] | None = None,
firecrawl_api_key: str | None = None,
thread_visibility: ChatVisibility | None = None,
):
"""
Create a SurfSense deep agent with configurable tools and prompts.
@ -228,14 +230,15 @@ async def create_surfsense_deep_agent(
logging.warning(f"Failed to discover available connectors/document types: {e}")
# Build dependencies dict for the tools registry
visibility = thread_visibility or ChatVisibility.PRIVATE
dependencies = {
"search_space_id": search_space_id,
"db_session": db_session,
"connector_service": connector_service,
"firecrawl_api_key": firecrawl_api_key,
"user_id": user_id, # Required for memory tools
"thread_id": thread_id, # For podcast tool
# Dynamic connector/document type discovery for knowledge base tool
"user_id": user_id,
"thread_id": thread_id,
"thread_visibility": visibility,
"available_connectors": available_connectors,
"available_document_types": available_document_types,
}
@ -255,10 +258,12 @@ async def create_surfsense_deep_agent(
custom_system_instructions=agent_config.system_instructions,
use_default_system_instructions=agent_config.use_default_system_instructions,
citations_enabled=agent_config.citations_enabled,
thread_visibility=thread_visibility,
)
else:
# Use default prompt (with citations enabled)
system_prompt = build_surfsense_system_prompt()
system_prompt = build_surfsense_system_prompt(
thread_visibility=thread_visibility,
)
# Create the deep agent with system prompt and checkpointer
# Note: TodoListMiddleware (write_todos) is included by default in create_deep_agent

View file

@ -45,6 +45,7 @@ PROVIDER_MAP = {
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai",
"GITHUB_MODELS": "github",
"REPLICATE": "replicate",
"PERPLEXITY": "perplexity",
"ANYSCALE": "anyscale",

View file

@ -12,6 +12,8 @@ The prompt is composed of three parts:
from datetime import UTC, datetime
from app.db import ChatVisibility
# Default system instructions - can be overridden via NewLLMConfig.system_instructions
SURFSENSE_SYSTEM_INSTRUCTIONS = """
<system_instruction>
@ -22,7 +24,34 @@ Today's date (UTC): {resolved_today}
</system_instruction>
"""
SURFSENSE_TOOLS_INSTRUCTIONS = """
# Default system instructions for shared (team) threads: team context + message format for attribution
_SYSTEM_INSTRUCTIONS_SHARED = """
<system_instruction>
You are SurfSense, a reasoning and acting AI agent designed to answer questions in this team space using the team's shared knowledge base.
In this team thread, each message is prefixed with **[DisplayName of the author]**. Use this to attribute and reference the author of anything in the discussion (who asked a question, made a suggestion, or contributed an idea) and to cite who said what in your answers.
Today's date (UTC): {resolved_today}
</system_instruction>
"""
def _get_system_instructions(
thread_visibility: ChatVisibility | None = None, today: datetime | None = None
) -> str:
"""Build system instructions based on thread visibility (private vs shared)."""
resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat()
visibility = thread_visibility or ChatVisibility.PRIVATE
if visibility == ChatVisibility.SEARCH_SPACE:
return _SYSTEM_INSTRUCTIONS_SHARED.format(resolved_today=resolved_today)
else:
return SURFSENSE_SYSTEM_INSTRUCTIONS.format(resolved_today=resolved_today)
# Tools 0-6 (common to both private and shared prompts)
_TOOLS_INSTRUCTIONS_COMMON = """
<tools>
You have access to the following tools:
@ -138,7 +167,11 @@ You have access to the following tools:
* Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content.
* Don't show every image - just the most relevant 1-3 images that enhance understanding.
7. save_memory: Save facts, preferences, or context about the user for personalized responses.
"""
# Private (user) memory: tools 7-8 + memory-specific examples
_TOOLS_INSTRUCTIONS_MEMORY_PRIVATE = """
7. save_memory: Save facts, preferences, or context for personalized responses.
- Use this when the user explicitly or implicitly shares information worth remembering.
- Trigger scenarios:
* User says "remember this", "keep this in mind", "note that", or similar
@ -178,6 +211,75 @@ You have access to the following tools:
stating "Based on your memory..." - integrate the context seamlessly.
</tools>
<tool_call_examples>
- User: "Remember that I prefer TypeScript over JavaScript"
- Call: `save_memory(content="User prefers TypeScript over JavaScript for development", category="preference")`
- User: "I'm a data scientist working on ML pipelines"
- Call: `save_memory(content="User is a data scientist working on ML pipelines", category="fact")`
- User: "Always give me code examples in Python"
- Call: `save_memory(content="User wants code examples to be written in Python", category="instruction")`
- User: "What programming language should I use for this project?"
- First recall: `recall_memory(query="programming language preferences")`
- Then provide a personalized recommendation based on their preferences
- User: "What do you know about me?"
- Call: `recall_memory(top_k=10)`
- Then summarize the stored memories
"""
# Shared (team) memory: tools 7-8 + team memory examples
_TOOLS_INSTRUCTIONS_MEMORY_SHARED = """
7. save_memory: Save a fact, preference, or context to the team's shared memory for future reference.
- Use this when the user or a team member says "remember this", "keep this in mind", or similar in this shared chat.
- Use when the team agrees on something to remember (e.g., decisions, conventions).
- Someone shares a preference or fact that should be visible to the whole team.
- The saved information will be available in future shared conversations in this space.
- Args:
- content: The fact/preference/context to remember. Phrase it clearly, e.g. "API keys are stored in Vault", "The team prefers weekly demos on Fridays"
- category: Type of memory. One of:
* "preference": Team or workspace preferences
* "fact": Facts the team agreed on (e.g., processes, locations)
* "instruction": Standing instructions for the team
* "context": Current context (e.g., ongoing projects, goals)
- Returns: Confirmation of saved memory; returned context may include who added it (added_by).
- IMPORTANT: Only save information that would be genuinely useful for future team conversations in this space.
8. recall_memory: Recall relevant team memories for this space to provide contextual responses.
- Use when you need team context to answer (e.g., "where do we store X?", "what did we decide about Y?").
- Use when someone asks about something the team agreed to remember.
- Use when team preferences or conventions would improve the response.
- Args:
- query: Optional search query to find specific memories. If not provided, returns the most recent memories.
- category: Optional filter by category ("preference", "fact", "instruction", "context")
- top_k: Number of memories to retrieve (default: 5, max: 20)
- Returns: Relevant team memories and formatted context (may include added_by). Integrate naturally without saying "Based on team memory...".
</tools>
<tool_call_examples>
- User: "Remember that API keys are stored in Vault"
- Call: `save_memory(content="API keys are stored in Vault", category="fact")`
- User: "Let's remember that the team prefers weekly demos on Fridays"
- Call: `save_memory(content="The team prefers weekly demos on Fridays", category="preference")`
- User: "What did we decide about the release date?"
- First recall: `recall_memory(query="release date decision")`
- Then answer based on the team memories
- User: "Where do we document onboarding?"
- Call: `recall_memory(query="onboarding documentation")`
- Then answer using the recalled team context
- User: "What have we agreed to remember about our deployment process?"
- Call: `recall_memory(query="deployment process", top_k=10)`
- Then summarize the relevant team memories
"""
# Examples shared by both private and shared prompts (knowledge base, docs, podcast, links, images, etc.)
_TOOLS_INSTRUCTIONS_EXAMPLES_COMMON = """
- User: "What time is the team meeting today?"
- Call: `search_knowledge_base(query="team meeting time today")` (searches ALL sources - calendar, notes, Obsidian, etc.)
- DO NOT limit to just calendar - the info might be in notes!
@ -209,23 +311,6 @@ You have access to the following tools:
- User: "What's in my Obsidian vault about project ideas?"
- Call: `search_knowledge_base(query="project ideas", connectors_to_search=["OBSIDIAN_CONNECTOR"])`
- User: "Remember that I prefer TypeScript over JavaScript"
- Call: `save_memory(content="User prefers TypeScript over JavaScript for development", category="preference")`
- User: "I'm a data scientist working on ML pipelines"
- Call: `save_memory(content="User is a data scientist working on ML pipelines", category="fact")`
- User: "Always give me code examples in Python"
- Call: `save_memory(content="User wants code examples to be written in Python", category="instruction")`
- User: "What programming language should I use for this project?"
- First recall: `recall_memory(query="programming language preferences")`
- Then provide a personalized recommendation based on their preferences
- User: "What do you know about me?"
- Call: `recall_memory(top_k=10)`
- Then summarize the stored memories
- User: "Give me a podcast about AI trends based on what we discussed"
- First search for relevant content, then call: `generate_podcast(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", podcast_title="AI Trends Podcast")`
@ -315,6 +400,31 @@ You have access to the following tools:
</tool_call_examples>
"""
# Reassemble so existing callers see no change (same full prompt)
SURFSENSE_TOOLS_INSTRUCTIONS = (
_TOOLS_INSTRUCTIONS_COMMON
+ _TOOLS_INSTRUCTIONS_MEMORY_PRIVATE
+ _TOOLS_INSTRUCTIONS_EXAMPLES_COMMON
)
def _get_tools_instructions(thread_visibility: ChatVisibility | None = None) -> str:
"""Build tools instructions based on thread visibility (private vs shared).
For private chats: use user-focused memory wording and examples.
For shared chats: use team memory wording and examples.
"""
visibility = thread_visibility or ChatVisibility.PRIVATE
memory_block = (
_TOOLS_INSTRUCTIONS_MEMORY_SHARED
if visibility == ChatVisibility.SEARCH_SPACE
else _TOOLS_INSTRUCTIONS_MEMORY_PRIVATE
)
return (
_TOOLS_INSTRUCTIONS_COMMON + memory_block + _TOOLS_INSTRUCTIONS_EXAMPLES_COMMON
)
SURFSENSE_CITATION_INSTRUCTIONS = """
<citation_instructions>
CRITICAL CITATION REQUIREMENTS:
@ -413,6 +523,7 @@ Your goal is to provide helpful, informative answers in a clean, readable format
def build_surfsense_system_prompt(
today: datetime | None = None,
thread_visibility: ChatVisibility | None = None,
) -> str:
"""
Build the SurfSense system prompt with default settings.
@ -424,17 +535,17 @@ def build_surfsense_system_prompt(
Args:
today: Optional datetime for today's date (defaults to current UTC date)
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
Returns:
Complete system prompt string
"""
resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat()
return (
SURFSENSE_SYSTEM_INSTRUCTIONS.format(resolved_today=resolved_today)
+ SURFSENSE_TOOLS_INSTRUCTIONS
+ SURFSENSE_CITATION_INSTRUCTIONS
)
visibility = thread_visibility or ChatVisibility.PRIVATE
system_instructions = _get_system_instructions(visibility, today)
tools_instructions = _get_tools_instructions(visibility)
citation_instructions = SURFSENSE_CITATION_INSTRUCTIONS
return system_instructions + tools_instructions + citation_instructions
def build_configurable_system_prompt(
@ -442,6 +553,7 @@ def build_configurable_system_prompt(
use_default_system_instructions: bool = True,
citations_enabled: bool = True,
today: datetime | None = None,
thread_visibility: ChatVisibility | None = None,
) -> str:
"""
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
@ -460,6 +572,7 @@ def build_configurable_system_prompt(
citations_enabled: Whether to include citation instructions (True) or
anti-citation instructions (False).
today: Optional datetime for today's date (defaults to current UTC date)
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
Returns:
Complete system prompt string
@ -473,16 +586,14 @@ def build_configurable_system_prompt(
resolved_today=resolved_today
)
elif use_default_system_instructions:
# Use default instructions
system_instructions = SURFSENSE_SYSTEM_INSTRUCTIONS.format(
resolved_today=resolved_today
)
visibility = thread_visibility or ChatVisibility.PRIVATE
system_instructions = _get_system_instructions(visibility, today)
else:
# No system instructions (edge case)
system_instructions = ""
# Tools instructions are always included
tools_instructions = SURFSENSE_TOOLS_INSTRUCTIONS
# Tools instructions: conditional on thread_visibility (private vs shared memory wording)
tools_instructions = _get_tools_instructions(thread_visibility)
# Citation instructions based on toggle
citation_instructions = (

View file

@ -8,6 +8,7 @@ This module provides:
- Tool factory for creating search_knowledge_base tools
"""
import asyncio
import json
from datetime import datetime
from typing import Any
@ -16,6 +17,7 @@ from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from app.services.connector_service import ConnectorService
# =============================================================================
@ -333,7 +335,7 @@ async def search_knowledge_base_async(
Returns:
Formatted string with search results
"""
all_documents = []
all_documents: list[dict[str, Any]] = []
# Resolve date range (default last 2 years)
from app.agents.new_chat.utils import resolve_date_range
@ -345,323 +347,131 @@ async def search_knowledge_base_async(
connectors = _normalize_connectors(connectors_to_search, available_connectors)
for connector in connectors:
connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = {
"YOUTUBE_VIDEO": ("search_youtube", True, True, {}),
"EXTENSION": ("search_extension", True, True, {}),
"CRAWLED_URL": ("search_crawled_urls", True, True, {}),
"FILE": ("search_files", True, True, {}),
"SLACK_CONNECTOR": ("search_slack", True, True, {}),
"TEAMS_CONNECTOR": ("search_teams", True, True, {}),
"NOTION_CONNECTOR": ("search_notion", True, True, {}),
"GITHUB_CONNECTOR": ("search_github", True, True, {}),
"LINEAR_CONNECTOR": ("search_linear", True, True, {}),
"TAVILY_API": ("search_tavily", False, True, {}),
"SEARXNG_API": ("search_searxng", False, True, {}),
"LINKUP_API": ("search_linkup", False, False, {"mode": "standard"}),
"BAIDU_SEARCH_API": ("search_baidu", False, True, {}),
"DISCORD_CONNECTOR": ("search_discord", True, True, {}),
"JIRA_CONNECTOR": ("search_jira", True, True, {}),
"GOOGLE_CALENDAR_CONNECTOR": ("search_google_calendar", True, True, {}),
"AIRTABLE_CONNECTOR": ("search_airtable", True, True, {}),
"GOOGLE_GMAIL_CONNECTOR": ("search_google_gmail", True, True, {}),
"GOOGLE_DRIVE_FILE": ("search_google_drive", True, True, {}),
"CONFLUENCE_CONNECTOR": ("search_confluence", True, True, {}),
"CLICKUP_CONNECTOR": ("search_clickup", True, True, {}),
"LUMA_CONNECTOR": ("search_luma", True, True, {}),
"ELASTICSEARCH_CONNECTOR": ("search_elasticsearch", True, True, {}),
"NOTE": ("search_notes", True, True, {}),
"BOOKSTACK_CONNECTOR": ("search_bookstack", True, True, {}),
"CIRCLEBACK": ("search_circleback", True, True, {}),
"OBSIDIAN_CONNECTOR": ("search_obsidian", True, True, {}),
# Composio connectors
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR": (
"search_composio_google_drive",
True,
True,
{},
),
"COMPOSIO_GMAIL_CONNECTOR": ("search_composio_gmail", True, True, {}),
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": (
"search_composio_google_calendar",
True,
True,
{},
),
}
# Keep a conservative cap to avoid overloading DB/external services.
max_parallel_searches = 4
semaphore = asyncio.Semaphore(max_parallel_searches)
async def _search_one_connector(connector: str) -> list[dict[str, Any]]:
spec = connector_specs.get(connector)
if spec is None:
return []
method_name, includes_date_range, includes_top_k, extra_kwargs = spec
kwargs: dict[str, Any] = {
"user_query": query,
"search_space_id": search_space_id,
**extra_kwargs,
}
if includes_top_k:
kwargs["top_k"] = top_k
if includes_date_range:
kwargs["start_date"] = resolved_start_date
kwargs["end_date"] = resolved_end_date
try:
if connector == "YOUTUBE_VIDEO":
_, chunks = await connector_service.search_youtube(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
# Use isolated session per connector. Shared AsyncSession cannot safely
# run concurrent DB operations.
async with semaphore, async_session_maker() as isolated_session:
isolated_connector_service = ConnectorService(
isolated_session, search_space_id
)
all_documents.extend(chunks)
elif connector == "EXTENSION":
_, chunks = await connector_service.search_extension(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "CRAWLED_URL":
_, chunks = await connector_service.search_crawled_urls(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "FILE":
_, chunks = await connector_service.search_files(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "SLACK_CONNECTOR":
_, chunks = await connector_service.search_slack(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "TEAMS_CONNECTOR":
_, chunks = await connector_service.search_teams(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "NOTION_CONNECTOR":
_, chunks = await connector_service.search_notion(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "GITHUB_CONNECTOR":
_, chunks = await connector_service.search_github(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "LINEAR_CONNECTOR":
_, chunks = await connector_service.search_linear(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "TAVILY_API":
_, chunks = await connector_service.search_tavily(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
)
all_documents.extend(chunks)
elif connector == "SEARXNG_API":
_, chunks = await connector_service.search_searxng(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
)
all_documents.extend(chunks)
elif connector == "LINKUP_API":
# Keep behavior aligned with researcher: default "standard"
_, chunks = await connector_service.search_linkup(
user_query=query,
search_space_id=search_space_id,
mode="standard",
)
all_documents.extend(chunks)
elif connector == "BAIDU_SEARCH_API":
_, chunks = await connector_service.search_baidu(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
)
all_documents.extend(chunks)
elif connector == "DISCORD_CONNECTOR":
_, chunks = await connector_service.search_discord(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "JIRA_CONNECTOR":
_, chunks = await connector_service.search_jira(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "GOOGLE_CALENDAR_CONNECTOR":
_, chunks = await connector_service.search_google_calendar(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "AIRTABLE_CONNECTOR":
_, chunks = await connector_service.search_airtable(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "GOOGLE_GMAIL_CONNECTOR":
_, chunks = await connector_service.search_google_gmail(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "GOOGLE_DRIVE_FILE":
_, chunks = await connector_service.search_google_drive(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "CONFLUENCE_CONNECTOR":
_, chunks = await connector_service.search_confluence(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "CLICKUP_CONNECTOR":
_, chunks = await connector_service.search_clickup(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "LUMA_CONNECTOR":
_, chunks = await connector_service.search_luma(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "ELASTICSEARCH_CONNECTOR":
_, chunks = await connector_service.search_elasticsearch(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "NOTE":
_, chunks = await connector_service.search_notes(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "BOOKSTACK_CONNECTOR":
_, chunks = await connector_service.search_bookstack(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "CIRCLEBACK":
_, chunks = await connector_service.search_circleback(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "OBSIDIAN_CONNECTOR":
_, chunks = await connector_service.search_obsidian(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
# =========================================================
# Composio Connectors
# =========================================================
elif connector == "COMPOSIO_GOOGLE_DRIVE_CONNECTOR":
_, chunks = await connector_service.search_composio_google_drive(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "COMPOSIO_GMAIL_CONNECTOR":
_, chunks = await connector_service.search_composio_gmail(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR":
_, chunks = await connector_service.search_composio_google_calendar(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
connector_method = getattr(isolated_connector_service, method_name)
_, chunks = await connector_method(**kwargs)
return chunks
except Exception as e:
print(f"Error searching connector {connector}: {e}")
continue
return []
# Deduplicate by content hash
connector_results = await asyncio.gather(
*[_search_one_connector(connector) for connector in connectors]
)
for chunks in connector_results:
all_documents.extend(chunks)
# Deduplicate primarily by document ID. Only fall back to content hashing
# when a document has no ID.
seen_doc_ids: set[Any] = set()
seen_hashes: set[int] = set()
seen_content_hashes: set[int] = set()
deduplicated: list[dict[str, Any]] = []
def _content_fingerprint(document: dict[str, Any]) -> int | None:
chunks = document.get("chunks")
if isinstance(chunks, list):
chunk_texts = []
for chunk in chunks:
if not isinstance(chunk, dict):
continue
chunk_content = (chunk.get("content") or "").strip()
if chunk_content:
chunk_texts.append(chunk_content)
if chunk_texts:
return hash("||".join(chunk_texts))
flat_content = (document.get("content") or "").strip()
if flat_content:
return hash(flat_content)
return None
for doc in all_documents:
doc_id = (doc.get("document", {}) or {}).get("id")
content = (doc.get("content", "") or "").strip()
content_hash = hash(content)
if (doc_id and doc_id in seen_doc_ids) or content_hash in seen_hashes:
if doc_id is not None:
if doc_id in seen_doc_ids:
continue
seen_doc_ids.add(doc_id)
deduplicated.append(doc)
continue
if doc_id:
seen_doc_ids.add(doc_id)
seen_hashes.add(content_hash)
content_hash = _content_fingerprint(doc)
if content_hash is not None:
if content_hash in seen_content_hashes:
continue
seen_content_hashes.add(content_hash)
deduplicated.append(doc)
return format_documents_for_context(deduplicated)

View file

@ -11,21 +11,18 @@ Duplicate request prevention:
- Returns a friendly message if a podcast is already being generated
"""
import os
from typing import Any
import redis
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import Podcast, PodcastStatus
# Redis connection for tracking active podcast tasks
# Defaults to the Celery broker when REDIS_APP_URL is not set
REDIS_URL = os.getenv(
"REDIS_APP_URL",
os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"),
)
REDIS_URL = config.REDIS_APP_URL
_redis_client: redis.Redis | None = None

View file

@ -43,6 +43,8 @@ from typing import Any
from langchain_core.tools import BaseTool
from app.db import ChatVisibility
from .display_image import create_display_image_tool
from .generate_image import create_generate_image_tool
from .knowledge_base import create_search_knowledge_base_tool
@ -51,6 +53,10 @@ from .mcp_tool import load_mcp_tools
from .podcast import create_generate_podcast_tool
from .scrape_webpage import create_scrape_webpage_tool
from .search_surfsense_docs import create_search_surfsense_docs_tool
from .shared_memory import (
create_recall_shared_memory_tool,
create_save_shared_memory_tool,
)
from .user_memory import create_recall_memory_tool, create_save_memory_tool
# =============================================================================
@ -156,29 +162,42 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
requires=["db_session"],
),
# =========================================================================
# USER MEMORY TOOLS - Claude-like memory feature
# USER MEMORY TOOLS - private or team store by thread_visibility
# =========================================================================
# Save memory tool - stores facts/preferences about the user
ToolDefinition(
name="save_memory",
description="Save facts, preferences, or context about the user for personalized responses",
factory=lambda deps: create_save_memory_tool(
user_id=deps["user_id"],
search_space_id=deps["search_space_id"],
db_session=deps["db_session"],
description="Save facts, preferences, or context for personalized or team responses",
factory=lambda deps: (
create_save_shared_memory_tool(
search_space_id=deps["search_space_id"],
created_by_id=deps["user_id"],
db_session=deps["db_session"],
)
if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE
else create_save_memory_tool(
user_id=deps["user_id"],
search_space_id=deps["search_space_id"],
db_session=deps["db_session"],
)
),
requires=["user_id", "search_space_id", "db_session"],
requires=["user_id", "search_space_id", "db_session", "thread_visibility"],
),
# Recall memory tool - retrieves relevant user memories
ToolDefinition(
name="recall_memory",
description="Recall user memories for personalized and contextual responses",
factory=lambda deps: create_recall_memory_tool(
user_id=deps["user_id"],
search_space_id=deps["search_space_id"],
db_session=deps["db_session"],
description="Recall relevant memories (personal or team) for context",
factory=lambda deps: (
create_recall_shared_memory_tool(
search_space_id=deps["search_space_id"],
db_session=deps["db_session"],
)
if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE
else create_recall_memory_tool(
user_id=deps["user_id"],
search_space_id=deps["search_space_id"],
db_session=deps["db_session"],
)
),
requires=["user_id", "search_space_id", "db_session"],
requires=["user_id", "search_space_id", "db_session", "thread_visibility"],
),
# =========================================================================
# ADD YOUR CUSTOM TOOLS BELOW

View file

@ -0,0 +1,280 @@
"""Shared (team) memory backend for search-space-scoped AI context."""
import logging
from typing import Any
from uuid import UUID
from langchain_core.tools import tool
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import MemoryCategory, SharedMemory, User
logger = logging.getLogger(__name__)
DEFAULT_RECALL_TOP_K = 5
MAX_MEMORIES_PER_SEARCH_SPACE = 250
async def get_shared_memory_count(
db_session: AsyncSession,
search_space_id: int,
) -> int:
result = await db_session.execute(
select(SharedMemory).where(SharedMemory.search_space_id == search_space_id)
)
return len(result.scalars().all())
async def delete_oldest_shared_memory(
db_session: AsyncSession,
search_space_id: int,
) -> None:
result = await db_session.execute(
select(SharedMemory)
.where(SharedMemory.search_space_id == search_space_id)
.order_by(SharedMemory.updated_at.asc())
.limit(1)
)
oldest = result.scalars().first()
if oldest:
await db_session.delete(oldest)
await db_session.commit()
def _to_uuid(value: str | UUID) -> UUID:
if isinstance(value, UUID):
return value
return UUID(value)
async def save_shared_memory(
db_session: AsyncSession,
search_space_id: int,
created_by_id: str | UUID,
content: str,
category: str = "fact",
) -> dict[str, Any]:
category = category.lower() if category else "fact"
valid = ["preference", "fact", "instruction", "context"]
if category not in valid:
category = "fact"
try:
count = await get_shared_memory_count(db_session, search_space_id)
if count >= MAX_MEMORIES_PER_SEARCH_SPACE:
await delete_oldest_shared_memory(db_session, search_space_id)
embedding = config.embedding_model_instance.embed(content)
row = SharedMemory(
search_space_id=search_space_id,
created_by_id=_to_uuid(created_by_id),
memory_text=content,
category=MemoryCategory(category),
embedding=embedding,
)
db_session.add(row)
await db_session.commit()
await db_session.refresh(row)
return {
"status": "saved",
"memory_id": row.id,
"memory_text": content,
"category": category,
"message": f"I'll remember: {content}",
}
except Exception as e:
logger.exception("Failed to save shared memory: %s", e)
await db_session.rollback()
return {
"status": "error",
"error": str(e),
"message": "Failed to save memory. Please try again.",
}
async def recall_shared_memory(
db_session: AsyncSession,
search_space_id: int,
query: str | None = None,
category: str | None = None,
top_k: int = DEFAULT_RECALL_TOP_K,
) -> dict[str, Any]:
top_k = min(max(top_k, 1), 20)
try:
valid_categories = ["preference", "fact", "instruction", "context"]
stmt = select(SharedMemory).where(
SharedMemory.search_space_id == search_space_id
)
if category and category in valid_categories:
stmt = stmt.where(SharedMemory.category == MemoryCategory(category))
if query:
query_embedding = config.embedding_model_instance.embed(query)
stmt = stmt.order_by(
SharedMemory.embedding.op("<=>")(query_embedding)
).limit(top_k)
else:
stmt = stmt.order_by(SharedMemory.updated_at.desc()).limit(top_k)
result = await db_session.execute(stmt)
rows = result.scalars().all()
memory_list = [
{
"id": m.id,
"memory_text": m.memory_text,
"category": m.category.value if m.category else "unknown",
"updated_at": m.updated_at.isoformat() if m.updated_at else None,
"created_by_id": str(m.created_by_id) if m.created_by_id else None,
}
for m in rows
]
created_by_ids = list(
{m["created_by_id"] for m in memory_list if m["created_by_id"]}
)
created_by_map: dict[str, str] = {}
if created_by_ids:
uuids = [UUID(uid) for uid in created_by_ids]
users_result = await db_session.execute(
select(User).where(User.id.in_(uuids))
)
for u in users_result.scalars().all():
created_by_map[str(u.id)] = u.display_name or "A team member"
formatted_context = format_shared_memories_for_context(
memory_list, created_by_map
)
return {
"status": "success",
"count": len(memory_list),
"memories": memory_list,
"formatted_context": formatted_context,
}
except Exception as e:
logger.exception("Failed to recall shared memory: %s", e)
await db_session.rollback()
return {
"status": "error",
"error": str(e),
"memories": [],
"formatted_context": "Failed to recall memories.",
}
def format_shared_memories_for_context(
memories: list[dict[str, Any]],
created_by_map: dict[str, str] | None = None,
) -> str:
if not memories:
return "No relevant team memories found."
created_by_map = created_by_map or {}
parts = ["<team_memories>"]
for memory in memories:
category = memory.get("category", "unknown")
text = memory.get("memory_text", "")
updated = memory.get("updated_at", "")
created_by_id = memory.get("created_by_id")
added_by = (
created_by_map.get(str(created_by_id), "A team member")
if created_by_id is not None
else "A team member"
)
parts.append(
f" <memory category='{category}' updated='{updated}' added_by='{added_by}'>{text}</memory>"
)
parts.append("</team_memories>")
return "\n".join(parts)
def create_save_shared_memory_tool(
search_space_id: int,
created_by_id: str | UUID,
db_session: AsyncSession,
):
"""
Factory function to create the save_memory tool for shared (team) chats.
Args:
search_space_id: The search space ID
created_by_id: The user ID of the person adding the memory
db_session: Database session for executing queries
Returns:
A configured tool function for saving team memories
"""
@tool
async def save_memory(
content: str,
category: str = "fact",
) -> dict[str, Any]:
"""
Save a fact, preference, or context to the team's shared memory for future reference.
Use this tool when:
- User or a team member says "remember this", "keep this in mind", or similar in this shared chat
- The team agrees on something to remember (e.g., decisions, conventions, where things live)
- Someone shares a preference or fact that should be visible to the whole team
The saved information will be available in future shared conversations in this space.
Args:
content: The fact/preference/context to remember.
Phrase it clearly, e.g., "API keys are stored in Vault",
"The team prefers weekly demos on Fridays"
category: Type of memory. One of:
- "preference": Team or workspace preferences
- "fact": Facts the team agreed on (e.g., processes, locations)
- "instruction": Standing instructions for the team
- "context": Current context (e.g., ongoing projects, goals)
Returns:
A dictionary with the save status and memory details
"""
return await save_shared_memory(
db_session, search_space_id, created_by_id, content, category
)
return save_memory
def create_recall_shared_memory_tool(
search_space_id: int,
db_session: AsyncSession,
):
"""
Factory function to create the recall_memory tool for shared (team) chats.
Args:
search_space_id: The search space ID
db_session: Database session for executing queries
Returns:
A configured tool function for recalling team memories
"""
@tool
async def recall_memory(
query: str | None = None,
category: str | None = None,
top_k: int = DEFAULT_RECALL_TOP_K,
) -> dict[str, Any]:
"""
Recall relevant team memories for this space to provide contextual responses.
Use this tool when:
- You need team context to answer (e.g., "where do we store X?", "what did we decide about Y?")
- Someone asks about something the team agreed to remember
- Team preferences or conventions would improve the response
Args:
query: Optional search query to find specific memories.
If not provided, returns the most recent memories.
category: Optional category filter. One of:
"preference", "fact", "instruction", "context"
top_k: Number of memories to retrieve (default: 5, max: 20)
Returns:
A dictionary containing relevant memories and formatted context
"""
return await recall_shared_memory(
db_session, search_space_id, query, category, top_k
)
return recall_memory

View file

@ -213,6 +213,17 @@ class Config:
# Database
DATABASE_URL = os.getenv("DATABASE_URL")
# Celery / Redis
CELERY_BROKER_URL = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
CELERY_RESULT_BACKEND = os.getenv(
"CELERY_RESULT_BACKEND", "redis://localhost:6379/0"
)
CELERY_TASK_DEFAULT_QUEUE = os.getenv("CELERY_TASK_DEFAULT_QUEUE", "surfsense")
REDIS_APP_URL = os.getenv("REDIS_APP_URL", CELERY_BROKER_URL)
CONNECTOR_INDEXING_LOCK_TTL_SECONDS = int(
os.getenv("CONNECTOR_INDEXING_LOCK_TTL_SECONDS", str(8 * 60 * 60))
)
NEXT_FRONTEND_URL = os.getenv("NEXT_FRONTEND_URL")
# Backend URL to override the http to https in the OAuth redirect URI
BACKEND_URL = os.getenv("BACKEND_URL")

View file

@ -27,6 +27,12 @@ T = TypeVar("T")
MAX_RETRIES = 5
BASE_RETRY_DELAY = 1.0 # seconds
MAX_RETRY_DELAY = 60.0 # seconds (Notion's max request timeout)
MAX_RATE_LIMIT_WAIT_SECONDS = float(
getattr(config, "NOTION_MAX_RETRY_AFTER_SECONDS", 30.0)
)
MAX_TOTAL_RETRY_WAIT_SECONDS = float(
getattr(config, "NOTION_MAX_TOTAL_RETRY_WAIT_SECONDS", 120.0)
)
# Type alias for retry callback function
# Signature: async callback(retry_reason, attempt, max_attempts, wait_seconds) -> None
@ -292,6 +298,7 @@ class NotionHistoryConnector:
"""
last_exception: APIResponseError | None = None
retry_delay = BASE_RETRY_DELAY
total_wait_time = 0.0
for attempt in range(MAX_RETRIES):
try:
@ -325,6 +332,15 @@ class NotionHistoryConnector:
wait_time = retry_delay
else:
wait_time = retry_delay
# Avoid very long worker sleeps from external Retry-After values.
if wait_time > MAX_RATE_LIMIT_WAIT_SECONDS:
logger.warning(
f"Notion Retry-After ({wait_time}s) exceeds cap "
f"({MAX_RATE_LIMIT_WAIT_SECONDS}s). Clamping wait time."
)
wait_time = MAX_RATE_LIMIT_WAIT_SECONDS
logger.warning(
f"Notion API rate limited (429). "
f"Waiting {wait_time}s. Attempt {attempt + 1}/{MAX_RETRIES}"
@ -348,6 +364,14 @@ class NotionHistoryConnector:
# Notify about retry via callback (for user notifications)
# Call before sleeping so user sees the message while we wait
if total_wait_time + wait_time > MAX_TOTAL_RETRY_WAIT_SECONDS:
logger.error(
"Notion API retry budget exceeded "
f"({total_wait_time + wait_time:.1f}s > "
f"{MAX_TOTAL_RETRY_WAIT_SECONDS:.1f}s). Failing fast."
)
raise
if on_retry:
try:
await on_retry(
@ -362,6 +386,7 @@ class NotionHistoryConnector:
# Wait before retrying
await asyncio.sleep(wait_time)
total_wait_time += wait_time
# Exponential backoff for next attempt
retry_delay = min(retry_delay * 2, MAX_RETRY_DELAY)

View file

@ -211,6 +211,7 @@ class LiteLLMProvider(str, Enum):
DATABRICKS = "DATABRICKS"
COMETAPI = "COMETAPI"
HUGGINGFACE = "HUGGINGFACE"
GITHUB_MODELS = "GITHUB_MODELS"
CUSTOM = "CUSTOM"
@ -272,19 +273,19 @@ INCENTIVE_TASKS_CONFIG = {
IncentiveTaskType.GITHUB_STAR: {
"title": "Star our GitHub repository",
"description": "Show your support by starring SurfSense on GitHub",
"pages_reward": 100,
"pages_reward": 30,
"action_url": "https://github.com/MODSetter/SurfSense",
},
IncentiveTaskType.REDDIT_FOLLOW: {
"title": "Join our Subreddit",
"description": "Join the SurfSense community on Reddit",
"pages_reward": 100,
"pages_reward": 30,
"action_url": "https://www.reddit.com/r/SurfSense/",
},
IncentiveTaskType.DISCORD_JOIN: {
"title": "Join our Discord",
"description": "Join the SurfSense community on Discord",
"pages_reward": 100,
"pages_reward": 40,
"action_url": "https://discord.gg/ejRNvftDp9",
},
# Future tasks can be configured here:
@ -801,9 +802,8 @@ class MemoryCategory(str, Enum):
class UserMemory(BaseModel, TimestampMixin):
"""
Stores facts, preferences, and context about users for personalized AI responses.
Similar to Claude's memory feature - enables the AI to remember user information
across conversations.
Private memory: facts, preferences, context per user per search space.
Used only for private chats (not shared/team chats).
"""
__tablename__ = "user_memories"
@ -847,6 +847,40 @@ class UserMemory(BaseModel, TimestampMixin):
search_space = relationship("SearchSpace", back_populates="user_memories")
class SharedMemory(BaseModel, TimestampMixin):
__tablename__ = "shared_memories"
search_space_id = Column(
Integer,
ForeignKey("searchspaces.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
created_by_id = Column(
UUID(as_uuid=True),
ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
memory_text = Column(Text, nullable=False)
category = Column(
SQLAlchemyEnum(MemoryCategory),
nullable=False,
default=MemoryCategory.fact,
)
embedding = Column(Vector(config.embedding_model_instance.dimension))
updated_at = Column(
TIMESTAMP(timezone=True),
nullable=False,
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
index=True,
)
search_space = relationship("SearchSpace", back_populates="shared_memories")
created_by = relationship("User")
class Document(BaseModel, TimestampMixin):
__tablename__ = "documents"
@ -1209,6 +1243,12 @@ class SearchSpace(BaseModel, TimestampMixin):
order_by="UserMemory.updated_at.desc()",
cascade="all, delete-orphan",
)
shared_memories = relationship(
"SharedMemory",
back_populates="search_space",
order_by="SharedMemory.updated_at.desc()",
cascade="all, delete-orphan",
)
class SearchSourceConnector(BaseModel, TimestampMixin):
@ -1258,7 +1298,7 @@ class NewLLMConfig(BaseModel, TimestampMixin):
- Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS)
- Citation toggle (enable/disable citation instructions)
Note: SURFSENSE_TOOLS_INSTRUCTIONS is always used and not configurable.
Note: Tools instructions are built by get_tools_instructions(thread_visibility) (personal vs shared memory).
"""
__tablename__ = "new_llm_configs"

View file

@ -19,6 +19,8 @@ from app.db import (
from app.schemas import (
DocumentRead,
DocumentsCreate,
DocumentStatusBatchResponse,
DocumentStatusItemRead,
DocumentStatusSchema,
DocumentTitleRead,
DocumentTitleSearchResponse,
@ -148,6 +150,7 @@ async def create_documents_file_upload(
tuple[Document, str, str]
] = [] # (document, temp_path, filename)
skipped_duplicates = 0
duplicate_document_ids: list[int] = []
# ===== PHASE 1: Create pending documents for all files =====
# This makes ALL documents visible in the UI immediately with pending status
@ -182,6 +185,7 @@ async def create_documents_file_upload(
# True duplicate — content already indexed, skip
os.unlink(temp_path)
skipped_duplicates += 1
duplicate_document_ids.append(existing.id)
continue
# Existing document is stuck (failed/pending/processing)
@ -255,6 +259,7 @@ async def create_documents_file_upload(
return {
"message": "Files uploaded for processing",
"document_ids": [doc.id for doc in created_documents],
"duplicate_document_ids": duplicate_document_ids,
"total_files": len(files),
"pending_files": len(files_to_process),
"skipped_duplicates": skipped_duplicates,
@ -678,6 +683,74 @@ async def search_document_titles(
) from e
@router.get("/documents/status", response_model=DocumentStatusBatchResponse)
async def get_documents_status(
search_space_id: int,
document_ids: str,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Batch status endpoint for documents in a search space.
Returns lightweight status info for the provided document IDs, intended for
polling async ETL progress in chat upload flows.
"""
try:
await check_permission(
session,
user,
search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space",
)
# Parse comma-separated IDs (e.g. "1,2,3")
parsed_ids = []
for raw_id in document_ids.split(","):
value = raw_id.strip()
if not value:
continue
try:
parsed_ids.append(int(value))
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid document id: {value}",
) from None
if not parsed_ids:
return DocumentStatusBatchResponse(items=[])
result = await session.execute(
select(Document).filter(
Document.search_space_id == search_space_id,
Document.id.in_(parsed_ids),
)
)
docs = result.scalars().all()
items = [
DocumentStatusItemRead(
id=doc.id,
title=doc.title,
document_type=doc.document_type,
status=DocumentStatusSchema(
state=(doc.status or {}).get("state", "ready"),
reason=(doc.status or {}).get("reason"),
),
)
for doc in docs
]
return DocumentStatusBatchResponse(items=items)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to fetch document status: {e!s}"
) from e
@router.get("/documents/type-counts")
async def get_document_type_counts(
search_space_id: int | None = None,

View file

@ -8,16 +8,11 @@ These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
- PUT /threads/{thread_id} - Update thread (rename, archive)
- DELETE /threads/{thread_id} - Delete thread
- POST /threads/{thread_id}/messages - Append message
- POST /attachments/process - Process attachments for chat context
"""
import contextlib
import os
import tempfile
import uuid
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import StreamingResponse
from sqlalchemy import func, or_
from sqlalchemy.exc import IntegrityError, OperationalError
@ -1045,12 +1040,13 @@ async def handle_new_chat(
search_space_id=request.search_space_id,
chat_id=request.chat_id,
session=session,
user_id=str(user.id), # Pass user ID for memory tools and session state
user_id=str(user.id),
llm_config_id=llm_config_id,
attachments=request.attachments,
mentioned_document_ids=request.mentioned_document_ids,
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
needs_history_bootstrap=thread.needs_history_bootstrap,
thread_visibility=thread.visibility,
current_user_display_name=user.display_name or "A team member",
),
media_type="text/event-stream",
headers={
@ -1276,11 +1272,12 @@ async def regenerate_response(
session=session,
user_id=str(user.id),
llm_config_id=llm_config_id,
attachments=request.attachments,
mentioned_document_ids=request.mentioned_document_ids,
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
checkpoint_id=target_checkpoint_id,
needs_history_bootstrap=thread.needs_history_bootstrap,
thread_visibility=thread.visibility,
current_user_display_name=user.display_name or "A team member",
):
yield chunk
# If we get here, streaming completed successfully
@ -1329,185 +1326,3 @@ async def regenerate_response(
status_code=500,
detail=f"An unexpected error occurred during regeneration: {e!s}",
) from None
# =============================================================================
# Attachment Processing Endpoint
# =============================================================================
@router.post("/attachments/process")
async def process_attachment(
file: UploadFile = File(...),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Process an attachment file and extract its content as markdown.
This endpoint uses the configured ETL service to parse files and return
the extracted content that can be used as context in chat messages.
Supported file types depend on the configured ETL_SERVICE:
- Markdown/Text files: .md, .markdown, .txt (always supported)
- Audio files: .mp3, .mp4, .mpeg, .mpga, .m4a, .wav, .webm (if STT configured)
- Documents: .pdf, .docx, .doc, .pptx, .xlsx (depends on ETL service)
Returns:
JSON with attachment id, name, type, and extracted content
"""
from app.config import config as app_config
if not file.filename:
raise HTTPException(status_code=400, detail="No filename provided")
filename = file.filename
attachment_id = str(uuid.uuid4())
try:
# Save file to a temporary location
file_ext = os.path.splitext(filename)[1].lower()
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as temp_file:
temp_path = temp_file.name
content = await file.read()
temp_file.write(content)
extracted_content = ""
# Process based on file type
if file_ext in (".md", ".markdown", ".txt"):
# For text/markdown files, read content directly
with open(temp_path, encoding="utf-8") as f:
extracted_content = f.read()
elif file_ext in (".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm"):
# Audio files - transcribe if STT service is configured
if not app_config.STT_SERVICE:
raise HTTPException(
status_code=422,
detail="Audio transcription is not configured. Please set STT_SERVICE.",
)
stt_service_type = (
"local" if app_config.STT_SERVICE.startswith("local/") else "external"
)
if stt_service_type == "local":
from app.services.stt_service import stt_service
result = stt_service.transcribe_file(temp_path)
extracted_content = result.get("text", "")
else:
from litellm import atranscription
with open(temp_path, "rb") as audio_file:
transcription_kwargs = {
"model": app_config.STT_SERVICE,
"file": audio_file,
"api_key": app_config.STT_SERVICE_API_KEY,
}
if app_config.STT_SERVICE_API_BASE:
transcription_kwargs["api_base"] = (
app_config.STT_SERVICE_API_BASE
)
transcription_response = await atranscription(
**transcription_kwargs
)
extracted_content = transcription_response.get("text", "")
if extracted_content:
extracted_content = (
f"# Transcription of {filename}\n\n{extracted_content}"
)
else:
# Document files - use configured ETL service
if app_config.ETL_SERVICE == "UNSTRUCTURED":
from langchain_unstructured import UnstructuredLoader
from app.utils.document_converters import convert_document_to_markdown
loader = UnstructuredLoader(
temp_path,
mode="elements",
post_processors=[],
languages=["eng"],
include_orig_elements=False,
include_metadata=False,
strategy="auto",
)
docs = await loader.aload()
extracted_content = await convert_document_to_markdown(docs)
elif app_config.ETL_SERVICE == "LLAMACLOUD":
from llama_cloud_services import LlamaParse
from llama_cloud_services.parse.utils import ResultType
parser = LlamaParse(
api_key=app_config.LLAMA_CLOUD_API_KEY,
num_workers=1,
verbose=False,
language="en",
result_type=ResultType.MD,
)
result = await parser.aparse(temp_path)
markdown_documents = await result.aget_markdown_documents(
split_by_page=False
)
if markdown_documents:
extracted_content = "\n\n".join(
doc.text for doc in markdown_documents
)
elif app_config.ETL_SERVICE == "DOCLING":
from app.services.docling_service import create_docling_service
docling_service = create_docling_service()
result = await docling_service.process_document(temp_path, filename)
extracted_content = result.get("content", "")
else:
raise HTTPException(
status_code=422,
detail=f"ETL service not configured or unsupported file type: {file_ext}",
)
# Clean up temp file
with contextlib.suppress(Exception):
os.unlink(temp_path)
if not extracted_content:
raise HTTPException(
status_code=422,
detail=f"Could not extract content from file: {filename}",
)
# Determine attachment type (must be one of: "image", "document", "file")
# assistant-ui only supports these three types
if file_ext in (".png", ".jpg", ".jpeg", ".gif", ".webp"):
attachment_type = "image"
else:
# All other files (including audio, documents, text) are treated as "document"
attachment_type = "document"
return {
"id": attachment_id,
"name": filename,
"type": attachment_type,
"content": extracted_content,
"contentLength": len(extracted_content),
}
except HTTPException:
raise
except Exception as e:
# Clean up temp file on error
with contextlib.suppress(Exception):
os.unlink(temp_path)
raise HTTPException(
status_code=500,
detail=f"Failed to process attachment: {e!s}",
) from e

View file

@ -19,7 +19,7 @@ Non-OAuth connectors (BookStack, GitHub, etc.) are limited to one per search spa
"""
import logging
import os
from contextlib import suppress
from datetime import UTC, datetime, timedelta
from typing import Any
@ -32,6 +32,7 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.config import config
from app.connectors.github_connector import GitHubConnector
from app.db import (
Permission,
@ -70,6 +71,10 @@ from app.tasks.connector_indexers import (
index_slack_messages,
)
from app.users import current_active_user
from app.utils.indexing_locks import (
acquire_connector_indexing_lock,
release_connector_indexing_lock,
)
from app.utils.periodic_scheduler import (
create_periodic_schedule,
delete_periodic_schedule,
@ -91,11 +96,9 @@ def get_heartbeat_redis_client() -> redis.Redis:
"""Get or create Redis client for heartbeat tracking."""
global _heartbeat_redis_client
if _heartbeat_redis_client is None:
redis_url = os.getenv(
"REDIS_APP_URL",
os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"),
_heartbeat_redis_client = redis.from_url(
config.REDIS_APP_URL, decode_responses=True
)
_heartbeat_redis_client = redis.from_url(redis_url, decode_responses=True)
return _heartbeat_redis_client
@ -1229,10 +1232,19 @@ async def _run_indexing_with_notifications(
from celery.exceptions import SoftTimeLimitExceeded
notification = None
connector_lock_acquired = False
# Track indexed count for retry notifications and heartbeat
current_indexed_count = 0
try:
connector_lock_acquired = acquire_connector_indexing_lock(connector_id)
if not connector_lock_acquired:
logger.info(
f"Skipping indexing for connector {connector_id} "
"(another worker already holds Redis connector lock)"
)
return
# Get connector info for notification
connector_result = await session.execute(
select(SearchSourceConnector).where(
@ -1558,6 +1570,9 @@ async def _run_indexing_with_notifications(
get_heartbeat_redis_client().delete(heartbeat_key)
except Exception:
pass # Ignore cleanup errors - key will expire anyway
if connector_lock_acquired:
with suppress(Exception):
release_connector_indexing_lock(connector_id)
async def run_notion_indexing_with_new_session(

View file

@ -11,6 +11,8 @@ from .documents import (
DocumentBase,
DocumentRead,
DocumentsCreate,
DocumentStatusBatchResponse,
DocumentStatusItemRead,
DocumentStatusSchema,
DocumentTitleRead,
DocumentTitleSearchResponse,
@ -105,6 +107,8 @@ __all__ = [
# Document schemas
"DocumentBase",
"DocumentRead",
"DocumentStatusBatchResponse",
"DocumentStatusItemRead",
"DocumentStatusSchema",
"DocumentTitleRead",
"DocumentTitleSearchResponse",

View file

@ -99,3 +99,20 @@ class DocumentTitleSearchResponse(BaseModel):
items: list[DocumentTitleRead]
has_more: bool
class DocumentStatusItemRead(BaseModel):
"""Lightweight document status payload for batch status polling."""
id: int
title: str
document_type: DocumentType
status: DocumentStatusSchema
model_config = ConfigDict(from_attributes=True)
class DocumentStatusBatchResponse(BaseModel):
"""Batch status response for a set of document IDs."""
items: list[DocumentStatusItemRead]

View file

@ -159,15 +159,6 @@ class ChatMessage(BaseModel):
content: str
class ChatAttachment(BaseModel):
"""An attachment with its extracted content for chat context."""
id: str # Unique attachment ID
name: str # Original filename
type: str # Attachment type: document, image, audio
content: str # Extracted markdown content from the file
class NewChatRequest(BaseModel):
"""Request schema for the deep agent chat endpoint."""
@ -175,9 +166,6 @@ class NewChatRequest(BaseModel):
user_query: str
search_space_id: int
messages: list[ChatMessage] | None = None # Optional chat history from frontend
attachments: list[ChatAttachment] | None = (
None # Optional attachments with extracted content
)
mentioned_document_ids: list[int] | None = (
None # Optional document IDs mentioned with @ in the chat
)
@ -201,7 +189,6 @@ class RegenerateRequest(BaseModel):
user_query: str | None = (
None # New user query (for edit). None = reload with same query
)
attachments: list[ChatAttachment] | None = None
mentioned_document_ids: list[int] | None = None
mentioned_surfsense_doc_ids: list[int] | None = None

View file

@ -56,6 +56,7 @@ PROVIDER_MAP = {
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai",
"GITHUB_MODELS": "github",
"HUGGINGFACE": "huggingface",
"CUSTOM": "custom",
}

View file

@ -119,6 +119,7 @@ async def validate_llm_config(
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai", # GLM needs special handling
"GITHUB_MODELS": "github",
}
provider_prefix = provider_map.get(provider, provider.lower())
model_string = f"{provider_prefix}/{model_name}"
@ -335,6 +336,7 @@ async def get_search_space_llm_instance(
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai",
"GITHUB_MODELS": "github",
}
provider_prefix = provider_map.get(
llm_config.provider.value, llm_config.provider.value.lower()

View file

@ -36,11 +36,9 @@ def _get_doc_heartbeat_redis():
global _doc_heartbeat_redis
if _doc_heartbeat_redis is None:
redis_url = os.getenv(
"REDIS_APP_URL",
os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"),
_doc_heartbeat_redis = redis.from_url(
config.REDIS_APP_URL, decode_responses=True
)
_doc_heartbeat_redis = redis.from_url(redis_url, decode_responses=True)
return _doc_heartbeat_redis

View file

@ -46,16 +46,10 @@ def get_celery_session_maker():
def _clear_generating_podcast(search_space_id: int) -> None:
"""Clear the generating podcast marker from Redis when task completes."""
import os
import redis
try:
redis_url = os.getenv(
"REDIS_APP_URL",
os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"),
)
client = redis.from_url(redis_url, decode_responses=True)
client = redis.from_url(config.REDIS_APP_URL, decode_responses=True)
key = f"podcast:generating:{search_space_id}"
client.delete(key)
logger.info(

View file

@ -9,7 +9,8 @@ from sqlalchemy.pool import NullPool
from app.celery_app import celery_app
from app.config import config
from app.db import SearchSourceConnector, SearchSourceConnectorType
from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType
from app.utils.indexing_locks import is_connector_indexing_locked
logger = logging.getLogger(__name__)
@ -107,6 +108,32 @@ async def _check_and_trigger_schedules():
# Trigger indexing for each due connector
for connector in due_connectors:
# Primary guard: Redis lock indicates a task is currently running.
if is_connector_indexing_locked(connector.id):
logger.info(
f"Skipping periodic indexing for connector {connector.id} "
"(Redis lock indicates indexing is already in progress)"
)
continue
# Skip scheduling if a sync for this connector is already in progress.
# This prevents duplicate tasks from piling up under slow/rate-limited providers.
in_progress_result = await session.execute(
select(Notification.id).where(
Notification.type == "connector_indexing",
Notification.notification_metadata["connector_id"].astext
== str(connector.id),
Notification.notification_metadata["status"].astext
== "in_progress",
)
)
if in_progress_result.first():
logger.info(
f"Skipping periodic indexing for connector {connector.id} "
"(already has in-progress indexing notification)"
)
continue
task = task_map.get(connector.connector_type)
if task:
logger.info(

View file

@ -25,7 +25,6 @@ Detection mechanism:
import contextlib
import json
import logging
import os
from datetime import UTC, datetime
import redis
@ -52,11 +51,7 @@ def get_redis_client() -> redis.Redis:
"""Get or create Redis client for heartbeat checking."""
global _redis_client
if _redis_client is None:
redis_url = os.getenv(
"REDIS_APP_URL",
os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"),
)
_redis_client = redis.from_url(redis_url, decode_responses=True)
_redis_client = redis.from_url(config.REDIS_APP_URL, decode_responses=True)
return _redis_client

View file

@ -26,9 +26,8 @@ from app.agents.new_chat.llm_config import (
load_agent_config,
load_llm_config_from_yaml,
)
from app.db import Document, SurfsenseDocsDocument
from app.db import ChatVisibility, Document, SurfsenseDocsDocument
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
from app.schemas.new_chat import ChatAttachment
from app.services.chat_session_state_service import (
clear_ai_responding,
set_ai_responding,
@ -38,23 +37,6 @@ from app.services.new_streaming_service import VercelStreamingService
from app.utils.content_utils import bootstrap_history_from_db
def format_attachments_as_context(attachments: list[ChatAttachment]) -> str:
"""Format attachments as context for the agent."""
if not attachments:
return ""
context_parts = ["<user_attachments>"]
for i, attachment in enumerate(attachments, 1):
context_parts.append(
f"<attachment index='{i}' name='{attachment.name}' type='{attachment.type}'>"
)
context_parts.append(f"<![CDATA[{attachment.content}]]>")
context_parts.append("</attachment>")
context_parts.append("</user_attachments>")
return "\n".join(context_parts)
def format_mentioned_documents_as_context(documents: list[Document]) -> str:
"""
Format mentioned documents as context for the agent.
@ -203,11 +185,12 @@ async def stream_new_chat(
session: AsyncSession,
user_id: str | None = None,
llm_config_id: int = -1,
attachments: list[ChatAttachment] | None = None,
mentioned_document_ids: list[int] | None = None,
mentioned_surfsense_doc_ids: list[int] | None = None,
checkpoint_id: str | None = None,
needs_history_bootstrap: bool = False,
thread_visibility: ChatVisibility | None = None,
current_user_display_name: str | None = None,
) -> AsyncGenerator[str, None]:
"""
Stream chat responses from the new SurfSense deep agent.
@ -222,7 +205,6 @@ async def stream_new_chat(
session: The database session
user_id: The current user's UUID string (for memory tools and session state)
llm_config_id: The LLM configuration ID (default: -1 for first global config)
attachments: Optional attachments with extracted content
needs_history_bootstrap: If True, load message history from DB (for cloned chats)
mentioned_document_ids: Optional list of document IDs mentioned with @ in the chat
mentioned_surfsense_doc_ids: Optional list of SurfSense doc IDs mentioned with @ in the chat
@ -295,17 +277,18 @@ async def stream_new_chat(
# Get the PostgreSQL checkpointer for persistent conversation memory
checkpointer = await get_checkpointer()
# Create the deep agent with checkpointer and configurable prompts
visibility = thread_visibility or ChatVisibility.PRIVATE
agent = await create_surfsense_deep_agent(
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id, # Pass user ID for memory tools
thread_id=chat_id, # Pass chat ID for podcast association
agent_config=agent_config, # Pass prompt configuration
firecrawl_api_key=firecrawl_api_key, # Pass Firecrawl API key if configured
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
)
# Build input with message history
@ -313,7 +296,9 @@ async def stream_new_chat(
# Bootstrap history for cloned chats (no LangGraph checkpoint exists yet)
if needs_history_bootstrap:
langchain_messages = await bootstrap_history_from_db(session, chat_id)
langchain_messages = await bootstrap_history_from_db(
session, chat_id, thread_visibility=visibility
)
# Clear the flag so we don't bootstrap again on next message
from app.db import NewChatThread
@ -355,13 +340,10 @@ async def stream_new_chat(
)
mentioned_surfsense_docs = list(result.scalars().all())
# Format the user query with context (attachments + mentioned documents + surfsense docs)
# Format the user query with context (mentioned documents + SurfSense docs)
final_query = user_query
context_parts = []
if attachments:
context_parts.append(format_attachments_as_context(attachments))
if mentioned_documents:
context_parts.append(
format_mentioned_documents_as_context(mentioned_documents)
@ -376,6 +358,9 @@ async def stream_new_chat(
context = "\n\n".join(context_parts)
final_query = f"{context}\n\n<user_query>{user_query}</user_query>"
if visibility == ChatVisibility.SEARCH_SPACE and current_user_display_name:
final_query = f"**[{current_user_display_name}]:** {final_query}"
# if messages:
# # Convert frontend messages to LangChain format
# for msg in messages:
@ -451,39 +436,20 @@ async def stream_new_chat(
last_active_step_id = analyze_step_id
# Determine step title and action verb based on context
if attachments and (mentioned_documents or mentioned_surfsense_docs):
last_active_step_title = "Analyzing your content"
action_verb = "Reading"
elif attachments:
last_active_step_title = "Reading your content"
action_verb = "Reading"
elif mentioned_documents or mentioned_surfsense_docs:
if mentioned_documents or mentioned_surfsense_docs:
last_active_step_title = "Analyzing referenced content"
action_verb = "Analyzing"
else:
last_active_step_title = "Understanding your request"
action_verb = "Processing"
# Build the message with inline context about attachments/documents
# Build the message with inline context about referenced documents
processing_parts = []
# Add the user query
query_text = user_query[:80] + ("..." if len(user_query) > 80 else "")
processing_parts.append(query_text)
# Add file attachment names inline
if attachments:
attachment_names = []
for attachment in attachments:
name = attachment.name
if len(name) > 30:
name = name[:27] + "..."
attachment_names.append(name)
if len(attachment_names) == 1:
processing_parts.append(f"[{attachment_names[0]}]")
else:
processing_parts.append(f"[{len(attachment_names)} files]")
# Add mentioned document names inline
if mentioned_documents:
doc_names = []

View file

@ -52,10 +52,22 @@ def safe_set_chunks(document: Document, chunks: list) -> None:
# Instead of: document.chunks = chunks (DANGEROUS!)
safe_set_chunks(document, chunks) # Always safe
"""
from sqlalchemy.orm import object_session
from sqlalchemy.orm.attributes import set_committed_value
# Keep relationship assignment lazy-load-safe.
set_committed_value(document, "chunks", chunks)
# Ensure chunk rows are actually persisted.
# set_committed_value bypasses normal unit-of-work tracking, so we need to
# explicitly attach chunk objects to the current session.
session = object_session(document)
if session is not None:
if document.id is not None:
for chunk in chunks:
chunk.document_id = document.id
session.add_all(chunks)
def parse_date_flexible(date_str: str) -> datetime:
"""

View file

@ -1,7 +1,10 @@
"""
Discord connector indexer.
Implements 2-phase document status updates for real-time UI feedback:
Implements batch indexing: groups up to DISCORD_BATCH_SIZE messages per channel
into a single document for efficient indexing and better conversational context.
Uses 2-phase document status updates for real-time UI feedback:
- Phase 1: Create all documents with 'pending' status (visible in UI immediately)
- Phase 2: Process each document: pending processing ready/failed
"""
@ -41,6 +44,72 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]]
# Heartbeat interval in seconds - update notification every 30 seconds
HEARTBEAT_INTERVAL_SECONDS = 30
# Number of messages to combine into a single document for batch indexing.
# Grouping messages improves conversational context in embeddings/chunks and
# drastically reduces the number of documents, embedding calls, and DB overhead.
DISCORD_BATCH_SIZE = 100
def _build_batch_document_string(
guild_name: str,
guild_id: str,
channel_name: str,
channel_id: str,
messages: list[dict],
) -> str:
"""
Combine multiple Discord messages into a single document string.
Each message is formatted with its timestamp and author, and all messages
are concatenated into a conversation-style document. The chunker will
later split this into overlapping windows of ~8-10 consecutive messages,
preserving conversational context in each chunk's embedding.
Args:
guild_name: Name of the Discord guild
guild_id: ID of the Discord guild
channel_name: Name of the channel
channel_id: ID of the channel
messages: List of message dicts with 'author_name', 'created_at', 'content'
Returns:
Formatted document string with metadata and conversation content
"""
first_msg_time = messages[0].get("created_at", "Unknown")
last_msg_time = messages[-1].get("created_at", "Unknown")
metadata_lines = [
f"GUILD_NAME: {guild_name}",
f"GUILD_ID: {guild_id}",
f"CHANNEL_NAME: {channel_name}",
f"CHANNEL_ID: {channel_id}",
f"MESSAGE_COUNT: {len(messages)}",
f"FIRST_MESSAGE_TIME: {first_msg_time}",
f"LAST_MESSAGE_TIME: {last_msg_time}",
]
conversation_lines = []
for msg in messages:
author = msg.get("author_name", "Unknown User")
timestamp = msg.get("created_at", "Unknown Time")
content = msg.get("content", "")
conversation_lines.append(f"[{timestamp}] {author}: {content}")
metadata_sections = [
("METADATA", metadata_lines),
(
"CONTENT",
[
"FORMAT: markdown",
"TEXT_START",
"\n".join(conversation_lines),
"TEXT_END",
],
),
]
return build_document_metadata_markdown(metadata_sections)
async def index_discord_messages(
session: AsyncSession,
@ -55,6 +124,12 @@ async def index_discord_messages(
"""
Index Discord messages from the configured guild's channels.
Messages are grouped into batches of DISCORD_BATCH_SIZE per channel,
so each document contains up to 100 consecutive messages with full
conversational context. This reduces document count, embedding calls,
and DB overhead by ~100x while improving search quality through
context-aware chunk embeddings.
Implements 2-phase document status updates for real-time UI feedback:
- Phase 1: Create all documents with 'pending' status (visible in UI immediately)
- Phase 2: Process each document: pending processing ready/failed
@ -324,6 +399,7 @@ async def index_discord_messages(
documents_skipped = 0
documents_failed = 0
duplicate_content_count = 0
total_messages_collected = 0
skipped_channels: list[str] = []
# Heartbeat tracking - update notification periodically to prevent appearing stuck
@ -340,10 +416,12 @@ async def index_discord_messages(
)
# =======================================================================
# PHASE 1: Collect all messages and create pending documents
# This makes ALL documents visible in the UI immediately with pending status
# PHASE 1: Collect messages, group into batches, and create pending documents
# Messages are grouped into batches of DISCORD_BATCH_SIZE per channel.
# Each batch becomes a single document with full conversational context.
# All documents are visible in the UI immediately with pending status.
# =======================================================================
messages_to_process = [] # List of dicts with document and message data
batches_to_process = [] # List of dicts with document and batch data
new_documents_created = False
try:
@ -394,44 +472,35 @@ async def index_discord_messages(
)
continue
# Process each message as an individual document (like Slack)
for msg in formatted_messages:
msg_id = msg.get("id", "")
msg_user_name = msg.get("author_name", "Unknown User")
msg_timestamp = msg.get("created_at", "Unknown Time")
msg_text = msg.get("content", "")
total_messages_collected += len(formatted_messages)
# Format document metadata (similar to Slack)
metadata_sections = [
(
"METADATA",
[
f"GUILD_NAME: {guild_name}",
f"GUILD_ID: {guild_id}",
f"CHANNEL_NAME: {channel_name}",
f"CHANNEL_ID: {channel_id}",
f"MESSAGE_TIMESTAMP: {msg_timestamp}",
f"MESSAGE_USER_NAME: {msg_user_name}",
],
),
(
"CONTENT",
[
"FORMAT: markdown",
"TEXT_START",
msg_text,
"TEXT_END",
],
),
# =======================================================
# Group messages into batches of DISCORD_BATCH_SIZE
# Each batch becomes a single document with conversation context
# =======================================================
for batch_start in range(
0, len(formatted_messages), DISCORD_BATCH_SIZE
):
batch = formatted_messages[
batch_start : batch_start + DISCORD_BATCH_SIZE
]
# Build the document string
combined_document_string = build_document_metadata_markdown(
metadata_sections
# Build combined document string from all messages in this batch
combined_document_string = _build_batch_document_string(
guild_name=guild_name,
guild_id=guild_id,
channel_name=channel_name,
channel_id=channel_id,
messages=batch,
)
# Generate unique identifier hash for this Discord message
unique_identifier = f"{channel_id}_{msg_id}"
# Generate unique identifier for this batch using
# channel_id + first message ID + last message ID
first_msg_id = batch[0].get("id", "")
last_msg_id = batch[-1].get("id", "")
unique_identifier = (
f"{channel_id}_{first_msg_id}_{last_msg_id}"
)
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.DISCORD_CONNECTOR,
unique_identifier,
@ -464,7 +533,7 @@ async def index_discord_messages(
continue
# Queue existing document for update (will be set to processing in Phase 2)
messages_to_process.append(
batches_to_process.append(
{
"document": existing_document,
"is_new": False,
@ -474,9 +543,15 @@ async def index_discord_messages(
"guild_id": guild_id,
"channel_name": channel_name,
"channel_id": channel_id,
"message_id": msg_id,
"message_timestamp": msg_timestamp,
"message_user_name": msg_user_name,
"first_message_id": first_msg_id,
"last_message_id": last_msg_id,
"first_message_time": batch[0].get(
"created_at", "Unknown"
),
"last_message_time": batch[-1].get(
"created_at", "Unknown"
),
"message_count": len(batch),
}
)
continue
@ -492,7 +567,7 @@ async def index_discord_messages(
if duplicate_by_content:
logger.info(
f"Discord message {msg_id} in {guild_name}#{channel_name} already indexed by another connector "
f"Discord batch ({len(batch)} msgs) in {guild_name}#{channel_name} already indexed by another connector "
f"(existing document ID: {duplicate_by_content.id}, "
f"type: {duplicate_by_content.document_type}). Skipping."
)
@ -510,7 +585,9 @@ async def index_discord_messages(
"guild_id": guild_id,
"channel_name": channel_name,
"channel_id": channel_id,
"message_id": msg_id,
"first_message_id": first_msg_id,
"last_message_id": last_msg_id,
"message_count": len(batch),
"connector_id": connector_id,
},
content="Pending...", # Placeholder until processed
@ -526,7 +603,7 @@ async def index_discord_messages(
session.add(document)
new_documents_created = True
messages_to_process.append(
batches_to_process.append(
{
"document": document,
"is_new": True,
@ -536,12 +613,23 @@ async def index_discord_messages(
"guild_id": guild_id,
"channel_name": channel_name,
"channel_id": channel_id,
"message_id": msg_id,
"message_timestamp": msg_timestamp,
"message_user_name": msg_user_name,
"first_message_id": first_msg_id,
"last_message_id": last_msg_id,
"first_message_time": batch[0].get(
"created_at", "Unknown"
),
"last_message_time": batch[-1].get(
"created_at", "Unknown"
),
"message_count": len(batch),
}
)
logger.info(
f"Phase 1: Collected {len(formatted_messages)} messages from channel {channel_name}, "
f"grouped into {(len(formatted_messages) + DISCORD_BATCH_SIZE - 1) // DISCORD_BATCH_SIZE} batch(es)"
)
except Exception as e:
logger.error(
f"Error processing guild {guild_name}: {e!s}", exc_info=True
@ -554,17 +642,18 @@ async def index_discord_messages(
# Commit all pending documents - they all appear in UI now
if new_documents_created:
logger.info(
f"Phase 1: Committing {len([m for m in messages_to_process if m['is_new']])} pending documents"
f"Phase 1: Committing {len([b for b in batches_to_process if b['is_new']])} pending batch documents "
f"({total_messages_collected} total messages across all channels)"
)
await session.commit()
# =======================================================================
# PHASE 2: Process each document one by one
# PHASE 2: Process each batch document one by one
# Each document transitions: pending → processing → ready/failed
# =======================================================================
logger.info(f"Phase 2: Processing {len(messages_to_process)} documents")
logger.info(f"Phase 2: Processing {len(batches_to_process)} batch documents")
for item in messages_to_process:
for item in batches_to_process:
# Send heartbeat periodically
if on_heartbeat_callback:
current_time = time.time()
@ -594,9 +683,11 @@ async def index_discord_messages(
"guild_id": item["guild_id"],
"channel_name": item["channel_name"],
"channel_id": item["channel_id"],
"message_id": item["message_id"],
"message_timestamp": item["message_timestamp"],
"message_user_name": item["message_user_name"],
"first_message_id": item["first_message_id"],
"last_message_id": item["last_message_id"],
"first_message_time": item["first_message_time"],
"last_message_time": item["last_message_time"],
"message_count": item["message_count"],
"indexed_at": datetime.now(UTC).strftime("%Y-%m-%d %H:%M:%S"),
"connector_id": connector_id,
}
@ -609,12 +700,14 @@ async def index_discord_messages(
# Batch commit every 10 documents (for ready status updates)
if documents_indexed % 10 == 0:
logger.info(
f"Committing batch: {documents_indexed} Discord messages processed so far"
f"Committing batch: {documents_indexed} batch documents processed so far"
)
await session.commit()
except Exception as e:
logger.error(f"Error processing Discord message: {e!s}", exc_info=True)
logger.error(
f"Error processing Discord batch document: {e!s}", exc_info=True
)
# Mark document as failed with reason (visible in UI)
try:
document.status = DocumentStatus.failed(str(e))
@ -631,7 +724,8 @@ async def index_discord_messages(
# Final commit for any remaining documents not yet committed in batches
logger.info(
f"Final commit: Total {documents_indexed} Discord messages processed"
f"Final commit: Total {documents_indexed} batch documents processed "
f"(from {total_messages_collected} messages)"
)
try:
await session.commit()
@ -672,14 +766,18 @@ async def index_discord_messages(
"documents_failed": documents_failed,
"duplicate_content_count": duplicate_content_count,
"skipped_channels_count": len(skipped_channels),
"total_messages_collected": total_messages_collected,
"batch_size": DISCORD_BATCH_SIZE,
"guild_id": guild_id,
"guild_name": guild_name,
},
)
logger.info(
f"Discord indexing completed for guild {guild_name}: {documents_indexed} ready, {documents_skipped} skipped, "
f"{documents_failed} failed ({duplicate_content_count} duplicate content)"
f"Discord indexing completed for guild {guild_name}: "
f"{documents_indexed} batch docs ready (from {total_messages_collected} messages), "
f"{documents_skipped} skipped, {documents_failed} failed "
f"({duplicate_content_count} duplicate content)"
)
return documents_indexed, warning_message

View file

@ -1,7 +1,10 @@
"""
Slack connector indexer.
Implements 2-phase document status updates for real-time UI feedback:
Implements batch indexing: groups up to SLACK_BATCH_SIZE messages per channel
into a single document for efficient indexing and better conversational context.
Uses 2-phase document status updates for real-time UI feedback:
- Phase 1: Create all documents with 'pending' status (visible in UI immediately)
- Phase 2: Process each document: pending processing ready/failed
"""
@ -42,6 +45,72 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]]
# Heartbeat interval in seconds - update notification every 30 seconds
HEARTBEAT_INTERVAL_SECONDS = 30
# Number of messages to combine into a single document for batch indexing.
# Grouping messages improves conversational context in embeddings/chunks and
# drastically reduces the number of documents, embedding calls, and DB overhead.
SLACK_BATCH_SIZE = 100
def _build_batch_document_string(
team_name: str,
team_id: str,
channel_name: str,
channel_id: str,
messages: list[dict],
) -> str:
"""
Combine multiple Slack messages into a single document string.
Each message is formatted with its timestamp and author, and all messages
are concatenated into a conversation-style document. The chunker will
later split this into overlapping windows of ~8-10 consecutive messages,
preserving conversational context in each chunk's embedding.
Args:
team_name: Name of the Slack workspace
team_id: ID of the Slack workspace
channel_name: Name of the channel
channel_id: ID of the channel
messages: List of formatted message dicts with 'user_name', 'datetime', 'text'
Returns:
Formatted document string with metadata and conversation content
"""
first_msg_time = messages[0].get("datetime", "Unknown")
last_msg_time = messages[-1].get("datetime", "Unknown")
metadata_lines = [
f"WORKSPACE_NAME: {team_name}",
f"WORKSPACE_ID: {team_id}",
f"CHANNEL_NAME: {channel_name}",
f"CHANNEL_ID: {channel_id}",
f"MESSAGE_COUNT: {len(messages)}",
f"FIRST_MESSAGE_TIME: {first_msg_time}",
f"LAST_MESSAGE_TIME: {last_msg_time}",
]
conversation_lines = []
for msg in messages:
author = msg.get("user_name", "Unknown User")
timestamp = msg.get("datetime", "Unknown Time")
content = msg.get("text", "")
conversation_lines.append(f"[{timestamp}] {author}: {content}")
metadata_sections = [
("METADATA", metadata_lines),
(
"CONTENT",
[
"FORMAT: markdown",
"TEXT_START",
"\n".join(conversation_lines),
"TEXT_END",
],
),
]
return build_document_metadata_markdown(metadata_sections)
async def index_slack_messages(
session: AsyncSession,
@ -56,6 +125,16 @@ async def index_slack_messages(
"""
Index Slack messages from all accessible channels.
Messages are grouped into batches of SLACK_BATCH_SIZE per channel,
so each document contains up to 100 consecutive messages with full
conversational context. This reduces document count, embedding calls,
and DB overhead by ~100x while improving search quality through
context-aware chunk embeddings.
Implements 2-phase document status updates for real-time UI feedback:
- Phase 1: Create all documents with 'pending' status (visible in UI immediately)
- Phase 2: Process each document: pending processing ready/failed
Args:
session: Database session
connector_id: ID of the Slack connector
@ -109,6 +188,10 @@ async def index_slack_messages(
f"Connector with ID {connector_id} not found or is not a Slack connector",
)
# Extract workspace info from connector config
team_id = connector.config.get("team_id", "")
team_name = connector.config.get("team_name", "Unknown Workspace")
# Note: Token handling is now done automatically by SlackHistory
# with auto-refresh support. We just need to pass session and connector_id.
@ -182,6 +265,8 @@ async def index_slack_messages(
documents_indexed = 0
documents_skipped = 0
documents_failed = 0 # Track messages that failed processing
duplicate_content_count = 0
total_messages_collected = 0
skipped_channels = []
# Heartbeat tracking - update notification periodically to prevent appearing stuck
@ -194,10 +279,12 @@ async def index_slack_messages(
)
# =======================================================================
# PHASE 1: Collect all messages from all channels, create pending documents
# This makes ALL documents visible in the UI immediately with pending status
# PHASE 1: Collect messages, group into batches, and create pending documents
# Messages are grouped into batches of SLACK_BATCH_SIZE per channel.
# Each batch becomes a single document with full conversational context.
# All documents are visible in the UI immediately with pending status.
# =======================================================================
messages_to_process = [] # List of dicts with document and message data
batches_to_process = [] # List of dicts with document and batch data
new_documents_created = False
for channel_obj in channels:
@ -264,40 +351,35 @@ async def index_slack_messages(
documents_skipped += 1
continue # Skip if no valid messages after filtering
for msg in formatted_messages:
timestamp = msg.get("datetime", "Unknown Time")
msg_ts = msg.get("ts", timestamp) # Get original Slack timestamp
msg_user_name = msg.get("user_name", "Unknown User")
msg_user_email = msg.get("user_email", "Unknown Email")
msg_text = msg.get("text", "")
total_messages_collected += len(formatted_messages)
# Format document metadata
metadata_sections = [
(
"METADATA",
[
f"CHANNEL_NAME: {channel_name}",
f"CHANNEL_ID: {channel_id}",
f"MESSAGE_TIMESTAMP: {timestamp}",
f"MESSAGE_USER_NAME: {msg_user_name}",
f"MESSAGE_USER_EMAIL: {msg_user_email}",
],
),
(
"CONTENT",
["FORMAT: markdown", "TEXT_START", msg_text, "TEXT_END"],
),
# =======================================================
# Group messages into batches of SLACK_BATCH_SIZE
# Each batch becomes a single document with conversation context
# =======================================================
for batch_start in range(0, len(formatted_messages), SLACK_BATCH_SIZE):
batch = formatted_messages[
batch_start : batch_start + SLACK_BATCH_SIZE
]
# Build the document string
combined_document_string = build_document_metadata_markdown(
metadata_sections
# Build combined document string from all messages in this batch
combined_document_string = _build_batch_document_string(
team_name=team_name,
team_id=team_id,
channel_name=channel_name,
channel_id=channel_id,
messages=batch,
)
# Generate unique identifier hash for this Slack message
unique_identifier = f"{channel_id}_{msg_ts}"
# Generate unique identifier for this batch using
# channel_id + first message ts + last message ts
first_msg_ts = batch[0].get("timestamp", "")
last_msg_ts = batch[-1].get("timestamp", "")
unique_identifier = f"{channel_id}_{first_msg_ts}_{last_msg_ts}"
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.SLACK_CONNECTOR, unique_identifier, search_space_id
DocumentType.SLACK_CONNECTOR,
unique_identifier,
search_space_id,
)
# Generate content hash
@ -318,25 +400,31 @@ async def index_slack_messages(
existing_document.status, DocumentStatus.READY
):
existing_document.status = DocumentStatus.ready()
logger.info(
f"Document for Slack message {msg_ts} in channel {channel_name} unchanged. Skipping."
)
documents_skipped += 1
continue
# Queue existing document for update (will be set to processing in Phase 2)
messages_to_process.append(
batches_to_process.append(
{
"document": existing_document,
"is_new": False,
"combined_document_string": combined_document_string,
"content_hash": content_hash,
"team_name": team_name,
"team_id": team_id,
"channel_name": channel_name,
"channel_id": channel_id,
"msg_ts": msg_ts,
"first_message_ts": first_msg_ts,
"last_message_ts": last_msg_ts,
"first_message_time": batch[0].get(
"datetime", "Unknown"
),
"last_message_time": batch[-1].get(
"datetime", "Unknown"
),
"message_count": len(batch),
"start_date": start_date_str,
"end_date": end_date_str,
"message_count": len(formatted_messages),
}
)
continue
@ -350,22 +438,27 @@ async def index_slack_messages(
if duplicate_by_content:
logger.info(
f"Slack message {msg_ts} in channel {channel_name} already indexed by another connector "
f"Slack batch ({len(batch)} msgs) in {team_name}#{channel_name} already indexed by another connector "
f"(existing document ID: {duplicate_by_content.id}, "
f"type: {duplicate_by_content.document_type}). Skipping."
)
duplicate_content_count += 1
documents_skipped += 1
continue
# Create new document with PENDING status (visible in UI immediately)
document = Document(
search_space_id=search_space_id,
title=channel_name,
title=f"{team_name}#{channel_name}",
document_type=DocumentType.SLACK_CONNECTOR,
document_metadata={
"team_name": team_name,
"team_id": team_id,
"channel_name": channel_name,
"channel_id": channel_id,
"msg_ts": msg_ts,
"first_message_ts": first_msg_ts,
"last_message_ts": last_msg_ts,
"message_count": len(batch),
"connector_id": connector_id,
},
content="Pending...", # Placeholder until processed
@ -381,23 +474,29 @@ async def index_slack_messages(
session.add(document)
new_documents_created = True
messages_to_process.append(
batches_to_process.append(
{
"document": document,
"is_new": True,
"combined_document_string": combined_document_string,
"content_hash": content_hash,
"team_name": team_name,
"team_id": team_id,
"channel_name": channel_name,
"channel_id": channel_id,
"msg_ts": msg_ts,
"first_message_ts": first_msg_ts,
"last_message_ts": last_msg_ts,
"first_message_time": batch[0].get("datetime", "Unknown"),
"last_message_time": batch[-1].get("datetime", "Unknown"),
"message_count": len(batch),
"start_date": start_date_str,
"end_date": end_date_str,
"message_count": len(formatted_messages),
}
)
logger.info(
f"Phase 1: Collected {len(formatted_messages)} messages from channel {channel_name}"
f"Phase 1: Collected {len(formatted_messages)} messages from channel {channel_name}, "
f"grouped into {(len(formatted_messages) + SLACK_BATCH_SIZE - 1) // SLACK_BATCH_SIZE} batch(es)"
)
except SlackApiError as slack_error:
@ -416,17 +515,18 @@ async def index_slack_messages(
# Commit all pending documents - they all appear in UI now
if new_documents_created:
logger.info(
f"Phase 1: Committing {len([m for m in messages_to_process if m['is_new']])} pending documents"
f"Phase 1: Committing {len([b for b in batches_to_process if b['is_new']])} pending batch documents "
f"({total_messages_collected} total messages across all channels)"
)
await session.commit()
# =======================================================================
# PHASE 2: Process each document one by one
# PHASE 2: Process each batch document one by one
# Each document transitions: pending → processing → ready/failed
# =======================================================================
logger.info(f"Phase 2: Processing {len(messages_to_process)} documents")
logger.info(f"Phase 2: Processing {len(batches_to_process)} batch documents")
for item in messages_to_process:
for item in batches_to_process:
# Send heartbeat periodically
if on_heartbeat_callback:
current_time = time.time()
@ -447,16 +547,22 @@ async def index_slack_messages(
)
# Update document to READY with actual content
document.title = item["channel_name"]
document.title = f"{item['team_name']}#{item['channel_name']}"
document.content = item["combined_document_string"]
document.content_hash = item["content_hash"]
document.embedding = doc_embedding
document.document_metadata = {
"team_name": item["team_name"],
"team_id": item["team_id"],
"channel_name": item["channel_name"],
"channel_id": item["channel_id"],
"first_message_ts": item["first_message_ts"],
"last_message_ts": item["last_message_ts"],
"first_message_time": item["first_message_time"],
"last_message_time": item["last_message_time"],
"message_count": item["message_count"],
"start_date": item["start_date"],
"end_date": item["end_date"],
"message_count": item["message_count"],
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"connector_id": connector_id,
}
@ -469,13 +575,13 @@ async def index_slack_messages(
# Batch commit every 10 documents (for ready status updates)
if documents_indexed % 10 == 0:
logger.info(
f"Committing batch: {documents_indexed} Slack messages processed so far"
f"Committing batch: {documents_indexed} batch documents processed so far"
)
await session.commit()
except Exception as e:
logger.error(
f"Error processing Slack message {item.get('msg_ts', 'Unknown')}: {e!s}",
f"Error processing Slack batch document: {e!s}",
exc_info=True,
)
# Mark document as failed with reason (visible in UI)
@ -493,7 +599,10 @@ async def index_slack_messages(
await update_connector_last_indexed(session, connector, update_last_indexed)
# Final commit for any remaining documents not yet committed in batches
logger.info(f"Final commit: Total {documents_indexed} Slack messages processed")
logger.info(
f"Final commit: Total {documents_indexed} batch documents processed "
f"(from {total_messages_collected} messages)"
)
try:
await session.commit()
logger.info("Successfully committed all Slack document changes to database")
@ -514,8 +623,12 @@ async def index_slack_messages(
# Build warning message if there were issues
warning_parts = []
if duplicate_content_count > 0:
warning_parts.append(f"{duplicate_content_count} duplicate")
if documents_failed > 0:
warning_parts.append(f"{documents_failed} failed")
if skipped_channels:
warning_parts.append(f"{len(skipped_channels)} channels skipped")
warning_message = ", ".join(warning_parts) if warning_parts else None
# Log success
@ -527,13 +640,20 @@ async def index_slack_messages(
"documents_indexed": documents_indexed,
"documents_skipped": documents_skipped,
"documents_failed": documents_failed,
"duplicate_content_count": duplicate_content_count,
"skipped_channels_count": len(skipped_channels),
"total_messages_collected": total_messages_collected,
"batch_size": SLACK_BATCH_SIZE,
"team_id": team_id,
"team_name": team_name,
},
)
logger.info(
f"Slack indexing completed: {documents_indexed} ready, "
f"{documents_skipped} skipped, {documents_failed} failed"
f"Slack indexing completed for workspace {team_name}: "
f"{documents_indexed} batch docs ready (from {total_messages_collected} messages), "
f"{documents_skipped} skipped, {documents_failed} failed "
f"({duplicate_content_count} duplicate content)"
)
return documents_indexed, warning_message

View file

@ -1,7 +1,10 @@
"""
Microsoft Teams connector indexer.
Implements 2-phase document status updates for real-time UI feedback:
Implements batch indexing: groups up to TEAMS_BATCH_SIZE messages per channel
into a single document for efficient indexing and better conversational context.
Uses 2-phase document status updates for real-time UI feedback:
- Phase 1: Create all documents with 'pending' status (visible in UI immediately)
- Phase 2: Process each document: pending processing ready/failed
"""
@ -41,6 +44,72 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]]
# Heartbeat interval in seconds - update notification every 30 seconds
HEARTBEAT_INTERVAL_SECONDS = 30
# Number of messages to combine into a single document for batch indexing.
# Grouping messages improves conversational context in embeddings/chunks and
# drastically reduces the number of documents, embedding calls, and DB overhead.
TEAMS_BATCH_SIZE = 100
def _build_batch_document_string(
team_name: str,
team_id: str,
channel_name: str,
channel_id: str,
messages: list[dict],
) -> str:
"""
Combine multiple Teams messages into a single document string.
Each message is formatted with its timestamp and author, and all messages
are concatenated into a conversation-style document. The chunker will
later split this into overlapping windows of ~8-10 consecutive messages,
preserving conversational context in each chunk's embedding.
Args:
team_name: Name of the Microsoft Team
team_id: ID of the Microsoft Team
channel_name: Name of the channel
channel_id: ID of the channel
messages: List of formatted message dicts with 'user_name', 'created_datetime', 'content'
Returns:
Formatted document string with metadata and conversation content
"""
first_msg_time = messages[0].get("created_datetime", "Unknown")
last_msg_time = messages[-1].get("created_datetime", "Unknown")
metadata_lines = [
f"TEAM_NAME: {team_name}",
f"TEAM_ID: {team_id}",
f"CHANNEL_NAME: {channel_name}",
f"CHANNEL_ID: {channel_id}",
f"MESSAGE_COUNT: {len(messages)}",
f"FIRST_MESSAGE_TIME: {first_msg_time}",
f"LAST_MESSAGE_TIME: {last_msg_time}",
]
conversation_lines = []
for msg in messages:
author = msg.get("user_name", "Unknown User")
timestamp = msg.get("created_datetime", "Unknown Time")
content = msg.get("content", "")
conversation_lines.append(f"[{timestamp}] {author}: {content}")
metadata_sections = [
("METADATA", metadata_lines),
(
"CONTENT",
[
"FORMAT: markdown",
"TEXT_START",
"\n".join(conversation_lines),
"TEXT_END",
],
),
]
return build_document_metadata_markdown(metadata_sections)
async def index_teams_messages(
session: AsyncSession,
@ -55,6 +124,12 @@ async def index_teams_messages(
"""
Index Microsoft Teams messages from all accessible teams and channels.
Messages are grouped into batches of TEAMS_BATCH_SIZE per channel,
so each document contains up to 100 consecutive messages with full
conversational context. This reduces document count, embedding calls,
and DB overhead by ~100x while improving search quality through
context-aware chunk embeddings.
Implements 2-phase document status updates for real-time UI feedback:
- Phase 1: Create all documents with 'pending' status (visible in UI immediately)
- Phase 2: Process each document: pending processing ready/failed
@ -184,6 +259,7 @@ async def index_teams_messages(
documents_skipped = 0
documents_failed = 0
duplicate_content_count = 0
total_messages_collected = 0
skipped_channels = []
# Heartbeat tracking - update notification periodically to prevent appearing stuck
@ -199,21 +275,21 @@ async def index_teams_messages(
start_datetime = None
end_datetime = None
if start_date_str:
# Parse as naive datetime and make it timezone-aware (UTC)
start_datetime = datetime.strptime(start_date_str, "%Y-%m-%d").replace(
tzinfo=UTC
)
if end_date_str:
# Parse as naive datetime, set to end of day, and make it timezone-aware (UTC)
end_datetime = datetime.strptime(end_date_str, "%Y-%m-%d").replace(
hour=23, minute=59, second=59, tzinfo=UTC
)
# =======================================================================
# PHASE 1: Collect all messages and create pending documents
# This makes ALL documents visible in the UI immediately with pending status
# PHASE 1: Collect messages, group into batches, and create pending documents
# Messages are grouped into batches of TEAMS_BATCH_SIZE per channel.
# Each batch becomes a single document with full conversational context.
# All documents are visible in the UI immediately with pending status.
# =======================================================================
messages_to_process = [] # List of dicts with document and message data
batches_to_process = [] # List of dicts with document and batch data
new_documents_created = False
for team in teams:
@ -251,65 +327,72 @@ async def index_teams_messages(
)
continue
# Process each message
# Format messages for batching
formatted_messages = []
for msg in messages:
# Skip deleted messages or empty content
if msg.get("deletedDateTime"):
continue
# Extract message details
message_id = msg.get("id", "")
created_datetime = msg.get("createdDateTime", "")
from_user = msg.get("from", {})
user_name = from_user.get("user", {}).get(
"displayName", "Unknown User"
)
user_email = from_user.get("user", {}).get(
"userPrincipalName", "Unknown Email"
)
# Extract message content
body = msg.get("body", {})
content_type = body.get("contentType", "text")
msg_text = body.get("content", "")
# Skip empty messages
if not msg_text or msg_text.strip() == "":
continue
# Format document metadata
metadata_sections = [
(
"METADATA",
[
f"TEAM_NAME: {team_name}",
f"TEAM_ID: {team_id}",
f"CHANNEL_NAME: {channel_name}",
f"CHANNEL_ID: {channel_id}",
f"MESSAGE_TIMESTAMP: {created_datetime}",
f"MESSAGE_USER_NAME: {user_name}",
f"MESSAGE_USER_EMAIL: {user_email}",
f"CONTENT_TYPE: {content_type}",
],
),
(
"CONTENT",
[
f"FORMAT: {content_type}",
"TEXT_START",
msg_text,
"TEXT_END",
],
),
]
# Build the document string
combined_document_string = build_document_metadata_markdown(
metadata_sections
formatted_messages.append(
{
"message_id": msg.get("id", ""),
"created_datetime": msg.get("createdDateTime", ""),
"user_name": user_name,
"content": msg_text,
}
)
# Generate unique identifier hash for this Teams message
unique_identifier = f"{team_id}_{channel_id}_{message_id}"
if not formatted_messages:
logger.info(
"No valid messages found in channel %s of team %s after filtering.",
channel_name,
team_name,
)
documents_skipped += 1
continue
total_messages_collected += len(formatted_messages)
# =======================================================
# Group messages into batches of TEAMS_BATCH_SIZE
# Each batch becomes a single document with conversation context
# =======================================================
for batch_start in range(
0, len(formatted_messages), TEAMS_BATCH_SIZE
):
batch = formatted_messages[
batch_start : batch_start + TEAMS_BATCH_SIZE
]
# Build combined document string from all messages in this batch
combined_document_string = _build_batch_document_string(
team_name=team_name,
team_id=team_id,
channel_name=channel_name,
channel_id=channel_id,
messages=batch,
)
# Generate unique identifier for this batch using
# team_id + channel_id + first message id + last message id
first_msg_id = batch[0].get("message_id", "")
last_msg_id = batch[-1].get("message_id", "")
unique_identifier = (
f"{team_id}_{channel_id}_{first_msg_id}_{last_msg_id}"
)
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.TEAMS_CONNECTOR,
unique_identifier,
@ -331,7 +414,6 @@ async def index_teams_messages(
if existing_document:
# Document exists - check if content has changed
if existing_document.content_hash == content_hash:
# Ensure status is ready (might have been stuck in processing/pending)
if not DocumentStatus.is_state(
existing_document.status, DocumentStatus.READY
):
@ -342,7 +424,7 @@ async def index_teams_messages(
continue
# Queue existing document for update (will be set to processing in Phase 2)
messages_to_process.append(
batches_to_process.append(
{
"document": existing_document,
"is_new": False,
@ -352,14 +434,21 @@ async def index_teams_messages(
"team_id": team_id,
"channel_name": channel_name,
"channel_id": channel_id,
"message_id": message_id,
"first_message_id": first_msg_id,
"last_message_id": last_msg_id,
"first_message_time": batch[0].get(
"created_datetime", "Unknown"
),
"last_message_time": batch[-1].get(
"created_datetime", "Unknown"
),
"message_count": len(batch),
"start_date": start_date_str,
"end_date": end_date_str,
}
)
continue
# Document doesn't exist by unique_identifier_hash
# Check if a document with the same content_hash exists (from another connector)
with session.no_autoflush:
duplicate_by_content = (
@ -370,9 +459,10 @@ async def index_teams_messages(
if duplicate_by_content:
logger.info(
"Teams message %s in channel %s already indexed by another connector "
"Teams batch (%s msgs) in %s/%s already indexed by another connector "
"(existing document ID: %s, type: %s). Skipping.",
message_id,
len(batch),
team_name,
channel_name,
duplicate_by_content.id,
duplicate_by_content.document_type,
@ -391,6 +481,9 @@ async def index_teams_messages(
"team_id": team_id,
"channel_name": channel_name,
"channel_id": channel_id,
"first_message_id": first_msg_id,
"last_message_id": last_msg_id,
"message_count": len(batch),
"connector_id": connector_id,
},
content="Pending...", # Placeholder until processed
@ -406,7 +499,7 @@ async def index_teams_messages(
session.add(document)
new_documents_created = True
messages_to_process.append(
batches_to_process.append(
{
"document": document,
"is_new": True,
@ -416,12 +509,30 @@ async def index_teams_messages(
"team_id": team_id,
"channel_name": channel_name,
"channel_id": channel_id,
"message_id": message_id,
"first_message_id": first_msg_id,
"last_message_id": last_msg_id,
"first_message_time": batch[0].get(
"created_datetime", "Unknown"
),
"last_message_time": batch[-1].get(
"created_datetime", "Unknown"
),
"message_count": len(batch),
"start_date": start_date_str,
"end_date": end_date_str,
}
)
logger.info(
"Phase 1: Collected %s messages from %s/%s, "
"grouped into %s batch(es)",
len(formatted_messages),
team_name,
channel_name,
(len(formatted_messages) + TEAMS_BATCH_SIZE - 1)
// TEAMS_BATCH_SIZE,
)
except Exception as e:
logger.error(
"Error processing channel %s in team %s: %s",
@ -441,17 +552,20 @@ async def index_teams_messages(
# Commit all pending documents - they all appear in UI now
if new_documents_created:
logger.info(
f"Phase 1: Committing {len([m for m in messages_to_process if m['is_new']])} pending documents"
"Phase 1: Committing %s pending batch documents "
"(%s total messages across all channels)",
len([b for b in batches_to_process if b["is_new"]]),
total_messages_collected,
)
await session.commit()
# =======================================================================
# PHASE 2: Process each document one by one
# PHASE 2: Process each batch document one by one
# Each document transitions: pending → processing → ready/failed
# =======================================================================
logger.info(f"Phase 2: Processing {len(messages_to_process)} documents")
logger.info("Phase 2: Processing %s batch documents", len(batches_to_process))
for item in messages_to_process:
for item in batches_to_process:
# Send heartbeat periodically
if on_heartbeat_callback:
current_time = time.time()
@ -481,6 +595,11 @@ async def index_teams_messages(
"team_id": item["team_id"],
"channel_name": item["channel_name"],
"channel_id": item["channel_id"],
"first_message_id": item["first_message_id"],
"last_message_id": item["last_message_id"],
"first_message_time": item["first_message_time"],
"last_message_time": item["last_message_time"],
"message_count": item["message_count"],
"start_date": item["start_date"],
"end_date": item["end_date"],
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
@ -495,20 +614,25 @@ async def index_teams_messages(
# Batch commit every 10 documents (for ready status updates)
if documents_indexed % 10 == 0:
logger.info(
"Committing batch: %s Teams messages processed so far",
"Committing batch: %s Teams batch documents processed so far",
documents_indexed,
)
await session.commit()
except Exception as e:
logger.error(f"Error processing Teams message: {e!s}", exc_info=True)
logger.error(
"Error processing Teams batch document: %s",
str(e),
exc_info=True,
)
# Mark document as failed with reason (visible in UI)
try:
document.status = DocumentStatus.failed(str(e))
document.updated_at = get_current_timestamp()
except Exception as status_error:
logger.error(
f"Failed to update document status to failed: {status_error}"
"Failed to update document status to failed: %s",
str(status_error),
)
documents_failed += 1
continue
@ -518,7 +642,9 @@ async def index_teams_messages(
# Final commit for any remaining documents not yet committed in batches
logger.info(
"Final commit: Total %s Teams messages processed", documents_indexed
"Final commit: Total %s Teams batch documents processed (from %s messages)",
documents_indexed,
total_messages_collected,
)
try:
await session.commit()
@ -530,8 +656,9 @@ async def index_teams_messages(
or "uniqueviolationerror" in str(e).lower()
):
logger.warning(
f"Duplicate content_hash detected during final commit. "
f"Rolling back and continuing. Error: {e!s}"
"Duplicate content_hash detected during final commit. "
"Rolling back and continuing. Error: %s",
str(e),
)
await session.rollback()
else:
@ -557,13 +684,16 @@ async def index_teams_messages(
"documents_failed": documents_failed,
"duplicate_content_count": duplicate_content_count,
"skipped_channels_count": len(skipped_channels),
"total_messages_collected": total_messages_collected,
"batch_size": TEAMS_BATCH_SIZE,
},
)
logger.info(
"Teams indexing completed: %s ready, %s skipped, %s failed "
"(%s duplicate content)",
"Teams indexing completed: %s batch docs ready (from %s messages), "
"%s skipped, %s failed (%s duplicate content)",
documents_indexed,
total_messages_collected,
documents_skipped,
documents_failed,
duplicate_content_count,

View file

@ -38,10 +38,22 @@ def safe_set_chunks(document: Document, chunks: list) -> None:
# Instead of: document.chunks = chunks (DANGEROUS!)
safe_set_chunks(document, chunks) # Always safe
"""
from sqlalchemy.orm import object_session
from sqlalchemy.orm.attributes import set_committed_value
# Keep relationship assignment lazy-load-safe.
set_committed_value(document, "chunks", chunks)
# Ensure chunk rows are actually persisted.
# set_committed_value bypasses normal unit-of-work tracking, so we need to
# explicitly attach chunk objects to the current session.
session = object_session(document)
if session is not None:
if document.id is not None:
for chunk in chunks:
chunk.document_id = document.id
session.add_all(chunks)
def get_current_timestamp() -> datetime:
"""

View file

@ -9,9 +9,17 @@ Message content in new_chat_messages can be stored in various formats:
These utilities help extract and transform content for different use cases.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from langchain_core.messages import AIMessage, HumanMessage
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
if TYPE_CHECKING:
from app.db import ChatVisibility
def extract_text_content(content: str | dict | list) -> str:
@ -38,6 +46,7 @@ def extract_text_content(content: str | dict | list) -> str:
async def bootstrap_history_from_db(
session: AsyncSession,
thread_id: int,
thread_visibility: ChatVisibility | None = None,
) -> list[HumanMessage | AIMessage]:
"""
Load message history from database and convert to LangChain format.
@ -45,20 +54,28 @@ async def bootstrap_history_from_db(
Used for cloned chats where the LangGraph checkpointer has no state,
but we have messages in the database that should be used as context.
When thread_visibility is SEARCH_SPACE, user messages are prefixed with
the author's display name so the LLM sees who said what.
Args:
session: Database session
thread_id: The chat thread ID
thread_visibility: When SEARCH_SPACE, user messages get author prefix
Returns:
List of LangChain messages (HumanMessage/AIMessage)
"""
from app.db import NewChatMessage
from app.db import ChatVisibility, NewChatMessage
result = await session.execute(
is_shared = thread_visibility == ChatVisibility.SEARCH_SPACE
stmt = (
select(NewChatMessage)
.filter(NewChatMessage.thread_id == thread_id)
.order_by(NewChatMessage.created_at)
)
if is_shared:
stmt = stmt.options(selectinload(NewChatMessage.author))
result = await session.execute(stmt)
db_messages = result.scalars().all()
langchain_messages: list[HumanMessage | AIMessage] = []
@ -68,6 +85,11 @@ async def bootstrap_history_from_db(
if not text_content:
continue
if msg.role == "user":
if is_shared:
author_name = (
msg.author.display_name if msg.author else None
) or "A team member"
text_content = f"**[{author_name}]:** {text_content}"
langchain_messages.append(HumanMessage(content=text_content))
elif msg.role == "assistant":
langchain_messages.append(AIMessage(content=text_content))

View file

@ -0,0 +1,46 @@
"""Redis-based connector indexing locks to prevent duplicate sync tasks."""
import redis
from app.config import config
_redis_client: redis.Redis | None = None
LOCK_TTL_SECONDS = config.CONNECTOR_INDEXING_LOCK_TTL_SECONDS
def get_indexing_lock_redis_client() -> redis.Redis:
"""Get or create Redis client for connector indexing locks."""
global _redis_client
if _redis_client is None:
_redis_client = redis.from_url(config.REDIS_APP_URL, decode_responses=True)
return _redis_client
def _get_connector_lock_key(connector_id: int) -> str:
"""Generate Redis key for a connector indexing lock."""
return f"indexing:connector_lock:{connector_id}"
def acquire_connector_indexing_lock(connector_id: int) -> bool:
"""Acquire lock for connector indexing. Returns True if acquired."""
key = _get_connector_lock_key(connector_id)
return bool(
get_indexing_lock_redis_client().set(
key,
"1",
nx=True,
ex=LOCK_TTL_SECONDS,
)
)
def release_connector_indexing_lock(connector_id: int) -> None:
"""Release lock for connector indexing."""
key = _get_connector_lock_key(connector_id)
get_indexing_lock_redis_client().delete(key)
def is_connector_indexing_locked(connector_id: int) -> bool:
"""Check if connector indexing lock exists."""
key = _get_connector_lock_key(connector_id)
return bool(get_indexing_lock_redis_client().exists(key))