mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-25 19:15:18 +02:00
Merge remote-tracking branch 'upstream/dev' into fix/auth
This commit is contained in:
commit
2dec643cb4
80 changed files with 2968 additions and 2379 deletions
|
|
@ -22,6 +22,7 @@ from app.agents.new_chat.system_prompt import (
|
|||
build_surfsense_system_prompt,
|
||||
)
|
||||
from app.agents.new_chat.tools.registry import build_tools_async
|
||||
from app.db import ChatVisibility
|
||||
from app.services.connector_service import ConnectorService
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -126,6 +127,7 @@ async def create_surfsense_deep_agent(
|
|||
disabled_tools: list[str] | None = None,
|
||||
additional_tools: Sequence[BaseTool] | None = None,
|
||||
firecrawl_api_key: str | None = None,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
):
|
||||
"""
|
||||
Create a SurfSense deep agent with configurable tools and prompts.
|
||||
|
|
@ -228,14 +230,15 @@ async def create_surfsense_deep_agent(
|
|||
logging.warning(f"Failed to discover available connectors/document types: {e}")
|
||||
|
||||
# Build dependencies dict for the tools registry
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
dependencies = {
|
||||
"search_space_id": search_space_id,
|
||||
"db_session": db_session,
|
||||
"connector_service": connector_service,
|
||||
"firecrawl_api_key": firecrawl_api_key,
|
||||
"user_id": user_id, # Required for memory tools
|
||||
"thread_id": thread_id, # For podcast tool
|
||||
# Dynamic connector/document type discovery for knowledge base tool
|
||||
"user_id": user_id,
|
||||
"thread_id": thread_id,
|
||||
"thread_visibility": visibility,
|
||||
"available_connectors": available_connectors,
|
||||
"available_document_types": available_document_types,
|
||||
}
|
||||
|
|
@ -255,10 +258,12 @@ async def create_surfsense_deep_agent(
|
|||
custom_system_instructions=agent_config.system_instructions,
|
||||
use_default_system_instructions=agent_config.use_default_system_instructions,
|
||||
citations_enabled=agent_config.citations_enabled,
|
||||
thread_visibility=thread_visibility,
|
||||
)
|
||||
else:
|
||||
# Use default prompt (with citations enabled)
|
||||
system_prompt = build_surfsense_system_prompt()
|
||||
system_prompt = build_surfsense_system_prompt(
|
||||
thread_visibility=thread_visibility,
|
||||
)
|
||||
|
||||
# Create the deep agent with system prompt and checkpointer
|
||||
# Note: TodoListMiddleware (write_todos) is included by default in create_deep_agent
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ PROVIDER_MAP = {
|
|||
"ALIBABA_QWEN": "openai",
|
||||
"MOONSHOT": "openai",
|
||||
"ZHIPU": "openai",
|
||||
"GITHUB_MODELS": "github",
|
||||
"REPLICATE": "replicate",
|
||||
"PERPLEXITY": "perplexity",
|
||||
"ANYSCALE": "anyscale",
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ The prompt is composed of three parts:
|
|||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.db import ChatVisibility
|
||||
|
||||
# Default system instructions - can be overridden via NewLLMConfig.system_instructions
|
||||
SURFSENSE_SYSTEM_INSTRUCTIONS = """
|
||||
<system_instruction>
|
||||
|
|
@ -22,7 +24,34 @@ Today's date (UTC): {resolved_today}
|
|||
</system_instruction>
|
||||
"""
|
||||
|
||||
SURFSENSE_TOOLS_INSTRUCTIONS = """
|
||||
# Default system instructions for shared (team) threads: team context + message format for attribution
|
||||
_SYSTEM_INSTRUCTIONS_SHARED = """
|
||||
<system_instruction>
|
||||
You are SurfSense, a reasoning and acting AI agent designed to answer questions in this team space using the team's shared knowledge base.
|
||||
|
||||
In this team thread, each message is prefixed with **[DisplayName of the author]**. Use this to attribute and reference the author of anything in the discussion (who asked a question, made a suggestion, or contributed an idea) and to cite who said what in your answers.
|
||||
|
||||
Today's date (UTC): {resolved_today}
|
||||
|
||||
</system_instruction>
|
||||
"""
|
||||
|
||||
|
||||
def _get_system_instructions(
|
||||
thread_visibility: ChatVisibility | None = None, today: datetime | None = None
|
||||
) -> str:
|
||||
"""Build system instructions based on thread visibility (private vs shared)."""
|
||||
|
||||
resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat()
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
if visibility == ChatVisibility.SEARCH_SPACE:
|
||||
return _SYSTEM_INSTRUCTIONS_SHARED.format(resolved_today=resolved_today)
|
||||
else:
|
||||
return SURFSENSE_SYSTEM_INSTRUCTIONS.format(resolved_today=resolved_today)
|
||||
|
||||
|
||||
# Tools 0-6 (common to both private and shared prompts)
|
||||
_TOOLS_INSTRUCTIONS_COMMON = """
|
||||
<tools>
|
||||
You have access to the following tools:
|
||||
|
||||
|
|
@ -138,7 +167,11 @@ You have access to the following tools:
|
|||
* Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content.
|
||||
* Don't show every image - just the most relevant 1-3 images that enhance understanding.
|
||||
|
||||
7. save_memory: Save facts, preferences, or context about the user for personalized responses.
|
||||
"""
|
||||
|
||||
# Private (user) memory: tools 7-8 + memory-specific examples
|
||||
_TOOLS_INSTRUCTIONS_MEMORY_PRIVATE = """
|
||||
7. save_memory: Save facts, preferences, or context for personalized responses.
|
||||
- Use this when the user explicitly or implicitly shares information worth remembering.
|
||||
- Trigger scenarios:
|
||||
* User says "remember this", "keep this in mind", "note that", or similar
|
||||
|
|
@ -178,6 +211,75 @@ You have access to the following tools:
|
|||
stating "Based on your memory..." - integrate the context seamlessly.
|
||||
</tools>
|
||||
<tool_call_examples>
|
||||
- User: "Remember that I prefer TypeScript over JavaScript"
|
||||
- Call: `save_memory(content="User prefers TypeScript over JavaScript for development", category="preference")`
|
||||
|
||||
- User: "I'm a data scientist working on ML pipelines"
|
||||
- Call: `save_memory(content="User is a data scientist working on ML pipelines", category="fact")`
|
||||
|
||||
- User: "Always give me code examples in Python"
|
||||
- Call: `save_memory(content="User wants code examples to be written in Python", category="instruction")`
|
||||
|
||||
- User: "What programming language should I use for this project?"
|
||||
- First recall: `recall_memory(query="programming language preferences")`
|
||||
- Then provide a personalized recommendation based on their preferences
|
||||
|
||||
- User: "What do you know about me?"
|
||||
- Call: `recall_memory(top_k=10)`
|
||||
- Then summarize the stored memories
|
||||
|
||||
"""
|
||||
|
||||
# Shared (team) memory: tools 7-8 + team memory examples
|
||||
_TOOLS_INSTRUCTIONS_MEMORY_SHARED = """
|
||||
7. save_memory: Save a fact, preference, or context to the team's shared memory for future reference.
|
||||
- Use this when the user or a team member says "remember this", "keep this in mind", or similar in this shared chat.
|
||||
- Use when the team agrees on something to remember (e.g., decisions, conventions).
|
||||
- Someone shares a preference or fact that should be visible to the whole team.
|
||||
- The saved information will be available in future shared conversations in this space.
|
||||
- Args:
|
||||
- content: The fact/preference/context to remember. Phrase it clearly, e.g. "API keys are stored in Vault", "The team prefers weekly demos on Fridays"
|
||||
- category: Type of memory. One of:
|
||||
* "preference": Team or workspace preferences
|
||||
* "fact": Facts the team agreed on (e.g., processes, locations)
|
||||
* "instruction": Standing instructions for the team
|
||||
* "context": Current context (e.g., ongoing projects, goals)
|
||||
- Returns: Confirmation of saved memory; returned context may include who added it (added_by).
|
||||
- IMPORTANT: Only save information that would be genuinely useful for future team conversations in this space.
|
||||
|
||||
8. recall_memory: Recall relevant team memories for this space to provide contextual responses.
|
||||
- Use when you need team context to answer (e.g., "where do we store X?", "what did we decide about Y?").
|
||||
- Use when someone asks about something the team agreed to remember.
|
||||
- Use when team preferences or conventions would improve the response.
|
||||
- Args:
|
||||
- query: Optional search query to find specific memories. If not provided, returns the most recent memories.
|
||||
- category: Optional filter by category ("preference", "fact", "instruction", "context")
|
||||
- top_k: Number of memories to retrieve (default: 5, max: 20)
|
||||
- Returns: Relevant team memories and formatted context (may include added_by). Integrate naturally without saying "Based on team memory...".
|
||||
</tools>
|
||||
<tool_call_examples>
|
||||
- User: "Remember that API keys are stored in Vault"
|
||||
- Call: `save_memory(content="API keys are stored in Vault", category="fact")`
|
||||
|
||||
- User: "Let's remember that the team prefers weekly demos on Fridays"
|
||||
- Call: `save_memory(content="The team prefers weekly demos on Fridays", category="preference")`
|
||||
|
||||
- User: "What did we decide about the release date?"
|
||||
- First recall: `recall_memory(query="release date decision")`
|
||||
- Then answer based on the team memories
|
||||
|
||||
- User: "Where do we document onboarding?"
|
||||
- Call: `recall_memory(query="onboarding documentation")`
|
||||
- Then answer using the recalled team context
|
||||
|
||||
- User: "What have we agreed to remember about our deployment process?"
|
||||
- Call: `recall_memory(query="deployment process", top_k=10)`
|
||||
- Then summarize the relevant team memories
|
||||
|
||||
"""
|
||||
|
||||
# Examples shared by both private and shared prompts (knowledge base, docs, podcast, links, images, etc.)
|
||||
_TOOLS_INSTRUCTIONS_EXAMPLES_COMMON = """
|
||||
- User: "What time is the team meeting today?"
|
||||
- Call: `search_knowledge_base(query="team meeting time today")` (searches ALL sources - calendar, notes, Obsidian, etc.)
|
||||
- DO NOT limit to just calendar - the info might be in notes!
|
||||
|
|
@ -209,23 +311,6 @@ You have access to the following tools:
|
|||
- User: "What's in my Obsidian vault about project ideas?"
|
||||
- Call: `search_knowledge_base(query="project ideas", connectors_to_search=["OBSIDIAN_CONNECTOR"])`
|
||||
|
||||
- User: "Remember that I prefer TypeScript over JavaScript"
|
||||
- Call: `save_memory(content="User prefers TypeScript over JavaScript for development", category="preference")`
|
||||
|
||||
- User: "I'm a data scientist working on ML pipelines"
|
||||
- Call: `save_memory(content="User is a data scientist working on ML pipelines", category="fact")`
|
||||
|
||||
- User: "Always give me code examples in Python"
|
||||
- Call: `save_memory(content="User wants code examples to be written in Python", category="instruction")`
|
||||
|
||||
- User: "What programming language should I use for this project?"
|
||||
- First recall: `recall_memory(query="programming language preferences")`
|
||||
- Then provide a personalized recommendation based on their preferences
|
||||
|
||||
- User: "What do you know about me?"
|
||||
- Call: `recall_memory(top_k=10)`
|
||||
- Then summarize the stored memories
|
||||
|
||||
- User: "Give me a podcast about AI trends based on what we discussed"
|
||||
- First search for relevant content, then call: `generate_podcast(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", podcast_title="AI Trends Podcast")`
|
||||
|
||||
|
|
@ -315,6 +400,31 @@ You have access to the following tools:
|
|||
</tool_call_examples>
|
||||
"""
|
||||
|
||||
# Reassemble so existing callers see no change (same full prompt)
|
||||
SURFSENSE_TOOLS_INSTRUCTIONS = (
|
||||
_TOOLS_INSTRUCTIONS_COMMON
|
||||
+ _TOOLS_INSTRUCTIONS_MEMORY_PRIVATE
|
||||
+ _TOOLS_INSTRUCTIONS_EXAMPLES_COMMON
|
||||
)
|
||||
|
||||
|
||||
def _get_tools_instructions(thread_visibility: ChatVisibility | None = None) -> str:
|
||||
"""Build tools instructions based on thread visibility (private vs shared).
|
||||
|
||||
For private chats: use user-focused memory wording and examples.
|
||||
For shared chats: use team memory wording and examples.
|
||||
"""
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
memory_block = (
|
||||
_TOOLS_INSTRUCTIONS_MEMORY_SHARED
|
||||
if visibility == ChatVisibility.SEARCH_SPACE
|
||||
else _TOOLS_INSTRUCTIONS_MEMORY_PRIVATE
|
||||
)
|
||||
return (
|
||||
_TOOLS_INSTRUCTIONS_COMMON + memory_block + _TOOLS_INSTRUCTIONS_EXAMPLES_COMMON
|
||||
)
|
||||
|
||||
|
||||
SURFSENSE_CITATION_INSTRUCTIONS = """
|
||||
<citation_instructions>
|
||||
CRITICAL CITATION REQUIREMENTS:
|
||||
|
|
@ -413,6 +523,7 @@ Your goal is to provide helpful, informative answers in a clean, readable format
|
|||
|
||||
def build_surfsense_system_prompt(
|
||||
today: datetime | None = None,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build the SurfSense system prompt with default settings.
|
||||
|
|
@ -424,17 +535,17 @@ def build_surfsense_system_prompt(
|
|||
|
||||
Args:
|
||||
today: Optional datetime for today's date (defaults to current UTC date)
|
||||
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
||||
|
||||
Returns:
|
||||
Complete system prompt string
|
||||
"""
|
||||
resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat()
|
||||
|
||||
return (
|
||||
SURFSENSE_SYSTEM_INSTRUCTIONS.format(resolved_today=resolved_today)
|
||||
+ SURFSENSE_TOOLS_INSTRUCTIONS
|
||||
+ SURFSENSE_CITATION_INSTRUCTIONS
|
||||
)
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
system_instructions = _get_system_instructions(visibility, today)
|
||||
tools_instructions = _get_tools_instructions(visibility)
|
||||
citation_instructions = SURFSENSE_CITATION_INSTRUCTIONS
|
||||
return system_instructions + tools_instructions + citation_instructions
|
||||
|
||||
|
||||
def build_configurable_system_prompt(
|
||||
|
|
@ -442,6 +553,7 @@ def build_configurable_system_prompt(
|
|||
use_default_system_instructions: bool = True,
|
||||
citations_enabled: bool = True,
|
||||
today: datetime | None = None,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
|
||||
|
|
@ -460,6 +572,7 @@ def build_configurable_system_prompt(
|
|||
citations_enabled: Whether to include citation instructions (True) or
|
||||
anti-citation instructions (False).
|
||||
today: Optional datetime for today's date (defaults to current UTC date)
|
||||
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
||||
|
||||
Returns:
|
||||
Complete system prompt string
|
||||
|
|
@ -473,16 +586,14 @@ def build_configurable_system_prompt(
|
|||
resolved_today=resolved_today
|
||||
)
|
||||
elif use_default_system_instructions:
|
||||
# Use default instructions
|
||||
system_instructions = SURFSENSE_SYSTEM_INSTRUCTIONS.format(
|
||||
resolved_today=resolved_today
|
||||
)
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
system_instructions = _get_system_instructions(visibility, today)
|
||||
else:
|
||||
# No system instructions (edge case)
|
||||
system_instructions = ""
|
||||
|
||||
# Tools instructions are always included
|
||||
tools_instructions = SURFSENSE_TOOLS_INSTRUCTIONS
|
||||
# Tools instructions: conditional on thread_visibility (private vs shared memory wording)
|
||||
tools_instructions = _get_tools_instructions(thread_visibility)
|
||||
|
||||
# Citation instructions based on toggle
|
||||
citation_instructions = (
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ This module provides:
|
|||
- Tool factory for creating search_knowledge_base tools
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
|
@ -16,6 +17,7 @@ from langchain_core.tools import StructuredTool
|
|||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import async_session_maker
|
||||
from app.services.connector_service import ConnectorService
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -333,7 +335,7 @@ async def search_knowledge_base_async(
|
|||
Returns:
|
||||
Formatted string with search results
|
||||
"""
|
||||
all_documents = []
|
||||
all_documents: list[dict[str, Any]] = []
|
||||
|
||||
# Resolve date range (default last 2 years)
|
||||
from app.agents.new_chat.utils import resolve_date_range
|
||||
|
|
@ -345,323 +347,131 @@ async def search_knowledge_base_async(
|
|||
|
||||
connectors = _normalize_connectors(connectors_to_search, available_connectors)
|
||||
|
||||
for connector in connectors:
|
||||
connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = {
|
||||
"YOUTUBE_VIDEO": ("search_youtube", True, True, {}),
|
||||
"EXTENSION": ("search_extension", True, True, {}),
|
||||
"CRAWLED_URL": ("search_crawled_urls", True, True, {}),
|
||||
"FILE": ("search_files", True, True, {}),
|
||||
"SLACK_CONNECTOR": ("search_slack", True, True, {}),
|
||||
"TEAMS_CONNECTOR": ("search_teams", True, True, {}),
|
||||
"NOTION_CONNECTOR": ("search_notion", True, True, {}),
|
||||
"GITHUB_CONNECTOR": ("search_github", True, True, {}),
|
||||
"LINEAR_CONNECTOR": ("search_linear", True, True, {}),
|
||||
"TAVILY_API": ("search_tavily", False, True, {}),
|
||||
"SEARXNG_API": ("search_searxng", False, True, {}),
|
||||
"LINKUP_API": ("search_linkup", False, False, {"mode": "standard"}),
|
||||
"BAIDU_SEARCH_API": ("search_baidu", False, True, {}),
|
||||
"DISCORD_CONNECTOR": ("search_discord", True, True, {}),
|
||||
"JIRA_CONNECTOR": ("search_jira", True, True, {}),
|
||||
"GOOGLE_CALENDAR_CONNECTOR": ("search_google_calendar", True, True, {}),
|
||||
"AIRTABLE_CONNECTOR": ("search_airtable", True, True, {}),
|
||||
"GOOGLE_GMAIL_CONNECTOR": ("search_google_gmail", True, True, {}),
|
||||
"GOOGLE_DRIVE_FILE": ("search_google_drive", True, True, {}),
|
||||
"CONFLUENCE_CONNECTOR": ("search_confluence", True, True, {}),
|
||||
"CLICKUP_CONNECTOR": ("search_clickup", True, True, {}),
|
||||
"LUMA_CONNECTOR": ("search_luma", True, True, {}),
|
||||
"ELASTICSEARCH_CONNECTOR": ("search_elasticsearch", True, True, {}),
|
||||
"NOTE": ("search_notes", True, True, {}),
|
||||
"BOOKSTACK_CONNECTOR": ("search_bookstack", True, True, {}),
|
||||
"CIRCLEBACK": ("search_circleback", True, True, {}),
|
||||
"OBSIDIAN_CONNECTOR": ("search_obsidian", True, True, {}),
|
||||
# Composio connectors
|
||||
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR": (
|
||||
"search_composio_google_drive",
|
||||
True,
|
||||
True,
|
||||
{},
|
||||
),
|
||||
"COMPOSIO_GMAIL_CONNECTOR": ("search_composio_gmail", True, True, {}),
|
||||
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": (
|
||||
"search_composio_google_calendar",
|
||||
True,
|
||||
True,
|
||||
{},
|
||||
),
|
||||
}
|
||||
|
||||
# Keep a conservative cap to avoid overloading DB/external services.
|
||||
max_parallel_searches = 4
|
||||
semaphore = asyncio.Semaphore(max_parallel_searches)
|
||||
|
||||
async def _search_one_connector(connector: str) -> list[dict[str, Any]]:
|
||||
spec = connector_specs.get(connector)
|
||||
if spec is None:
|
||||
return []
|
||||
|
||||
method_name, includes_date_range, includes_top_k, extra_kwargs = spec
|
||||
kwargs: dict[str, Any] = {
|
||||
"user_query": query,
|
||||
"search_space_id": search_space_id,
|
||||
**extra_kwargs,
|
||||
}
|
||||
if includes_top_k:
|
||||
kwargs["top_k"] = top_k
|
||||
if includes_date_range:
|
||||
kwargs["start_date"] = resolved_start_date
|
||||
kwargs["end_date"] = resolved_end_date
|
||||
|
||||
try:
|
||||
if connector == "YOUTUBE_VIDEO":
|
||||
_, chunks = await connector_service.search_youtube(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
# Use isolated session per connector. Shared AsyncSession cannot safely
|
||||
# run concurrent DB operations.
|
||||
async with semaphore, async_session_maker() as isolated_session:
|
||||
isolated_connector_service = ConnectorService(
|
||||
isolated_session, search_space_id
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "EXTENSION":
|
||||
_, chunks = await connector_service.search_extension(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "CRAWLED_URL":
|
||||
_, chunks = await connector_service.search_crawled_urls(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "FILE":
|
||||
_, chunks = await connector_service.search_files(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "SLACK_CONNECTOR":
|
||||
_, chunks = await connector_service.search_slack(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "TEAMS_CONNECTOR":
|
||||
_, chunks = await connector_service.search_teams(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "NOTION_CONNECTOR":
|
||||
_, chunks = await connector_service.search_notion(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "GITHUB_CONNECTOR":
|
||||
_, chunks = await connector_service.search_github(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "LINEAR_CONNECTOR":
|
||||
_, chunks = await connector_service.search_linear(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "TAVILY_API":
|
||||
_, chunks = await connector_service.search_tavily(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "SEARXNG_API":
|
||||
_, chunks = await connector_service.search_searxng(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "LINKUP_API":
|
||||
# Keep behavior aligned with researcher: default "standard"
|
||||
_, chunks = await connector_service.search_linkup(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
mode="standard",
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "BAIDU_SEARCH_API":
|
||||
_, chunks = await connector_service.search_baidu(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "DISCORD_CONNECTOR":
|
||||
_, chunks = await connector_service.search_discord(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "JIRA_CONNECTOR":
|
||||
_, chunks = await connector_service.search_jira(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "GOOGLE_CALENDAR_CONNECTOR":
|
||||
_, chunks = await connector_service.search_google_calendar(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "AIRTABLE_CONNECTOR":
|
||||
_, chunks = await connector_service.search_airtable(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "GOOGLE_GMAIL_CONNECTOR":
|
||||
_, chunks = await connector_service.search_google_gmail(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "GOOGLE_DRIVE_FILE":
|
||||
_, chunks = await connector_service.search_google_drive(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "CONFLUENCE_CONNECTOR":
|
||||
_, chunks = await connector_service.search_confluence(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "CLICKUP_CONNECTOR":
|
||||
_, chunks = await connector_service.search_clickup(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "LUMA_CONNECTOR":
|
||||
_, chunks = await connector_service.search_luma(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "ELASTICSEARCH_CONNECTOR":
|
||||
_, chunks = await connector_service.search_elasticsearch(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "NOTE":
|
||||
_, chunks = await connector_service.search_notes(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "BOOKSTACK_CONNECTOR":
|
||||
_, chunks = await connector_service.search_bookstack(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "CIRCLEBACK":
|
||||
_, chunks = await connector_service.search_circleback(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "OBSIDIAN_CONNECTOR":
|
||||
_, chunks = await connector_service.search_obsidian(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
# =========================================================
|
||||
# Composio Connectors
|
||||
# =========================================================
|
||||
elif connector == "COMPOSIO_GOOGLE_DRIVE_CONNECTOR":
|
||||
_, chunks = await connector_service.search_composio_google_drive(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "COMPOSIO_GMAIL_CONNECTOR":
|
||||
_, chunks = await connector_service.search_composio_gmail(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
elif connector == "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR":
|
||||
_, chunks = await connector_service.search_composio_google_calendar(
|
||||
user_query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
all_documents.extend(chunks)
|
||||
|
||||
connector_method = getattr(isolated_connector_service, method_name)
|
||||
_, chunks = await connector_method(**kwargs)
|
||||
return chunks
|
||||
except Exception as e:
|
||||
print(f"Error searching connector {connector}: {e}")
|
||||
continue
|
||||
return []
|
||||
|
||||
# Deduplicate by content hash
|
||||
connector_results = await asyncio.gather(
|
||||
*[_search_one_connector(connector) for connector in connectors]
|
||||
)
|
||||
for chunks in connector_results:
|
||||
all_documents.extend(chunks)
|
||||
|
||||
# Deduplicate primarily by document ID. Only fall back to content hashing
|
||||
# when a document has no ID.
|
||||
seen_doc_ids: set[Any] = set()
|
||||
seen_hashes: set[int] = set()
|
||||
seen_content_hashes: set[int] = set()
|
||||
deduplicated: list[dict[str, Any]] = []
|
||||
|
||||
def _content_fingerprint(document: dict[str, Any]) -> int | None:
|
||||
chunks = document.get("chunks")
|
||||
if isinstance(chunks, list):
|
||||
chunk_texts = []
|
||||
for chunk in chunks:
|
||||
if not isinstance(chunk, dict):
|
||||
continue
|
||||
chunk_content = (chunk.get("content") or "").strip()
|
||||
if chunk_content:
|
||||
chunk_texts.append(chunk_content)
|
||||
if chunk_texts:
|
||||
return hash("||".join(chunk_texts))
|
||||
|
||||
flat_content = (document.get("content") or "").strip()
|
||||
if flat_content:
|
||||
return hash(flat_content)
|
||||
return None
|
||||
|
||||
for doc in all_documents:
|
||||
doc_id = (doc.get("document", {}) or {}).get("id")
|
||||
content = (doc.get("content", "") or "").strip()
|
||||
content_hash = hash(content)
|
||||
|
||||
if (doc_id and doc_id in seen_doc_ids) or content_hash in seen_hashes:
|
||||
if doc_id is not None:
|
||||
if doc_id in seen_doc_ids:
|
||||
continue
|
||||
seen_doc_ids.add(doc_id)
|
||||
deduplicated.append(doc)
|
||||
continue
|
||||
|
||||
if doc_id:
|
||||
seen_doc_ids.add(doc_id)
|
||||
seen_hashes.add(content_hash)
|
||||
content_hash = _content_fingerprint(doc)
|
||||
if content_hash is not None:
|
||||
if content_hash in seen_content_hashes:
|
||||
continue
|
||||
seen_content_hashes.add(content_hash)
|
||||
|
||||
deduplicated.append(doc)
|
||||
|
||||
return format_documents_for_context(deduplicated)
|
||||
|
|
|
|||
|
|
@ -11,21 +11,18 @@ Duplicate request prevention:
|
|||
- Returns a friendly message if a podcast is already being generated
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import redis
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import Podcast, PodcastStatus
|
||||
|
||||
# Redis connection for tracking active podcast tasks
|
||||
# Defaults to the Celery broker when REDIS_APP_URL is not set
|
||||
REDIS_URL = os.getenv(
|
||||
"REDIS_APP_URL",
|
||||
os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"),
|
||||
)
|
||||
REDIS_URL = config.REDIS_APP_URL
|
||||
_redis_client: redis.Redis | None = None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -43,6 +43,8 @@ from typing import Any
|
|||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from .display_image import create_display_image_tool
|
||||
from .generate_image import create_generate_image_tool
|
||||
from .knowledge_base import create_search_knowledge_base_tool
|
||||
|
|
@ -51,6 +53,10 @@ from .mcp_tool import load_mcp_tools
|
|||
from .podcast import create_generate_podcast_tool
|
||||
from .scrape_webpage import create_scrape_webpage_tool
|
||||
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
||||
from .shared_memory import (
|
||||
create_recall_shared_memory_tool,
|
||||
create_save_shared_memory_tool,
|
||||
)
|
||||
from .user_memory import create_recall_memory_tool, create_save_memory_tool
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -156,29 +162,42 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
requires=["db_session"],
|
||||
),
|
||||
# =========================================================================
|
||||
# USER MEMORY TOOLS - Claude-like memory feature
|
||||
# USER MEMORY TOOLS - private or team store by thread_visibility
|
||||
# =========================================================================
|
||||
# Save memory tool - stores facts/preferences about the user
|
||||
ToolDefinition(
|
||||
name="save_memory",
|
||||
description="Save facts, preferences, or context about the user for personalized responses",
|
||||
factory=lambda deps: create_save_memory_tool(
|
||||
user_id=deps["user_id"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
db_session=deps["db_session"],
|
||||
description="Save facts, preferences, or context for personalized or team responses",
|
||||
factory=lambda deps: (
|
||||
create_save_shared_memory_tool(
|
||||
search_space_id=deps["search_space_id"],
|
||||
created_by_id=deps["user_id"],
|
||||
db_session=deps["db_session"],
|
||||
)
|
||||
if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE
|
||||
else create_save_memory_tool(
|
||||
user_id=deps["user_id"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
db_session=deps["db_session"],
|
||||
)
|
||||
),
|
||||
requires=["user_id", "search_space_id", "db_session"],
|
||||
requires=["user_id", "search_space_id", "db_session", "thread_visibility"],
|
||||
),
|
||||
# Recall memory tool - retrieves relevant user memories
|
||||
ToolDefinition(
|
||||
name="recall_memory",
|
||||
description="Recall user memories for personalized and contextual responses",
|
||||
factory=lambda deps: create_recall_memory_tool(
|
||||
user_id=deps["user_id"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
db_session=deps["db_session"],
|
||||
description="Recall relevant memories (personal or team) for context",
|
||||
factory=lambda deps: (
|
||||
create_recall_shared_memory_tool(
|
||||
search_space_id=deps["search_space_id"],
|
||||
db_session=deps["db_session"],
|
||||
)
|
||||
if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE
|
||||
else create_recall_memory_tool(
|
||||
user_id=deps["user_id"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
db_session=deps["db_session"],
|
||||
)
|
||||
),
|
||||
requires=["user_id", "search_space_id", "db_session"],
|
||||
requires=["user_id", "search_space_id", "db_session", "thread_visibility"],
|
||||
),
|
||||
# =========================================================================
|
||||
# ADD YOUR CUSTOM TOOLS BELOW
|
||||
|
|
|
|||
280
surfsense_backend/app/agents/new_chat/tools/shared_memory.py
Normal file
280
surfsense_backend/app/agents/new_chat/tools/shared_memory.py
Normal file
|
|
@ -0,0 +1,280 @@
|
|||
"""Shared (team) memory backend for search-space-scoped AI context."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import MemoryCategory, SharedMemory, User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_RECALL_TOP_K = 5
|
||||
MAX_MEMORIES_PER_SEARCH_SPACE = 250
|
||||
|
||||
|
||||
async def get_shared_memory_count(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> int:
|
||||
result = await db_session.execute(
|
||||
select(SharedMemory).where(SharedMemory.search_space_id == search_space_id)
|
||||
)
|
||||
return len(result.scalars().all())
|
||||
|
||||
|
||||
async def delete_oldest_shared_memory(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> None:
|
||||
result = await db_session.execute(
|
||||
select(SharedMemory)
|
||||
.where(SharedMemory.search_space_id == search_space_id)
|
||||
.order_by(SharedMemory.updated_at.asc())
|
||||
.limit(1)
|
||||
)
|
||||
oldest = result.scalars().first()
|
||||
if oldest:
|
||||
await db_session.delete(oldest)
|
||||
await db_session.commit()
|
||||
|
||||
|
||||
def _to_uuid(value: str | UUID) -> UUID:
|
||||
if isinstance(value, UUID):
|
||||
return value
|
||||
return UUID(value)
|
||||
|
||||
|
||||
async def save_shared_memory(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
created_by_id: str | UUID,
|
||||
content: str,
|
||||
category: str = "fact",
|
||||
) -> dict[str, Any]:
|
||||
category = category.lower() if category else "fact"
|
||||
valid = ["preference", "fact", "instruction", "context"]
|
||||
if category not in valid:
|
||||
category = "fact"
|
||||
try:
|
||||
count = await get_shared_memory_count(db_session, search_space_id)
|
||||
if count >= MAX_MEMORIES_PER_SEARCH_SPACE:
|
||||
await delete_oldest_shared_memory(db_session, search_space_id)
|
||||
embedding = config.embedding_model_instance.embed(content)
|
||||
row = SharedMemory(
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=_to_uuid(created_by_id),
|
||||
memory_text=content,
|
||||
category=MemoryCategory(category),
|
||||
embedding=embedding,
|
||||
)
|
||||
db_session.add(row)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(row)
|
||||
return {
|
||||
"status": "saved",
|
||||
"memory_id": row.id,
|
||||
"memory_text": content,
|
||||
"category": category,
|
||||
"message": f"I'll remember: {content}",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to save shared memory: %s", e)
|
||||
await db_session.rollback()
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"message": "Failed to save memory. Please try again.",
|
||||
}
|
||||
|
||||
|
||||
async def recall_shared_memory(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
query: str | None = None,
|
||||
category: str | None = None,
|
||||
top_k: int = DEFAULT_RECALL_TOP_K,
|
||||
) -> dict[str, Any]:
|
||||
top_k = min(max(top_k, 1), 20)
|
||||
try:
|
||||
valid_categories = ["preference", "fact", "instruction", "context"]
|
||||
stmt = select(SharedMemory).where(
|
||||
SharedMemory.search_space_id == search_space_id
|
||||
)
|
||||
if category and category in valid_categories:
|
||||
stmt = stmt.where(SharedMemory.category == MemoryCategory(category))
|
||||
if query:
|
||||
query_embedding = config.embedding_model_instance.embed(query)
|
||||
stmt = stmt.order_by(
|
||||
SharedMemory.embedding.op("<=>")(query_embedding)
|
||||
).limit(top_k)
|
||||
else:
|
||||
stmt = stmt.order_by(SharedMemory.updated_at.desc()).limit(top_k)
|
||||
result = await db_session.execute(stmt)
|
||||
rows = result.scalars().all()
|
||||
memory_list = [
|
||||
{
|
||||
"id": m.id,
|
||||
"memory_text": m.memory_text,
|
||||
"category": m.category.value if m.category else "unknown",
|
||||
"updated_at": m.updated_at.isoformat() if m.updated_at else None,
|
||||
"created_by_id": str(m.created_by_id) if m.created_by_id else None,
|
||||
}
|
||||
for m in rows
|
||||
]
|
||||
created_by_ids = list(
|
||||
{m["created_by_id"] for m in memory_list if m["created_by_id"]}
|
||||
)
|
||||
created_by_map: dict[str, str] = {}
|
||||
if created_by_ids:
|
||||
uuids = [UUID(uid) for uid in created_by_ids]
|
||||
users_result = await db_session.execute(
|
||||
select(User).where(User.id.in_(uuids))
|
||||
)
|
||||
for u in users_result.scalars().all():
|
||||
created_by_map[str(u.id)] = u.display_name or "A team member"
|
||||
formatted_context = format_shared_memories_for_context(
|
||||
memory_list, created_by_map
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"count": len(memory_list),
|
||||
"memories": memory_list,
|
||||
"formatted_context": formatted_context,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to recall shared memory: %s", e)
|
||||
await db_session.rollback()
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"memories": [],
|
||||
"formatted_context": "Failed to recall memories.",
|
||||
}
|
||||
|
||||
|
||||
def format_shared_memories_for_context(
|
||||
memories: list[dict[str, Any]],
|
||||
created_by_map: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
if not memories:
|
||||
return "No relevant team memories found."
|
||||
created_by_map = created_by_map or {}
|
||||
parts = ["<team_memories>"]
|
||||
for memory in memories:
|
||||
category = memory.get("category", "unknown")
|
||||
text = memory.get("memory_text", "")
|
||||
updated = memory.get("updated_at", "")
|
||||
created_by_id = memory.get("created_by_id")
|
||||
added_by = (
|
||||
created_by_map.get(str(created_by_id), "A team member")
|
||||
if created_by_id is not None
|
||||
else "A team member"
|
||||
)
|
||||
parts.append(
|
||||
f" <memory category='{category}' updated='{updated}' added_by='{added_by}'>{text}</memory>"
|
||||
)
|
||||
parts.append("</team_memories>")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def create_save_shared_memory_tool(
|
||||
search_space_id: int,
|
||||
created_by_id: str | UUID,
|
||||
db_session: AsyncSession,
|
||||
):
|
||||
"""
|
||||
Factory function to create the save_memory tool for shared (team) chats.
|
||||
|
||||
Args:
|
||||
search_space_id: The search space ID
|
||||
created_by_id: The user ID of the person adding the memory
|
||||
db_session: Database session for executing queries
|
||||
|
||||
Returns:
|
||||
A configured tool function for saving team memories
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def save_memory(
|
||||
content: str,
|
||||
category: str = "fact",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Save a fact, preference, or context to the team's shared memory for future reference.
|
||||
|
||||
Use this tool when:
|
||||
- User or a team member says "remember this", "keep this in mind", or similar in this shared chat
|
||||
- The team agrees on something to remember (e.g., decisions, conventions, where things live)
|
||||
- Someone shares a preference or fact that should be visible to the whole team
|
||||
|
||||
The saved information will be available in future shared conversations in this space.
|
||||
|
||||
Args:
|
||||
content: The fact/preference/context to remember.
|
||||
Phrase it clearly, e.g., "API keys are stored in Vault",
|
||||
"The team prefers weekly demos on Fridays"
|
||||
category: Type of memory. One of:
|
||||
- "preference": Team or workspace preferences
|
||||
- "fact": Facts the team agreed on (e.g., processes, locations)
|
||||
- "instruction": Standing instructions for the team
|
||||
- "context": Current context (e.g., ongoing projects, goals)
|
||||
|
||||
Returns:
|
||||
A dictionary with the save status and memory details
|
||||
"""
|
||||
return await save_shared_memory(
|
||||
db_session, search_space_id, created_by_id, content, category
|
||||
)
|
||||
|
||||
return save_memory
|
||||
|
||||
|
||||
def create_recall_shared_memory_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
):
|
||||
"""
|
||||
Factory function to create the recall_memory tool for shared (team) chats.
|
||||
|
||||
Args:
|
||||
search_space_id: The search space ID
|
||||
db_session: Database session for executing queries
|
||||
|
||||
Returns:
|
||||
A configured tool function for recalling team memories
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def recall_memory(
|
||||
query: str | None = None,
|
||||
category: str | None = None,
|
||||
top_k: int = DEFAULT_RECALL_TOP_K,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Recall relevant team memories for this space to provide contextual responses.
|
||||
|
||||
Use this tool when:
|
||||
- You need team context to answer (e.g., "where do we store X?", "what did we decide about Y?")
|
||||
- Someone asks about something the team agreed to remember
|
||||
- Team preferences or conventions would improve the response
|
||||
|
||||
Args:
|
||||
query: Optional search query to find specific memories.
|
||||
If not provided, returns the most recent memories.
|
||||
category: Optional category filter. One of:
|
||||
"preference", "fact", "instruction", "context"
|
||||
top_k: Number of memories to retrieve (default: 5, max: 20)
|
||||
|
||||
Returns:
|
||||
A dictionary containing relevant memories and formatted context
|
||||
"""
|
||||
return await recall_shared_memory(
|
||||
db_session, search_space_id, query, category, top_k
|
||||
)
|
||||
|
||||
return recall_memory
|
||||
|
|
@ -213,6 +213,17 @@ class Config:
|
|||
# Database
|
||||
DATABASE_URL = os.getenv("DATABASE_URL")
|
||||
|
||||
# Celery / Redis
|
||||
CELERY_BROKER_URL = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
|
||||
CELERY_RESULT_BACKEND = os.getenv(
|
||||
"CELERY_RESULT_BACKEND", "redis://localhost:6379/0"
|
||||
)
|
||||
CELERY_TASK_DEFAULT_QUEUE = os.getenv("CELERY_TASK_DEFAULT_QUEUE", "surfsense")
|
||||
REDIS_APP_URL = os.getenv("REDIS_APP_URL", CELERY_BROKER_URL)
|
||||
CONNECTOR_INDEXING_LOCK_TTL_SECONDS = int(
|
||||
os.getenv("CONNECTOR_INDEXING_LOCK_TTL_SECONDS", str(8 * 60 * 60))
|
||||
)
|
||||
|
||||
NEXT_FRONTEND_URL = os.getenv("NEXT_FRONTEND_URL")
|
||||
# Backend URL to override the http to https in the OAuth redirect URI
|
||||
BACKEND_URL = os.getenv("BACKEND_URL")
|
||||
|
|
|
|||
|
|
@ -27,6 +27,12 @@ T = TypeVar("T")
|
|||
MAX_RETRIES = 5
|
||||
BASE_RETRY_DELAY = 1.0 # seconds
|
||||
MAX_RETRY_DELAY = 60.0 # seconds (Notion's max request timeout)
|
||||
MAX_RATE_LIMIT_WAIT_SECONDS = float(
|
||||
getattr(config, "NOTION_MAX_RETRY_AFTER_SECONDS", 30.0)
|
||||
)
|
||||
MAX_TOTAL_RETRY_WAIT_SECONDS = float(
|
||||
getattr(config, "NOTION_MAX_TOTAL_RETRY_WAIT_SECONDS", 120.0)
|
||||
)
|
||||
|
||||
# Type alias for retry callback function
|
||||
# Signature: async callback(retry_reason, attempt, max_attempts, wait_seconds) -> None
|
||||
|
|
@ -292,6 +298,7 @@ class NotionHistoryConnector:
|
|||
"""
|
||||
last_exception: APIResponseError | None = None
|
||||
retry_delay = BASE_RETRY_DELAY
|
||||
total_wait_time = 0.0
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
try:
|
||||
|
|
@ -325,6 +332,15 @@ class NotionHistoryConnector:
|
|||
wait_time = retry_delay
|
||||
else:
|
||||
wait_time = retry_delay
|
||||
|
||||
# Avoid very long worker sleeps from external Retry-After values.
|
||||
if wait_time > MAX_RATE_LIMIT_WAIT_SECONDS:
|
||||
logger.warning(
|
||||
f"Notion Retry-After ({wait_time}s) exceeds cap "
|
||||
f"({MAX_RATE_LIMIT_WAIT_SECONDS}s). Clamping wait time."
|
||||
)
|
||||
wait_time = MAX_RATE_LIMIT_WAIT_SECONDS
|
||||
|
||||
logger.warning(
|
||||
f"Notion API rate limited (429). "
|
||||
f"Waiting {wait_time}s. Attempt {attempt + 1}/{MAX_RETRIES}"
|
||||
|
|
@ -348,6 +364,14 @@ class NotionHistoryConnector:
|
|||
|
||||
# Notify about retry via callback (for user notifications)
|
||||
# Call before sleeping so user sees the message while we wait
|
||||
if total_wait_time + wait_time > MAX_TOTAL_RETRY_WAIT_SECONDS:
|
||||
logger.error(
|
||||
"Notion API retry budget exceeded "
|
||||
f"({total_wait_time + wait_time:.1f}s > "
|
||||
f"{MAX_TOTAL_RETRY_WAIT_SECONDS:.1f}s). Failing fast."
|
||||
)
|
||||
raise
|
||||
|
||||
if on_retry:
|
||||
try:
|
||||
await on_retry(
|
||||
|
|
@ -362,6 +386,7 @@ class NotionHistoryConnector:
|
|||
|
||||
# Wait before retrying
|
||||
await asyncio.sleep(wait_time)
|
||||
total_wait_time += wait_time
|
||||
|
||||
# Exponential backoff for next attempt
|
||||
retry_delay = min(retry_delay * 2, MAX_RETRY_DELAY)
|
||||
|
|
|
|||
|
|
@ -211,6 +211,7 @@ class LiteLLMProvider(str, Enum):
|
|||
DATABRICKS = "DATABRICKS"
|
||||
COMETAPI = "COMETAPI"
|
||||
HUGGINGFACE = "HUGGINGFACE"
|
||||
GITHUB_MODELS = "GITHUB_MODELS"
|
||||
CUSTOM = "CUSTOM"
|
||||
|
||||
|
||||
|
|
@ -272,19 +273,19 @@ INCENTIVE_TASKS_CONFIG = {
|
|||
IncentiveTaskType.GITHUB_STAR: {
|
||||
"title": "Star our GitHub repository",
|
||||
"description": "Show your support by starring SurfSense on GitHub",
|
||||
"pages_reward": 100,
|
||||
"pages_reward": 30,
|
||||
"action_url": "https://github.com/MODSetter/SurfSense",
|
||||
},
|
||||
IncentiveTaskType.REDDIT_FOLLOW: {
|
||||
"title": "Join our Subreddit",
|
||||
"description": "Join the SurfSense community on Reddit",
|
||||
"pages_reward": 100,
|
||||
"pages_reward": 30,
|
||||
"action_url": "https://www.reddit.com/r/SurfSense/",
|
||||
},
|
||||
IncentiveTaskType.DISCORD_JOIN: {
|
||||
"title": "Join our Discord",
|
||||
"description": "Join the SurfSense community on Discord",
|
||||
"pages_reward": 100,
|
||||
"pages_reward": 40,
|
||||
"action_url": "https://discord.gg/ejRNvftDp9",
|
||||
},
|
||||
# Future tasks can be configured here:
|
||||
|
|
@ -801,9 +802,8 @@ class MemoryCategory(str, Enum):
|
|||
|
||||
class UserMemory(BaseModel, TimestampMixin):
|
||||
"""
|
||||
Stores facts, preferences, and context about users for personalized AI responses.
|
||||
Similar to Claude's memory feature - enables the AI to remember user information
|
||||
across conversations.
|
||||
Private memory: facts, preferences, context per user per search space.
|
||||
Used only for private chats (not shared/team chats).
|
||||
"""
|
||||
|
||||
__tablename__ = "user_memories"
|
||||
|
|
@ -847,6 +847,40 @@ class UserMemory(BaseModel, TimestampMixin):
|
|||
search_space = relationship("SearchSpace", back_populates="user_memories")
|
||||
|
||||
|
||||
class SharedMemory(BaseModel, TimestampMixin):
|
||||
__tablename__ = "shared_memories"
|
||||
|
||||
search_space_id = Column(
|
||||
Integer,
|
||||
ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
created_by_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
memory_text = Column(Text, nullable=False)
|
||||
category = Column(
|
||||
SQLAlchemyEnum(MemoryCategory),
|
||||
nullable=False,
|
||||
default=MemoryCategory.fact,
|
||||
)
|
||||
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
||||
updated_at = Column(
|
||||
TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
index=True,
|
||||
)
|
||||
|
||||
search_space = relationship("SearchSpace", back_populates="shared_memories")
|
||||
created_by = relationship("User")
|
||||
|
||||
|
||||
class Document(BaseModel, TimestampMixin):
|
||||
__tablename__ = "documents"
|
||||
|
||||
|
|
@ -1209,6 +1243,12 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
order_by="UserMemory.updated_at.desc()",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
shared_memories = relationship(
|
||||
"SharedMemory",
|
||||
back_populates="search_space",
|
||||
order_by="SharedMemory.updated_at.desc()",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
class SearchSourceConnector(BaseModel, TimestampMixin):
|
||||
|
|
@ -1258,7 +1298,7 @@ class NewLLMConfig(BaseModel, TimestampMixin):
|
|||
- Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS)
|
||||
- Citation toggle (enable/disable citation instructions)
|
||||
|
||||
Note: SURFSENSE_TOOLS_INSTRUCTIONS is always used and not configurable.
|
||||
Note: Tools instructions are built by get_tools_instructions(thread_visibility) (personal vs shared memory).
|
||||
"""
|
||||
|
||||
__tablename__ = "new_llm_configs"
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ from app.db import (
|
|||
from app.schemas import (
|
||||
DocumentRead,
|
||||
DocumentsCreate,
|
||||
DocumentStatusBatchResponse,
|
||||
DocumentStatusItemRead,
|
||||
DocumentStatusSchema,
|
||||
DocumentTitleRead,
|
||||
DocumentTitleSearchResponse,
|
||||
|
|
@ -148,6 +150,7 @@ async def create_documents_file_upload(
|
|||
tuple[Document, str, str]
|
||||
] = [] # (document, temp_path, filename)
|
||||
skipped_duplicates = 0
|
||||
duplicate_document_ids: list[int] = []
|
||||
|
||||
# ===== PHASE 1: Create pending documents for all files =====
|
||||
# This makes ALL documents visible in the UI immediately with pending status
|
||||
|
|
@ -182,6 +185,7 @@ async def create_documents_file_upload(
|
|||
# True duplicate — content already indexed, skip
|
||||
os.unlink(temp_path)
|
||||
skipped_duplicates += 1
|
||||
duplicate_document_ids.append(existing.id)
|
||||
continue
|
||||
|
||||
# Existing document is stuck (failed/pending/processing)
|
||||
|
|
@ -255,6 +259,7 @@ async def create_documents_file_upload(
|
|||
return {
|
||||
"message": "Files uploaded for processing",
|
||||
"document_ids": [doc.id for doc in created_documents],
|
||||
"duplicate_document_ids": duplicate_document_ids,
|
||||
"total_files": len(files),
|
||||
"pending_files": len(files_to_process),
|
||||
"skipped_duplicates": skipped_duplicates,
|
||||
|
|
@ -678,6 +683,74 @@ async def search_document_titles(
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/documents/status", response_model=DocumentStatusBatchResponse)
|
||||
async def get_documents_status(
|
||||
search_space_id: int,
|
||||
document_ids: str,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Batch status endpoint for documents in a search space.
|
||||
|
||||
Returns lightweight status info for the provided document IDs, intended for
|
||||
polling async ETL progress in chat upload flows.
|
||||
"""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read documents in this search space",
|
||||
)
|
||||
|
||||
# Parse comma-separated IDs (e.g. "1,2,3")
|
||||
parsed_ids = []
|
||||
for raw_id in document_ids.split(","):
|
||||
value = raw_id.strip()
|
||||
if not value:
|
||||
continue
|
||||
try:
|
||||
parsed_ids.append(int(value))
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid document id: {value}",
|
||||
) from None
|
||||
|
||||
if not parsed_ids:
|
||||
return DocumentStatusBatchResponse(items=[])
|
||||
|
||||
result = await session.execute(
|
||||
select(Document).filter(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.id.in_(parsed_ids),
|
||||
)
|
||||
)
|
||||
docs = result.scalars().all()
|
||||
|
||||
items = [
|
||||
DocumentStatusItemRead(
|
||||
id=doc.id,
|
||||
title=doc.title,
|
||||
document_type=doc.document_type,
|
||||
status=DocumentStatusSchema(
|
||||
state=(doc.status or {}).get("state", "ready"),
|
||||
reason=(doc.status or {}).get("reason"),
|
||||
),
|
||||
)
|
||||
for doc in docs
|
||||
]
|
||||
return DocumentStatusBatchResponse(items=items)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch document status: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/documents/type-counts")
|
||||
async def get_document_type_counts(
|
||||
search_space_id: int | None = None,
|
||||
|
|
|
|||
|
|
@ -8,16 +8,11 @@ These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
|
|||
- PUT /threads/{thread_id} - Update thread (rename, archive)
|
||||
- DELETE /threads/{thread_id} - Delete thread
|
||||
- POST /threads/{thread_id}/messages - Append message
|
||||
- POST /attachments/process - Process attachments for chat context
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
|
|
@ -1045,12 +1040,13 @@ async def handle_new_chat(
|
|||
search_space_id=request.search_space_id,
|
||||
chat_id=request.chat_id,
|
||||
session=session,
|
||||
user_id=str(user.id), # Pass user ID for memory tools and session state
|
||||
user_id=str(user.id),
|
||||
llm_config_id=llm_config_id,
|
||||
attachments=request.attachments,
|
||||
mentioned_document_ids=request.mentioned_document_ids,
|
||||
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
|
||||
needs_history_bootstrap=thread.needs_history_bootstrap,
|
||||
thread_visibility=thread.visibility,
|
||||
current_user_display_name=user.display_name or "A team member",
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
|
|
@ -1276,11 +1272,12 @@ async def regenerate_response(
|
|||
session=session,
|
||||
user_id=str(user.id),
|
||||
llm_config_id=llm_config_id,
|
||||
attachments=request.attachments,
|
||||
mentioned_document_ids=request.mentioned_document_ids,
|
||||
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
|
||||
checkpoint_id=target_checkpoint_id,
|
||||
needs_history_bootstrap=thread.needs_history_bootstrap,
|
||||
thread_visibility=thread.visibility,
|
||||
current_user_display_name=user.display_name or "A team member",
|
||||
):
|
||||
yield chunk
|
||||
# If we get here, streaming completed successfully
|
||||
|
|
@ -1329,185 +1326,3 @@ async def regenerate_response(
|
|||
status_code=500,
|
||||
detail=f"An unexpected error occurred during regeneration: {e!s}",
|
||||
) from None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Attachment Processing Endpoint
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/attachments/process")
|
||||
async def process_attachment(
|
||||
file: UploadFile = File(...),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Process an attachment file and extract its content as markdown.
|
||||
|
||||
This endpoint uses the configured ETL service to parse files and return
|
||||
the extracted content that can be used as context in chat messages.
|
||||
|
||||
Supported file types depend on the configured ETL_SERVICE:
|
||||
- Markdown/Text files: .md, .markdown, .txt (always supported)
|
||||
- Audio files: .mp3, .mp4, .mpeg, .mpga, .m4a, .wav, .webm (if STT configured)
|
||||
- Documents: .pdf, .docx, .doc, .pptx, .xlsx (depends on ETL service)
|
||||
|
||||
Returns:
|
||||
JSON with attachment id, name, type, and extracted content
|
||||
"""
|
||||
from app.config import config as app_config
|
||||
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No filename provided")
|
||||
|
||||
filename = file.filename
|
||||
attachment_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Save file to a temporary location
|
||||
file_ext = os.path.splitext(filename)[1].lower()
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
content = await file.read()
|
||||
temp_file.write(content)
|
||||
|
||||
extracted_content = ""
|
||||
|
||||
# Process based on file type
|
||||
if file_ext in (".md", ".markdown", ".txt"):
|
||||
# For text/markdown files, read content directly
|
||||
with open(temp_path, encoding="utf-8") as f:
|
||||
extracted_content = f.read()
|
||||
|
||||
elif file_ext in (".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm"):
|
||||
# Audio files - transcribe if STT service is configured
|
||||
if not app_config.STT_SERVICE:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="Audio transcription is not configured. Please set STT_SERVICE.",
|
||||
)
|
||||
|
||||
stt_service_type = (
|
||||
"local" if app_config.STT_SERVICE.startswith("local/") else "external"
|
||||
)
|
||||
|
||||
if stt_service_type == "local":
|
||||
from app.services.stt_service import stt_service
|
||||
|
||||
result = stt_service.transcribe_file(temp_path)
|
||||
extracted_content = result.get("text", "")
|
||||
else:
|
||||
from litellm import atranscription
|
||||
|
||||
with open(temp_path, "rb") as audio_file:
|
||||
transcription_kwargs = {
|
||||
"model": app_config.STT_SERVICE,
|
||||
"file": audio_file,
|
||||
"api_key": app_config.STT_SERVICE_API_KEY,
|
||||
}
|
||||
if app_config.STT_SERVICE_API_BASE:
|
||||
transcription_kwargs["api_base"] = (
|
||||
app_config.STT_SERVICE_API_BASE
|
||||
)
|
||||
|
||||
transcription_response = await atranscription(
|
||||
**transcription_kwargs
|
||||
)
|
||||
extracted_content = transcription_response.get("text", "")
|
||||
|
||||
if extracted_content:
|
||||
extracted_content = (
|
||||
f"# Transcription of {filename}\n\n{extracted_content}"
|
||||
)
|
||||
|
||||
else:
|
||||
# Document files - use configured ETL service
|
||||
if app_config.ETL_SERVICE == "UNSTRUCTURED":
|
||||
from langchain_unstructured import UnstructuredLoader
|
||||
|
||||
from app.utils.document_converters import convert_document_to_markdown
|
||||
|
||||
loader = UnstructuredLoader(
|
||||
temp_path,
|
||||
mode="elements",
|
||||
post_processors=[],
|
||||
languages=["eng"],
|
||||
include_orig_elements=False,
|
||||
include_metadata=False,
|
||||
strategy="auto",
|
||||
)
|
||||
docs = await loader.aload()
|
||||
extracted_content = await convert_document_to_markdown(docs)
|
||||
|
||||
elif app_config.ETL_SERVICE == "LLAMACLOUD":
|
||||
from llama_cloud_services import LlamaParse
|
||||
from llama_cloud_services.parse.utils import ResultType
|
||||
|
||||
parser = LlamaParse(
|
||||
api_key=app_config.LLAMA_CLOUD_API_KEY,
|
||||
num_workers=1,
|
||||
verbose=False,
|
||||
language="en",
|
||||
result_type=ResultType.MD,
|
||||
)
|
||||
result = await parser.aparse(temp_path)
|
||||
markdown_documents = await result.aget_markdown_documents(
|
||||
split_by_page=False
|
||||
)
|
||||
|
||||
if markdown_documents:
|
||||
extracted_content = "\n\n".join(
|
||||
doc.text for doc in markdown_documents
|
||||
)
|
||||
|
||||
elif app_config.ETL_SERVICE == "DOCLING":
|
||||
from app.services.docling_service import create_docling_service
|
||||
|
||||
docling_service = create_docling_service()
|
||||
result = await docling_service.process_document(temp_path, filename)
|
||||
extracted_content = result.get("content", "")
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"ETL service not configured or unsupported file type: {file_ext}",
|
||||
)
|
||||
|
||||
# Clean up temp file
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(temp_path)
|
||||
|
||||
if not extracted_content:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Could not extract content from file: {filename}",
|
||||
)
|
||||
|
||||
# Determine attachment type (must be one of: "image", "document", "file")
|
||||
# assistant-ui only supports these three types
|
||||
if file_ext in (".png", ".jpg", ".jpeg", ".gif", ".webp"):
|
||||
attachment_type = "image"
|
||||
else:
|
||||
# All other files (including audio, documents, text) are treated as "document"
|
||||
attachment_type = "document"
|
||||
|
||||
return {
|
||||
"id": attachment_id,
|
||||
"name": filename,
|
||||
"type": attachment_type,
|
||||
"content": extracted_content,
|
||||
"contentLength": len(extracted_content),
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
# Clean up temp file on error
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(temp_path)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to process attachment: {e!s}",
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ Non-OAuth connectors (BookStack, GitHub, etc.) are limited to one per search spa
|
|||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from contextlib import suppress
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -32,6 +32,7 @@ from sqlalchemy.exc import IntegrityError
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.github_connector import GitHubConnector
|
||||
from app.db import (
|
||||
Permission,
|
||||
|
|
@ -70,6 +71,10 @@ from app.tasks.connector_indexers import (
|
|||
index_slack_messages,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.indexing_locks import (
|
||||
acquire_connector_indexing_lock,
|
||||
release_connector_indexing_lock,
|
||||
)
|
||||
from app.utils.periodic_scheduler import (
|
||||
create_periodic_schedule,
|
||||
delete_periodic_schedule,
|
||||
|
|
@ -91,11 +96,9 @@ def get_heartbeat_redis_client() -> redis.Redis:
|
|||
"""Get or create Redis client for heartbeat tracking."""
|
||||
global _heartbeat_redis_client
|
||||
if _heartbeat_redis_client is None:
|
||||
redis_url = os.getenv(
|
||||
"REDIS_APP_URL",
|
||||
os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"),
|
||||
_heartbeat_redis_client = redis.from_url(
|
||||
config.REDIS_APP_URL, decode_responses=True
|
||||
)
|
||||
_heartbeat_redis_client = redis.from_url(redis_url, decode_responses=True)
|
||||
return _heartbeat_redis_client
|
||||
|
||||
|
||||
|
|
@ -1229,10 +1232,19 @@ async def _run_indexing_with_notifications(
|
|||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
|
||||
notification = None
|
||||
connector_lock_acquired = False
|
||||
# Track indexed count for retry notifications and heartbeat
|
||||
current_indexed_count = 0
|
||||
|
||||
try:
|
||||
connector_lock_acquired = acquire_connector_indexing_lock(connector_id)
|
||||
if not connector_lock_acquired:
|
||||
logger.info(
|
||||
f"Skipping indexing for connector {connector_id} "
|
||||
"(another worker already holds Redis connector lock)"
|
||||
)
|
||||
return
|
||||
|
||||
# Get connector info for notification
|
||||
connector_result = await session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
|
|
@ -1558,6 +1570,9 @@ async def _run_indexing_with_notifications(
|
|||
get_heartbeat_redis_client().delete(heartbeat_key)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors - key will expire anyway
|
||||
if connector_lock_acquired:
|
||||
with suppress(Exception):
|
||||
release_connector_indexing_lock(connector_id)
|
||||
|
||||
|
||||
async def run_notion_indexing_with_new_session(
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ from .documents import (
|
|||
DocumentBase,
|
||||
DocumentRead,
|
||||
DocumentsCreate,
|
||||
DocumentStatusBatchResponse,
|
||||
DocumentStatusItemRead,
|
||||
DocumentStatusSchema,
|
||||
DocumentTitleRead,
|
||||
DocumentTitleSearchResponse,
|
||||
|
|
@ -105,6 +107,8 @@ __all__ = [
|
|||
# Document schemas
|
||||
"DocumentBase",
|
||||
"DocumentRead",
|
||||
"DocumentStatusBatchResponse",
|
||||
"DocumentStatusItemRead",
|
||||
"DocumentStatusSchema",
|
||||
"DocumentTitleRead",
|
||||
"DocumentTitleSearchResponse",
|
||||
|
|
|
|||
|
|
@ -99,3 +99,20 @@ class DocumentTitleSearchResponse(BaseModel):
|
|||
|
||||
items: list[DocumentTitleRead]
|
||||
has_more: bool
|
||||
|
||||
|
||||
class DocumentStatusItemRead(BaseModel):
|
||||
"""Lightweight document status payload for batch status polling."""
|
||||
|
||||
id: int
|
||||
title: str
|
||||
document_type: DocumentType
|
||||
status: DocumentStatusSchema
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class DocumentStatusBatchResponse(BaseModel):
|
||||
"""Batch status response for a set of document IDs."""
|
||||
|
||||
items: list[DocumentStatusItemRead]
|
||||
|
|
|
|||
|
|
@ -159,15 +159,6 @@ class ChatMessage(BaseModel):
|
|||
content: str
|
||||
|
||||
|
||||
class ChatAttachment(BaseModel):
|
||||
"""An attachment with its extracted content for chat context."""
|
||||
|
||||
id: str # Unique attachment ID
|
||||
name: str # Original filename
|
||||
type: str # Attachment type: document, image, audio
|
||||
content: str # Extracted markdown content from the file
|
||||
|
||||
|
||||
class NewChatRequest(BaseModel):
|
||||
"""Request schema for the deep agent chat endpoint."""
|
||||
|
||||
|
|
@ -175,9 +166,6 @@ class NewChatRequest(BaseModel):
|
|||
user_query: str
|
||||
search_space_id: int
|
||||
messages: list[ChatMessage] | None = None # Optional chat history from frontend
|
||||
attachments: list[ChatAttachment] | None = (
|
||||
None # Optional attachments with extracted content
|
||||
)
|
||||
mentioned_document_ids: list[int] | None = (
|
||||
None # Optional document IDs mentioned with @ in the chat
|
||||
)
|
||||
|
|
@ -201,7 +189,6 @@ class RegenerateRequest(BaseModel):
|
|||
user_query: str | None = (
|
||||
None # New user query (for edit). None = reload with same query
|
||||
)
|
||||
attachments: list[ChatAttachment] | None = None
|
||||
mentioned_document_ids: list[int] | None = None
|
||||
mentioned_surfsense_doc_ids: list[int] | None = None
|
||||
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ PROVIDER_MAP = {
|
|||
"ALIBABA_QWEN": "openai",
|
||||
"MOONSHOT": "openai",
|
||||
"ZHIPU": "openai",
|
||||
"GITHUB_MODELS": "github",
|
||||
"HUGGINGFACE": "huggingface",
|
||||
"CUSTOM": "custom",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -119,6 +119,7 @@ async def validate_llm_config(
|
|||
"ALIBABA_QWEN": "openai",
|
||||
"MOONSHOT": "openai",
|
||||
"ZHIPU": "openai", # GLM needs special handling
|
||||
"GITHUB_MODELS": "github",
|
||||
}
|
||||
provider_prefix = provider_map.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{model_name}"
|
||||
|
|
@ -335,6 +336,7 @@ async def get_search_space_llm_instance(
|
|||
"ALIBABA_QWEN": "openai",
|
||||
"MOONSHOT": "openai",
|
||||
"ZHIPU": "openai",
|
||||
"GITHUB_MODELS": "github",
|
||||
}
|
||||
provider_prefix = provider_map.get(
|
||||
llm_config.provider.value, llm_config.provider.value.lower()
|
||||
|
|
|
|||
|
|
@ -36,11 +36,9 @@ def _get_doc_heartbeat_redis():
|
|||
|
||||
global _doc_heartbeat_redis
|
||||
if _doc_heartbeat_redis is None:
|
||||
redis_url = os.getenv(
|
||||
"REDIS_APP_URL",
|
||||
os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"),
|
||||
_doc_heartbeat_redis = redis.from_url(
|
||||
config.REDIS_APP_URL, decode_responses=True
|
||||
)
|
||||
_doc_heartbeat_redis = redis.from_url(redis_url, decode_responses=True)
|
||||
return _doc_heartbeat_redis
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -46,16 +46,10 @@ def get_celery_session_maker():
|
|||
|
||||
def _clear_generating_podcast(search_space_id: int) -> None:
|
||||
"""Clear the generating podcast marker from Redis when task completes."""
|
||||
import os
|
||||
|
||||
import redis
|
||||
|
||||
try:
|
||||
redis_url = os.getenv(
|
||||
"REDIS_APP_URL",
|
||||
os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"),
|
||||
)
|
||||
client = redis.from_url(redis_url, decode_responses=True)
|
||||
client = redis.from_url(config.REDIS_APP_URL, decode_responses=True)
|
||||
key = f"podcast:generating:{search_space_id}"
|
||||
client.delete(key)
|
||||
logger.info(
|
||||
|
|
|
|||
|
|
@ -9,7 +9,8 @@ from sqlalchemy.pool import NullPool
|
|||
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.utils.indexing_locks import is_connector_indexing_locked
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -107,6 +108,32 @@ async def _check_and_trigger_schedules():
|
|||
|
||||
# Trigger indexing for each due connector
|
||||
for connector in due_connectors:
|
||||
# Primary guard: Redis lock indicates a task is currently running.
|
||||
if is_connector_indexing_locked(connector.id):
|
||||
logger.info(
|
||||
f"Skipping periodic indexing for connector {connector.id} "
|
||||
"(Redis lock indicates indexing is already in progress)"
|
||||
)
|
||||
continue
|
||||
|
||||
# Skip scheduling if a sync for this connector is already in progress.
|
||||
# This prevents duplicate tasks from piling up under slow/rate-limited providers.
|
||||
in_progress_result = await session.execute(
|
||||
select(Notification.id).where(
|
||||
Notification.type == "connector_indexing",
|
||||
Notification.notification_metadata["connector_id"].astext
|
||||
== str(connector.id),
|
||||
Notification.notification_metadata["status"].astext
|
||||
== "in_progress",
|
||||
)
|
||||
)
|
||||
if in_progress_result.first():
|
||||
logger.info(
|
||||
f"Skipping periodic indexing for connector {connector.id} "
|
||||
"(already has in-progress indexing notification)"
|
||||
)
|
||||
continue
|
||||
|
||||
task = task_map.get(connector.connector_type)
|
||||
if task:
|
||||
logger.info(
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@ Detection mechanism:
|
|||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import redis
|
||||
|
|
@ -52,11 +51,7 @@ def get_redis_client() -> redis.Redis:
|
|||
"""Get or create Redis client for heartbeat checking."""
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
redis_url = os.getenv(
|
||||
"REDIS_APP_URL",
|
||||
os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"),
|
||||
)
|
||||
_redis_client = redis.from_url(redis_url, decode_responses=True)
|
||||
_redis_client = redis.from_url(config.REDIS_APP_URL, decode_responses=True)
|
||||
return _redis_client
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -26,9 +26,8 @@ from app.agents.new_chat.llm_config import (
|
|||
load_agent_config,
|
||||
load_llm_config_from_yaml,
|
||||
)
|
||||
from app.db import Document, SurfsenseDocsDocument
|
||||
from app.db import ChatVisibility, Document, SurfsenseDocsDocument
|
||||
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
from app.schemas.new_chat import ChatAttachment
|
||||
from app.services.chat_session_state_service import (
|
||||
clear_ai_responding,
|
||||
set_ai_responding,
|
||||
|
|
@ -38,23 +37,6 @@ from app.services.new_streaming_service import VercelStreamingService
|
|||
from app.utils.content_utils import bootstrap_history_from_db
|
||||
|
||||
|
||||
def format_attachments_as_context(attachments: list[ChatAttachment]) -> str:
|
||||
"""Format attachments as context for the agent."""
|
||||
if not attachments:
|
||||
return ""
|
||||
|
||||
context_parts = ["<user_attachments>"]
|
||||
for i, attachment in enumerate(attachments, 1):
|
||||
context_parts.append(
|
||||
f"<attachment index='{i}' name='{attachment.name}' type='{attachment.type}'>"
|
||||
)
|
||||
context_parts.append(f"<![CDATA[{attachment.content}]]>")
|
||||
context_parts.append("</attachment>")
|
||||
context_parts.append("</user_attachments>")
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
|
||||
def format_mentioned_documents_as_context(documents: list[Document]) -> str:
|
||||
"""
|
||||
Format mentioned documents as context for the agent.
|
||||
|
|
@ -203,11 +185,12 @@ async def stream_new_chat(
|
|||
session: AsyncSession,
|
||||
user_id: str | None = None,
|
||||
llm_config_id: int = -1,
|
||||
attachments: list[ChatAttachment] | None = None,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
mentioned_surfsense_doc_ids: list[int] | None = None,
|
||||
checkpoint_id: str | None = None,
|
||||
needs_history_bootstrap: bool = False,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
current_user_display_name: str | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream chat responses from the new SurfSense deep agent.
|
||||
|
|
@ -222,7 +205,6 @@ async def stream_new_chat(
|
|||
session: The database session
|
||||
user_id: The current user's UUID string (for memory tools and session state)
|
||||
llm_config_id: The LLM configuration ID (default: -1 for first global config)
|
||||
attachments: Optional attachments with extracted content
|
||||
needs_history_bootstrap: If True, load message history from DB (for cloned chats)
|
||||
mentioned_document_ids: Optional list of document IDs mentioned with @ in the chat
|
||||
mentioned_surfsense_doc_ids: Optional list of SurfSense doc IDs mentioned with @ in the chat
|
||||
|
|
@ -295,17 +277,18 @@ async def stream_new_chat(
|
|||
# Get the PostgreSQL checkpointer for persistent conversation memory
|
||||
checkpointer = await get_checkpointer()
|
||||
|
||||
# Create the deep agent with checkpointer and configurable prompts
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
agent = await create_surfsense_deep_agent(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
connector_service=connector_service,
|
||||
checkpointer=checkpointer,
|
||||
user_id=user_id, # Pass user ID for memory tools
|
||||
thread_id=chat_id, # Pass chat ID for podcast association
|
||||
agent_config=agent_config, # Pass prompt configuration
|
||||
firecrawl_api_key=firecrawl_api_key, # Pass Firecrawl API key if configured
|
||||
user_id=user_id,
|
||||
thread_id=chat_id,
|
||||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
)
|
||||
|
||||
# Build input with message history
|
||||
|
|
@ -313,7 +296,9 @@ async def stream_new_chat(
|
|||
|
||||
# Bootstrap history for cloned chats (no LangGraph checkpoint exists yet)
|
||||
if needs_history_bootstrap:
|
||||
langchain_messages = await bootstrap_history_from_db(session, chat_id)
|
||||
langchain_messages = await bootstrap_history_from_db(
|
||||
session, chat_id, thread_visibility=visibility
|
||||
)
|
||||
|
||||
# Clear the flag so we don't bootstrap again on next message
|
||||
from app.db import NewChatThread
|
||||
|
|
@ -355,13 +340,10 @@ async def stream_new_chat(
|
|||
)
|
||||
mentioned_surfsense_docs = list(result.scalars().all())
|
||||
|
||||
# Format the user query with context (attachments + mentioned documents + surfsense docs)
|
||||
# Format the user query with context (mentioned documents + SurfSense docs)
|
||||
final_query = user_query
|
||||
context_parts = []
|
||||
|
||||
if attachments:
|
||||
context_parts.append(format_attachments_as_context(attachments))
|
||||
|
||||
if mentioned_documents:
|
||||
context_parts.append(
|
||||
format_mentioned_documents_as_context(mentioned_documents)
|
||||
|
|
@ -376,6 +358,9 @@ async def stream_new_chat(
|
|||
context = "\n\n".join(context_parts)
|
||||
final_query = f"{context}\n\n<user_query>{user_query}</user_query>"
|
||||
|
||||
if visibility == ChatVisibility.SEARCH_SPACE and current_user_display_name:
|
||||
final_query = f"**[{current_user_display_name}]:** {final_query}"
|
||||
|
||||
# if messages:
|
||||
# # Convert frontend messages to LangChain format
|
||||
# for msg in messages:
|
||||
|
|
@ -451,39 +436,20 @@ async def stream_new_chat(
|
|||
last_active_step_id = analyze_step_id
|
||||
|
||||
# Determine step title and action verb based on context
|
||||
if attachments and (mentioned_documents or mentioned_surfsense_docs):
|
||||
last_active_step_title = "Analyzing your content"
|
||||
action_verb = "Reading"
|
||||
elif attachments:
|
||||
last_active_step_title = "Reading your content"
|
||||
action_verb = "Reading"
|
||||
elif mentioned_documents or mentioned_surfsense_docs:
|
||||
if mentioned_documents or mentioned_surfsense_docs:
|
||||
last_active_step_title = "Analyzing referenced content"
|
||||
action_verb = "Analyzing"
|
||||
else:
|
||||
last_active_step_title = "Understanding your request"
|
||||
action_verb = "Processing"
|
||||
|
||||
# Build the message with inline context about attachments/documents
|
||||
# Build the message with inline context about referenced documents
|
||||
processing_parts = []
|
||||
|
||||
# Add the user query
|
||||
query_text = user_query[:80] + ("..." if len(user_query) > 80 else "")
|
||||
processing_parts.append(query_text)
|
||||
|
||||
# Add file attachment names inline
|
||||
if attachments:
|
||||
attachment_names = []
|
||||
for attachment in attachments:
|
||||
name = attachment.name
|
||||
if len(name) > 30:
|
||||
name = name[:27] + "..."
|
||||
attachment_names.append(name)
|
||||
if len(attachment_names) == 1:
|
||||
processing_parts.append(f"[{attachment_names[0]}]")
|
||||
else:
|
||||
processing_parts.append(f"[{len(attachment_names)} files]")
|
||||
|
||||
# Add mentioned document names inline
|
||||
if mentioned_documents:
|
||||
doc_names = []
|
||||
|
|
|
|||
|
|
@ -52,10 +52,22 @@ def safe_set_chunks(document: Document, chunks: list) -> None:
|
|||
# Instead of: document.chunks = chunks (DANGEROUS!)
|
||||
safe_set_chunks(document, chunks) # Always safe
|
||||
"""
|
||||
from sqlalchemy.orm import object_session
|
||||
from sqlalchemy.orm.attributes import set_committed_value
|
||||
|
||||
# Keep relationship assignment lazy-load-safe.
|
||||
set_committed_value(document, "chunks", chunks)
|
||||
|
||||
# Ensure chunk rows are actually persisted.
|
||||
# set_committed_value bypasses normal unit-of-work tracking, so we need to
|
||||
# explicitly attach chunk objects to the current session.
|
||||
session = object_session(document)
|
||||
if session is not None:
|
||||
if document.id is not None:
|
||||
for chunk in chunks:
|
||||
chunk.document_id = document.id
|
||||
session.add_all(chunks)
|
||||
|
||||
|
||||
def parse_date_flexible(date_str: str) -> datetime:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
"""
|
||||
Discord connector indexer.
|
||||
|
||||
Implements 2-phase document status updates for real-time UI feedback:
|
||||
Implements batch indexing: groups up to DISCORD_BATCH_SIZE messages per channel
|
||||
into a single document for efficient indexing and better conversational context.
|
||||
|
||||
Uses 2-phase document status updates for real-time UI feedback:
|
||||
- Phase 1: Create all documents with 'pending' status (visible in UI immediately)
|
||||
- Phase 2: Process each document: pending → processing → ready/failed
|
||||
"""
|
||||
|
|
@ -41,6 +44,72 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
|||
# Heartbeat interval in seconds - update notification every 30 seconds
|
||||
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
|
||||
# Number of messages to combine into a single document for batch indexing.
|
||||
# Grouping messages improves conversational context in embeddings/chunks and
|
||||
# drastically reduces the number of documents, embedding calls, and DB overhead.
|
||||
DISCORD_BATCH_SIZE = 100
|
||||
|
||||
|
||||
def _build_batch_document_string(
|
||||
guild_name: str,
|
||||
guild_id: str,
|
||||
channel_name: str,
|
||||
channel_id: str,
|
||||
messages: list[dict],
|
||||
) -> str:
|
||||
"""
|
||||
Combine multiple Discord messages into a single document string.
|
||||
|
||||
Each message is formatted with its timestamp and author, and all messages
|
||||
are concatenated into a conversation-style document. The chunker will
|
||||
later split this into overlapping windows of ~8-10 consecutive messages,
|
||||
preserving conversational context in each chunk's embedding.
|
||||
|
||||
Args:
|
||||
guild_name: Name of the Discord guild
|
||||
guild_id: ID of the Discord guild
|
||||
channel_name: Name of the channel
|
||||
channel_id: ID of the channel
|
||||
messages: List of message dicts with 'author_name', 'created_at', 'content'
|
||||
|
||||
Returns:
|
||||
Formatted document string with metadata and conversation content
|
||||
"""
|
||||
first_msg_time = messages[0].get("created_at", "Unknown")
|
||||
last_msg_time = messages[-1].get("created_at", "Unknown")
|
||||
|
||||
metadata_lines = [
|
||||
f"GUILD_NAME: {guild_name}",
|
||||
f"GUILD_ID: {guild_id}",
|
||||
f"CHANNEL_NAME: {channel_name}",
|
||||
f"CHANNEL_ID: {channel_id}",
|
||||
f"MESSAGE_COUNT: {len(messages)}",
|
||||
f"FIRST_MESSAGE_TIME: {first_msg_time}",
|
||||
f"LAST_MESSAGE_TIME: {last_msg_time}",
|
||||
]
|
||||
|
||||
conversation_lines = []
|
||||
for msg in messages:
|
||||
author = msg.get("author_name", "Unknown User")
|
||||
timestamp = msg.get("created_at", "Unknown Time")
|
||||
content = msg.get("content", "")
|
||||
conversation_lines.append(f"[{timestamp}] {author}: {content}")
|
||||
|
||||
metadata_sections = [
|
||||
("METADATA", metadata_lines),
|
||||
(
|
||||
"CONTENT",
|
||||
[
|
||||
"FORMAT: markdown",
|
||||
"TEXT_START",
|
||||
"\n".join(conversation_lines),
|
||||
"TEXT_END",
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
return build_document_metadata_markdown(metadata_sections)
|
||||
|
||||
|
||||
async def index_discord_messages(
|
||||
session: AsyncSession,
|
||||
|
|
@ -55,6 +124,12 @@ async def index_discord_messages(
|
|||
"""
|
||||
Index Discord messages from the configured guild's channels.
|
||||
|
||||
Messages are grouped into batches of DISCORD_BATCH_SIZE per channel,
|
||||
so each document contains up to 100 consecutive messages with full
|
||||
conversational context. This reduces document count, embedding calls,
|
||||
and DB overhead by ~100x while improving search quality through
|
||||
context-aware chunk embeddings.
|
||||
|
||||
Implements 2-phase document status updates for real-time UI feedback:
|
||||
- Phase 1: Create all documents with 'pending' status (visible in UI immediately)
|
||||
- Phase 2: Process each document: pending → processing → ready/failed
|
||||
|
|
@ -324,6 +399,7 @@ async def index_discord_messages(
|
|||
documents_skipped = 0
|
||||
documents_failed = 0
|
||||
duplicate_content_count = 0
|
||||
total_messages_collected = 0
|
||||
skipped_channels: list[str] = []
|
||||
|
||||
# Heartbeat tracking - update notification periodically to prevent appearing stuck
|
||||
|
|
@ -340,10 +416,12 @@ async def index_discord_messages(
|
|||
)
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 1: Collect all messages and create pending documents
|
||||
# This makes ALL documents visible in the UI immediately with pending status
|
||||
# PHASE 1: Collect messages, group into batches, and create pending documents
|
||||
# Messages are grouped into batches of DISCORD_BATCH_SIZE per channel.
|
||||
# Each batch becomes a single document with full conversational context.
|
||||
# All documents are visible in the UI immediately with pending status.
|
||||
# =======================================================================
|
||||
messages_to_process = [] # List of dicts with document and message data
|
||||
batches_to_process = [] # List of dicts with document and batch data
|
||||
new_documents_created = False
|
||||
|
||||
try:
|
||||
|
|
@ -394,44 +472,35 @@ async def index_discord_messages(
|
|||
)
|
||||
continue
|
||||
|
||||
# Process each message as an individual document (like Slack)
|
||||
for msg in formatted_messages:
|
||||
msg_id = msg.get("id", "")
|
||||
msg_user_name = msg.get("author_name", "Unknown User")
|
||||
msg_timestamp = msg.get("created_at", "Unknown Time")
|
||||
msg_text = msg.get("content", "")
|
||||
total_messages_collected += len(formatted_messages)
|
||||
|
||||
# Format document metadata (similar to Slack)
|
||||
metadata_sections = [
|
||||
(
|
||||
"METADATA",
|
||||
[
|
||||
f"GUILD_NAME: {guild_name}",
|
||||
f"GUILD_ID: {guild_id}",
|
||||
f"CHANNEL_NAME: {channel_name}",
|
||||
f"CHANNEL_ID: {channel_id}",
|
||||
f"MESSAGE_TIMESTAMP: {msg_timestamp}",
|
||||
f"MESSAGE_USER_NAME: {msg_user_name}",
|
||||
],
|
||||
),
|
||||
(
|
||||
"CONTENT",
|
||||
[
|
||||
"FORMAT: markdown",
|
||||
"TEXT_START",
|
||||
msg_text,
|
||||
"TEXT_END",
|
||||
],
|
||||
),
|
||||
# =======================================================
|
||||
# Group messages into batches of DISCORD_BATCH_SIZE
|
||||
# Each batch becomes a single document with conversation context
|
||||
# =======================================================
|
||||
for batch_start in range(
|
||||
0, len(formatted_messages), DISCORD_BATCH_SIZE
|
||||
):
|
||||
batch = formatted_messages[
|
||||
batch_start : batch_start + DISCORD_BATCH_SIZE
|
||||
]
|
||||
|
||||
# Build the document string
|
||||
combined_document_string = build_document_metadata_markdown(
|
||||
metadata_sections
|
||||
# Build combined document string from all messages in this batch
|
||||
combined_document_string = _build_batch_document_string(
|
||||
guild_name=guild_name,
|
||||
guild_id=guild_id,
|
||||
channel_name=channel_name,
|
||||
channel_id=channel_id,
|
||||
messages=batch,
|
||||
)
|
||||
|
||||
# Generate unique identifier hash for this Discord message
|
||||
unique_identifier = f"{channel_id}_{msg_id}"
|
||||
# Generate unique identifier for this batch using
|
||||
# channel_id + first message ID + last message ID
|
||||
first_msg_id = batch[0].get("id", "")
|
||||
last_msg_id = batch[-1].get("id", "")
|
||||
unique_identifier = (
|
||||
f"{channel_id}_{first_msg_id}_{last_msg_id}"
|
||||
)
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.DISCORD_CONNECTOR,
|
||||
unique_identifier,
|
||||
|
|
@ -464,7 +533,7 @@ async def index_discord_messages(
|
|||
continue
|
||||
|
||||
# Queue existing document for update (will be set to processing in Phase 2)
|
||||
messages_to_process.append(
|
||||
batches_to_process.append(
|
||||
{
|
||||
"document": existing_document,
|
||||
"is_new": False,
|
||||
|
|
@ -474,9 +543,15 @@ async def index_discord_messages(
|
|||
"guild_id": guild_id,
|
||||
"channel_name": channel_name,
|
||||
"channel_id": channel_id,
|
||||
"message_id": msg_id,
|
||||
"message_timestamp": msg_timestamp,
|
||||
"message_user_name": msg_user_name,
|
||||
"first_message_id": first_msg_id,
|
||||
"last_message_id": last_msg_id,
|
||||
"first_message_time": batch[0].get(
|
||||
"created_at", "Unknown"
|
||||
),
|
||||
"last_message_time": batch[-1].get(
|
||||
"created_at", "Unknown"
|
||||
),
|
||||
"message_count": len(batch),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
|
@ -492,7 +567,7 @@ async def index_discord_messages(
|
|||
|
||||
if duplicate_by_content:
|
||||
logger.info(
|
||||
f"Discord message {msg_id} in {guild_name}#{channel_name} already indexed by another connector "
|
||||
f"Discord batch ({len(batch)} msgs) in {guild_name}#{channel_name} already indexed by another connector "
|
||||
f"(existing document ID: {duplicate_by_content.id}, "
|
||||
f"type: {duplicate_by_content.document_type}). Skipping."
|
||||
)
|
||||
|
|
@ -510,7 +585,9 @@ async def index_discord_messages(
|
|||
"guild_id": guild_id,
|
||||
"channel_name": channel_name,
|
||||
"channel_id": channel_id,
|
||||
"message_id": msg_id,
|
||||
"first_message_id": first_msg_id,
|
||||
"last_message_id": last_msg_id,
|
||||
"message_count": len(batch),
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
content="Pending...", # Placeholder until processed
|
||||
|
|
@ -526,7 +603,7 @@ async def index_discord_messages(
|
|||
session.add(document)
|
||||
new_documents_created = True
|
||||
|
||||
messages_to_process.append(
|
||||
batches_to_process.append(
|
||||
{
|
||||
"document": document,
|
||||
"is_new": True,
|
||||
|
|
@ -536,12 +613,23 @@ async def index_discord_messages(
|
|||
"guild_id": guild_id,
|
||||
"channel_name": channel_name,
|
||||
"channel_id": channel_id,
|
||||
"message_id": msg_id,
|
||||
"message_timestamp": msg_timestamp,
|
||||
"message_user_name": msg_user_name,
|
||||
"first_message_id": first_msg_id,
|
||||
"last_message_id": last_msg_id,
|
||||
"first_message_time": batch[0].get(
|
||||
"created_at", "Unknown"
|
||||
),
|
||||
"last_message_time": batch[-1].get(
|
||||
"created_at", "Unknown"
|
||||
),
|
||||
"message_count": len(batch),
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Phase 1: Collected {len(formatted_messages)} messages from channel {channel_name}, "
|
||||
f"grouped into {(len(formatted_messages) + DISCORD_BATCH_SIZE - 1) // DISCORD_BATCH_SIZE} batch(es)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing guild {guild_name}: {e!s}", exc_info=True
|
||||
|
|
@ -554,17 +642,18 @@ async def index_discord_messages(
|
|||
# Commit all pending documents - they all appear in UI now
|
||||
if new_documents_created:
|
||||
logger.info(
|
||||
f"Phase 1: Committing {len([m for m in messages_to_process if m['is_new']])} pending documents"
|
||||
f"Phase 1: Committing {len([b for b in batches_to_process if b['is_new']])} pending batch documents "
|
||||
f"({total_messages_collected} total messages across all channels)"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 2: Process each document one by one
|
||||
# PHASE 2: Process each batch document one by one
|
||||
# Each document transitions: pending → processing → ready/failed
|
||||
# =======================================================================
|
||||
logger.info(f"Phase 2: Processing {len(messages_to_process)} documents")
|
||||
logger.info(f"Phase 2: Processing {len(batches_to_process)} batch documents")
|
||||
|
||||
for item in messages_to_process:
|
||||
for item in batches_to_process:
|
||||
# Send heartbeat periodically
|
||||
if on_heartbeat_callback:
|
||||
current_time = time.time()
|
||||
|
|
@ -594,9 +683,11 @@ async def index_discord_messages(
|
|||
"guild_id": item["guild_id"],
|
||||
"channel_name": item["channel_name"],
|
||||
"channel_id": item["channel_id"],
|
||||
"message_id": item["message_id"],
|
||||
"message_timestamp": item["message_timestamp"],
|
||||
"message_user_name": item["message_user_name"],
|
||||
"first_message_id": item["first_message_id"],
|
||||
"last_message_id": item["last_message_id"],
|
||||
"first_message_time": item["first_message_time"],
|
||||
"last_message_time": item["last_message_time"],
|
||||
"message_count": item["message_count"],
|
||||
"indexed_at": datetime.now(UTC).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"connector_id": connector_id,
|
||||
}
|
||||
|
|
@ -609,12 +700,14 @@ async def index_discord_messages(
|
|||
# Batch commit every 10 documents (for ready status updates)
|
||||
if documents_indexed % 10 == 0:
|
||||
logger.info(
|
||||
f"Committing batch: {documents_indexed} Discord messages processed so far"
|
||||
f"Committing batch: {documents_indexed} batch documents processed so far"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Discord message: {e!s}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error processing Discord batch document: {e!s}", exc_info=True
|
||||
)
|
||||
# Mark document as failed with reason (visible in UI)
|
||||
try:
|
||||
document.status = DocumentStatus.failed(str(e))
|
||||
|
|
@ -631,7 +724,8 @@ async def index_discord_messages(
|
|||
|
||||
# Final commit for any remaining documents not yet committed in batches
|
||||
logger.info(
|
||||
f"Final commit: Total {documents_indexed} Discord messages processed"
|
||||
f"Final commit: Total {documents_indexed} batch documents processed "
|
||||
f"(from {total_messages_collected} messages)"
|
||||
)
|
||||
try:
|
||||
await session.commit()
|
||||
|
|
@ -672,14 +766,18 @@ async def index_discord_messages(
|
|||
"documents_failed": documents_failed,
|
||||
"duplicate_content_count": duplicate_content_count,
|
||||
"skipped_channels_count": len(skipped_channels),
|
||||
"total_messages_collected": total_messages_collected,
|
||||
"batch_size": DISCORD_BATCH_SIZE,
|
||||
"guild_id": guild_id,
|
||||
"guild_name": guild_name,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Discord indexing completed for guild {guild_name}: {documents_indexed} ready, {documents_skipped} skipped, "
|
||||
f"{documents_failed} failed ({duplicate_content_count} duplicate content)"
|
||||
f"Discord indexing completed for guild {guild_name}: "
|
||||
f"{documents_indexed} batch docs ready (from {total_messages_collected} messages), "
|
||||
f"{documents_skipped} skipped, {documents_failed} failed "
|
||||
f"({duplicate_content_count} duplicate content)"
|
||||
)
|
||||
return documents_indexed, warning_message
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
"""
|
||||
Slack connector indexer.
|
||||
|
||||
Implements 2-phase document status updates for real-time UI feedback:
|
||||
Implements batch indexing: groups up to SLACK_BATCH_SIZE messages per channel
|
||||
into a single document for efficient indexing and better conversational context.
|
||||
|
||||
Uses 2-phase document status updates for real-time UI feedback:
|
||||
- Phase 1: Create all documents with 'pending' status (visible in UI immediately)
|
||||
- Phase 2: Process each document: pending → processing → ready/failed
|
||||
"""
|
||||
|
|
@ -42,6 +45,72 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
|||
# Heartbeat interval in seconds - update notification every 30 seconds
|
||||
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
|
||||
# Number of messages to combine into a single document for batch indexing.
|
||||
# Grouping messages improves conversational context in embeddings/chunks and
|
||||
# drastically reduces the number of documents, embedding calls, and DB overhead.
|
||||
SLACK_BATCH_SIZE = 100
|
||||
|
||||
|
||||
def _build_batch_document_string(
|
||||
team_name: str,
|
||||
team_id: str,
|
||||
channel_name: str,
|
||||
channel_id: str,
|
||||
messages: list[dict],
|
||||
) -> str:
|
||||
"""
|
||||
Combine multiple Slack messages into a single document string.
|
||||
|
||||
Each message is formatted with its timestamp and author, and all messages
|
||||
are concatenated into a conversation-style document. The chunker will
|
||||
later split this into overlapping windows of ~8-10 consecutive messages,
|
||||
preserving conversational context in each chunk's embedding.
|
||||
|
||||
Args:
|
||||
team_name: Name of the Slack workspace
|
||||
team_id: ID of the Slack workspace
|
||||
channel_name: Name of the channel
|
||||
channel_id: ID of the channel
|
||||
messages: List of formatted message dicts with 'user_name', 'datetime', 'text'
|
||||
|
||||
Returns:
|
||||
Formatted document string with metadata and conversation content
|
||||
"""
|
||||
first_msg_time = messages[0].get("datetime", "Unknown")
|
||||
last_msg_time = messages[-1].get("datetime", "Unknown")
|
||||
|
||||
metadata_lines = [
|
||||
f"WORKSPACE_NAME: {team_name}",
|
||||
f"WORKSPACE_ID: {team_id}",
|
||||
f"CHANNEL_NAME: {channel_name}",
|
||||
f"CHANNEL_ID: {channel_id}",
|
||||
f"MESSAGE_COUNT: {len(messages)}",
|
||||
f"FIRST_MESSAGE_TIME: {first_msg_time}",
|
||||
f"LAST_MESSAGE_TIME: {last_msg_time}",
|
||||
]
|
||||
|
||||
conversation_lines = []
|
||||
for msg in messages:
|
||||
author = msg.get("user_name", "Unknown User")
|
||||
timestamp = msg.get("datetime", "Unknown Time")
|
||||
content = msg.get("text", "")
|
||||
conversation_lines.append(f"[{timestamp}] {author}: {content}")
|
||||
|
||||
metadata_sections = [
|
||||
("METADATA", metadata_lines),
|
||||
(
|
||||
"CONTENT",
|
||||
[
|
||||
"FORMAT: markdown",
|
||||
"TEXT_START",
|
||||
"\n".join(conversation_lines),
|
||||
"TEXT_END",
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
return build_document_metadata_markdown(metadata_sections)
|
||||
|
||||
|
||||
async def index_slack_messages(
|
||||
session: AsyncSession,
|
||||
|
|
@ -56,6 +125,16 @@ async def index_slack_messages(
|
|||
"""
|
||||
Index Slack messages from all accessible channels.
|
||||
|
||||
Messages are grouped into batches of SLACK_BATCH_SIZE per channel,
|
||||
so each document contains up to 100 consecutive messages with full
|
||||
conversational context. This reduces document count, embedding calls,
|
||||
and DB overhead by ~100x while improving search quality through
|
||||
context-aware chunk embeddings.
|
||||
|
||||
Implements 2-phase document status updates for real-time UI feedback:
|
||||
- Phase 1: Create all documents with 'pending' status (visible in UI immediately)
|
||||
- Phase 2: Process each document: pending → processing → ready/failed
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Slack connector
|
||||
|
|
@ -109,6 +188,10 @@ async def index_slack_messages(
|
|||
f"Connector with ID {connector_id} not found or is not a Slack connector",
|
||||
)
|
||||
|
||||
# Extract workspace info from connector config
|
||||
team_id = connector.config.get("team_id", "")
|
||||
team_name = connector.config.get("team_name", "Unknown Workspace")
|
||||
|
||||
# Note: Token handling is now done automatically by SlackHistory
|
||||
# with auto-refresh support. We just need to pass session and connector_id.
|
||||
|
||||
|
|
@ -182,6 +265,8 @@ async def index_slack_messages(
|
|||
documents_indexed = 0
|
||||
documents_skipped = 0
|
||||
documents_failed = 0 # Track messages that failed processing
|
||||
duplicate_content_count = 0
|
||||
total_messages_collected = 0
|
||||
skipped_channels = []
|
||||
|
||||
# Heartbeat tracking - update notification periodically to prevent appearing stuck
|
||||
|
|
@ -194,10 +279,12 @@ async def index_slack_messages(
|
|||
)
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 1: Collect all messages from all channels, create pending documents
|
||||
# This makes ALL documents visible in the UI immediately with pending status
|
||||
# PHASE 1: Collect messages, group into batches, and create pending documents
|
||||
# Messages are grouped into batches of SLACK_BATCH_SIZE per channel.
|
||||
# Each batch becomes a single document with full conversational context.
|
||||
# All documents are visible in the UI immediately with pending status.
|
||||
# =======================================================================
|
||||
messages_to_process = [] # List of dicts with document and message data
|
||||
batches_to_process = [] # List of dicts with document and batch data
|
||||
new_documents_created = False
|
||||
|
||||
for channel_obj in channels:
|
||||
|
|
@ -264,40 +351,35 @@ async def index_slack_messages(
|
|||
documents_skipped += 1
|
||||
continue # Skip if no valid messages after filtering
|
||||
|
||||
for msg in formatted_messages:
|
||||
timestamp = msg.get("datetime", "Unknown Time")
|
||||
msg_ts = msg.get("ts", timestamp) # Get original Slack timestamp
|
||||
msg_user_name = msg.get("user_name", "Unknown User")
|
||||
msg_user_email = msg.get("user_email", "Unknown Email")
|
||||
msg_text = msg.get("text", "")
|
||||
total_messages_collected += len(formatted_messages)
|
||||
|
||||
# Format document metadata
|
||||
metadata_sections = [
|
||||
(
|
||||
"METADATA",
|
||||
[
|
||||
f"CHANNEL_NAME: {channel_name}",
|
||||
f"CHANNEL_ID: {channel_id}",
|
||||
f"MESSAGE_TIMESTAMP: {timestamp}",
|
||||
f"MESSAGE_USER_NAME: {msg_user_name}",
|
||||
f"MESSAGE_USER_EMAIL: {msg_user_email}",
|
||||
],
|
||||
),
|
||||
(
|
||||
"CONTENT",
|
||||
["FORMAT: markdown", "TEXT_START", msg_text, "TEXT_END"],
|
||||
),
|
||||
# =======================================================
|
||||
# Group messages into batches of SLACK_BATCH_SIZE
|
||||
# Each batch becomes a single document with conversation context
|
||||
# =======================================================
|
||||
for batch_start in range(0, len(formatted_messages), SLACK_BATCH_SIZE):
|
||||
batch = formatted_messages[
|
||||
batch_start : batch_start + SLACK_BATCH_SIZE
|
||||
]
|
||||
|
||||
# Build the document string
|
||||
combined_document_string = build_document_metadata_markdown(
|
||||
metadata_sections
|
||||
# Build combined document string from all messages in this batch
|
||||
combined_document_string = _build_batch_document_string(
|
||||
team_name=team_name,
|
||||
team_id=team_id,
|
||||
channel_name=channel_name,
|
||||
channel_id=channel_id,
|
||||
messages=batch,
|
||||
)
|
||||
|
||||
# Generate unique identifier hash for this Slack message
|
||||
unique_identifier = f"{channel_id}_{msg_ts}"
|
||||
# Generate unique identifier for this batch using
|
||||
# channel_id + first message ts + last message ts
|
||||
first_msg_ts = batch[0].get("timestamp", "")
|
||||
last_msg_ts = batch[-1].get("timestamp", "")
|
||||
unique_identifier = f"{channel_id}_{first_msg_ts}_{last_msg_ts}"
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.SLACK_CONNECTOR, unique_identifier, search_space_id
|
||||
DocumentType.SLACK_CONNECTOR,
|
||||
unique_identifier,
|
||||
search_space_id,
|
||||
)
|
||||
|
||||
# Generate content hash
|
||||
|
|
@ -318,25 +400,31 @@ async def index_slack_messages(
|
|||
existing_document.status, DocumentStatus.READY
|
||||
):
|
||||
existing_document.status = DocumentStatus.ready()
|
||||
logger.info(
|
||||
f"Document for Slack message {msg_ts} in channel {channel_name} unchanged. Skipping."
|
||||
)
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Queue existing document for update (will be set to processing in Phase 2)
|
||||
messages_to_process.append(
|
||||
batches_to_process.append(
|
||||
{
|
||||
"document": existing_document,
|
||||
"is_new": False,
|
||||
"combined_document_string": combined_document_string,
|
||||
"content_hash": content_hash,
|
||||
"team_name": team_name,
|
||||
"team_id": team_id,
|
||||
"channel_name": channel_name,
|
||||
"channel_id": channel_id,
|
||||
"msg_ts": msg_ts,
|
||||
"first_message_ts": first_msg_ts,
|
||||
"last_message_ts": last_msg_ts,
|
||||
"first_message_time": batch[0].get(
|
||||
"datetime", "Unknown"
|
||||
),
|
||||
"last_message_time": batch[-1].get(
|
||||
"datetime", "Unknown"
|
||||
),
|
||||
"message_count": len(batch),
|
||||
"start_date": start_date_str,
|
||||
"end_date": end_date_str,
|
||||
"message_count": len(formatted_messages),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
|
@ -350,22 +438,27 @@ async def index_slack_messages(
|
|||
|
||||
if duplicate_by_content:
|
||||
logger.info(
|
||||
f"Slack message {msg_ts} in channel {channel_name} already indexed by another connector "
|
||||
f"Slack batch ({len(batch)} msgs) in {team_name}#{channel_name} already indexed by another connector "
|
||||
f"(existing document ID: {duplicate_by_content.id}, "
|
||||
f"type: {duplicate_by_content.document_type}). Skipping."
|
||||
)
|
||||
duplicate_content_count += 1
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Create new document with PENDING status (visible in UI immediately)
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=channel_name,
|
||||
title=f"{team_name}#{channel_name}",
|
||||
document_type=DocumentType.SLACK_CONNECTOR,
|
||||
document_metadata={
|
||||
"team_name": team_name,
|
||||
"team_id": team_id,
|
||||
"channel_name": channel_name,
|
||||
"channel_id": channel_id,
|
||||
"msg_ts": msg_ts,
|
||||
"first_message_ts": first_msg_ts,
|
||||
"last_message_ts": last_msg_ts,
|
||||
"message_count": len(batch),
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
content="Pending...", # Placeholder until processed
|
||||
|
|
@ -381,23 +474,29 @@ async def index_slack_messages(
|
|||
session.add(document)
|
||||
new_documents_created = True
|
||||
|
||||
messages_to_process.append(
|
||||
batches_to_process.append(
|
||||
{
|
||||
"document": document,
|
||||
"is_new": True,
|
||||
"combined_document_string": combined_document_string,
|
||||
"content_hash": content_hash,
|
||||
"team_name": team_name,
|
||||
"team_id": team_id,
|
||||
"channel_name": channel_name,
|
||||
"channel_id": channel_id,
|
||||
"msg_ts": msg_ts,
|
||||
"first_message_ts": first_msg_ts,
|
||||
"last_message_ts": last_msg_ts,
|
||||
"first_message_time": batch[0].get("datetime", "Unknown"),
|
||||
"last_message_time": batch[-1].get("datetime", "Unknown"),
|
||||
"message_count": len(batch),
|
||||
"start_date": start_date_str,
|
||||
"end_date": end_date_str,
|
||||
"message_count": len(formatted_messages),
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Phase 1: Collected {len(formatted_messages)} messages from channel {channel_name}"
|
||||
f"Phase 1: Collected {len(formatted_messages)} messages from channel {channel_name}, "
|
||||
f"grouped into {(len(formatted_messages) + SLACK_BATCH_SIZE - 1) // SLACK_BATCH_SIZE} batch(es)"
|
||||
)
|
||||
|
||||
except SlackApiError as slack_error:
|
||||
|
|
@ -416,17 +515,18 @@ async def index_slack_messages(
|
|||
# Commit all pending documents - they all appear in UI now
|
||||
if new_documents_created:
|
||||
logger.info(
|
||||
f"Phase 1: Committing {len([m for m in messages_to_process if m['is_new']])} pending documents"
|
||||
f"Phase 1: Committing {len([b for b in batches_to_process if b['is_new']])} pending batch documents "
|
||||
f"({total_messages_collected} total messages across all channels)"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 2: Process each document one by one
|
||||
# PHASE 2: Process each batch document one by one
|
||||
# Each document transitions: pending → processing → ready/failed
|
||||
# =======================================================================
|
||||
logger.info(f"Phase 2: Processing {len(messages_to_process)} documents")
|
||||
logger.info(f"Phase 2: Processing {len(batches_to_process)} batch documents")
|
||||
|
||||
for item in messages_to_process:
|
||||
for item in batches_to_process:
|
||||
# Send heartbeat periodically
|
||||
if on_heartbeat_callback:
|
||||
current_time = time.time()
|
||||
|
|
@ -447,16 +547,22 @@ async def index_slack_messages(
|
|||
)
|
||||
|
||||
# Update document to READY with actual content
|
||||
document.title = item["channel_name"]
|
||||
document.title = f"{item['team_name']}#{item['channel_name']}"
|
||||
document.content = item["combined_document_string"]
|
||||
document.content_hash = item["content_hash"]
|
||||
document.embedding = doc_embedding
|
||||
document.document_metadata = {
|
||||
"team_name": item["team_name"],
|
||||
"team_id": item["team_id"],
|
||||
"channel_name": item["channel_name"],
|
||||
"channel_id": item["channel_id"],
|
||||
"first_message_ts": item["first_message_ts"],
|
||||
"last_message_ts": item["last_message_ts"],
|
||||
"first_message_time": item["first_message_time"],
|
||||
"last_message_time": item["last_message_time"],
|
||||
"message_count": item["message_count"],
|
||||
"start_date": item["start_date"],
|
||||
"end_date": item["end_date"],
|
||||
"message_count": item["message_count"],
|
||||
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"connector_id": connector_id,
|
||||
}
|
||||
|
|
@ -469,13 +575,13 @@ async def index_slack_messages(
|
|||
# Batch commit every 10 documents (for ready status updates)
|
||||
if documents_indexed % 10 == 0:
|
||||
logger.info(
|
||||
f"Committing batch: {documents_indexed} Slack messages processed so far"
|
||||
f"Committing batch: {documents_indexed} batch documents processed so far"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing Slack message {item.get('msg_ts', 'Unknown')}: {e!s}",
|
||||
f"Error processing Slack batch document: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Mark document as failed with reason (visible in UI)
|
||||
|
|
@ -493,7 +599,10 @@ async def index_slack_messages(
|
|||
await update_connector_last_indexed(session, connector, update_last_indexed)
|
||||
|
||||
# Final commit for any remaining documents not yet committed in batches
|
||||
logger.info(f"Final commit: Total {documents_indexed} Slack messages processed")
|
||||
logger.info(
|
||||
f"Final commit: Total {documents_indexed} batch documents processed "
|
||||
f"(from {total_messages_collected} messages)"
|
||||
)
|
||||
try:
|
||||
await session.commit()
|
||||
logger.info("Successfully committed all Slack document changes to database")
|
||||
|
|
@ -514,8 +623,12 @@ async def index_slack_messages(
|
|||
|
||||
# Build warning message if there were issues
|
||||
warning_parts = []
|
||||
if duplicate_content_count > 0:
|
||||
warning_parts.append(f"{duplicate_content_count} duplicate")
|
||||
if documents_failed > 0:
|
||||
warning_parts.append(f"{documents_failed} failed")
|
||||
if skipped_channels:
|
||||
warning_parts.append(f"{len(skipped_channels)} channels skipped")
|
||||
warning_message = ", ".join(warning_parts) if warning_parts else None
|
||||
|
||||
# Log success
|
||||
|
|
@ -527,13 +640,20 @@ async def index_slack_messages(
|
|||
"documents_indexed": documents_indexed,
|
||||
"documents_skipped": documents_skipped,
|
||||
"documents_failed": documents_failed,
|
||||
"duplicate_content_count": duplicate_content_count,
|
||||
"skipped_channels_count": len(skipped_channels),
|
||||
"total_messages_collected": total_messages_collected,
|
||||
"batch_size": SLACK_BATCH_SIZE,
|
||||
"team_id": team_id,
|
||||
"team_name": team_name,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Slack indexing completed: {documents_indexed} ready, "
|
||||
f"{documents_skipped} skipped, {documents_failed} failed"
|
||||
f"Slack indexing completed for workspace {team_name}: "
|
||||
f"{documents_indexed} batch docs ready (from {total_messages_collected} messages), "
|
||||
f"{documents_skipped} skipped, {documents_failed} failed "
|
||||
f"({duplicate_content_count} duplicate content)"
|
||||
)
|
||||
return documents_indexed, warning_message
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
"""
|
||||
Microsoft Teams connector indexer.
|
||||
|
||||
Implements 2-phase document status updates for real-time UI feedback:
|
||||
Implements batch indexing: groups up to TEAMS_BATCH_SIZE messages per channel
|
||||
into a single document for efficient indexing and better conversational context.
|
||||
|
||||
Uses 2-phase document status updates for real-time UI feedback:
|
||||
- Phase 1: Create all documents with 'pending' status (visible in UI immediately)
|
||||
- Phase 2: Process each document: pending → processing → ready/failed
|
||||
"""
|
||||
|
|
@ -41,6 +44,72 @@ HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
|||
# Heartbeat interval in seconds - update notification every 30 seconds
|
||||
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
|
||||
# Number of messages to combine into a single document for batch indexing.
|
||||
# Grouping messages improves conversational context in embeddings/chunks and
|
||||
# drastically reduces the number of documents, embedding calls, and DB overhead.
|
||||
TEAMS_BATCH_SIZE = 100
|
||||
|
||||
|
||||
def _build_batch_document_string(
|
||||
team_name: str,
|
||||
team_id: str,
|
||||
channel_name: str,
|
||||
channel_id: str,
|
||||
messages: list[dict],
|
||||
) -> str:
|
||||
"""
|
||||
Combine multiple Teams messages into a single document string.
|
||||
|
||||
Each message is formatted with its timestamp and author, and all messages
|
||||
are concatenated into a conversation-style document. The chunker will
|
||||
later split this into overlapping windows of ~8-10 consecutive messages,
|
||||
preserving conversational context in each chunk's embedding.
|
||||
|
||||
Args:
|
||||
team_name: Name of the Microsoft Team
|
||||
team_id: ID of the Microsoft Team
|
||||
channel_name: Name of the channel
|
||||
channel_id: ID of the channel
|
||||
messages: List of formatted message dicts with 'user_name', 'created_datetime', 'content'
|
||||
|
||||
Returns:
|
||||
Formatted document string with metadata and conversation content
|
||||
"""
|
||||
first_msg_time = messages[0].get("created_datetime", "Unknown")
|
||||
last_msg_time = messages[-1].get("created_datetime", "Unknown")
|
||||
|
||||
metadata_lines = [
|
||||
f"TEAM_NAME: {team_name}",
|
||||
f"TEAM_ID: {team_id}",
|
||||
f"CHANNEL_NAME: {channel_name}",
|
||||
f"CHANNEL_ID: {channel_id}",
|
||||
f"MESSAGE_COUNT: {len(messages)}",
|
||||
f"FIRST_MESSAGE_TIME: {first_msg_time}",
|
||||
f"LAST_MESSAGE_TIME: {last_msg_time}",
|
||||
]
|
||||
|
||||
conversation_lines = []
|
||||
for msg in messages:
|
||||
author = msg.get("user_name", "Unknown User")
|
||||
timestamp = msg.get("created_datetime", "Unknown Time")
|
||||
content = msg.get("content", "")
|
||||
conversation_lines.append(f"[{timestamp}] {author}: {content}")
|
||||
|
||||
metadata_sections = [
|
||||
("METADATA", metadata_lines),
|
||||
(
|
||||
"CONTENT",
|
||||
[
|
||||
"FORMAT: markdown",
|
||||
"TEXT_START",
|
||||
"\n".join(conversation_lines),
|
||||
"TEXT_END",
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
return build_document_metadata_markdown(metadata_sections)
|
||||
|
||||
|
||||
async def index_teams_messages(
|
||||
session: AsyncSession,
|
||||
|
|
@ -55,6 +124,12 @@ async def index_teams_messages(
|
|||
"""
|
||||
Index Microsoft Teams messages from all accessible teams and channels.
|
||||
|
||||
Messages are grouped into batches of TEAMS_BATCH_SIZE per channel,
|
||||
so each document contains up to 100 consecutive messages with full
|
||||
conversational context. This reduces document count, embedding calls,
|
||||
and DB overhead by ~100x while improving search quality through
|
||||
context-aware chunk embeddings.
|
||||
|
||||
Implements 2-phase document status updates for real-time UI feedback:
|
||||
- Phase 1: Create all documents with 'pending' status (visible in UI immediately)
|
||||
- Phase 2: Process each document: pending → processing → ready/failed
|
||||
|
|
@ -184,6 +259,7 @@ async def index_teams_messages(
|
|||
documents_skipped = 0
|
||||
documents_failed = 0
|
||||
duplicate_content_count = 0
|
||||
total_messages_collected = 0
|
||||
skipped_channels = []
|
||||
|
||||
# Heartbeat tracking - update notification periodically to prevent appearing stuck
|
||||
|
|
@ -199,21 +275,21 @@ async def index_teams_messages(
|
|||
start_datetime = None
|
||||
end_datetime = None
|
||||
if start_date_str:
|
||||
# Parse as naive datetime and make it timezone-aware (UTC)
|
||||
start_datetime = datetime.strptime(start_date_str, "%Y-%m-%d").replace(
|
||||
tzinfo=UTC
|
||||
)
|
||||
if end_date_str:
|
||||
# Parse as naive datetime, set to end of day, and make it timezone-aware (UTC)
|
||||
end_datetime = datetime.strptime(end_date_str, "%Y-%m-%d").replace(
|
||||
hour=23, minute=59, second=59, tzinfo=UTC
|
||||
)
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 1: Collect all messages and create pending documents
|
||||
# This makes ALL documents visible in the UI immediately with pending status
|
||||
# PHASE 1: Collect messages, group into batches, and create pending documents
|
||||
# Messages are grouped into batches of TEAMS_BATCH_SIZE per channel.
|
||||
# Each batch becomes a single document with full conversational context.
|
||||
# All documents are visible in the UI immediately with pending status.
|
||||
# =======================================================================
|
||||
messages_to_process = [] # List of dicts with document and message data
|
||||
batches_to_process = [] # List of dicts with document and batch data
|
||||
new_documents_created = False
|
||||
|
||||
for team in teams:
|
||||
|
|
@ -251,65 +327,72 @@ async def index_teams_messages(
|
|||
)
|
||||
continue
|
||||
|
||||
# Process each message
|
||||
# Format messages for batching
|
||||
formatted_messages = []
|
||||
for msg in messages:
|
||||
# Skip deleted messages or empty content
|
||||
if msg.get("deletedDateTime"):
|
||||
continue
|
||||
|
||||
# Extract message details
|
||||
message_id = msg.get("id", "")
|
||||
created_datetime = msg.get("createdDateTime", "")
|
||||
from_user = msg.get("from", {})
|
||||
user_name = from_user.get("user", {}).get(
|
||||
"displayName", "Unknown User"
|
||||
)
|
||||
user_email = from_user.get("user", {}).get(
|
||||
"userPrincipalName", "Unknown Email"
|
||||
)
|
||||
|
||||
# Extract message content
|
||||
body = msg.get("body", {})
|
||||
content_type = body.get("contentType", "text")
|
||||
msg_text = body.get("content", "")
|
||||
|
||||
# Skip empty messages
|
||||
if not msg_text or msg_text.strip() == "":
|
||||
continue
|
||||
|
||||
# Format document metadata
|
||||
metadata_sections = [
|
||||
(
|
||||
"METADATA",
|
||||
[
|
||||
f"TEAM_NAME: {team_name}",
|
||||
f"TEAM_ID: {team_id}",
|
||||
f"CHANNEL_NAME: {channel_name}",
|
||||
f"CHANNEL_ID: {channel_id}",
|
||||
f"MESSAGE_TIMESTAMP: {created_datetime}",
|
||||
f"MESSAGE_USER_NAME: {user_name}",
|
||||
f"MESSAGE_USER_EMAIL: {user_email}",
|
||||
f"CONTENT_TYPE: {content_type}",
|
||||
],
|
||||
),
|
||||
(
|
||||
"CONTENT",
|
||||
[
|
||||
f"FORMAT: {content_type}",
|
||||
"TEXT_START",
|
||||
msg_text,
|
||||
"TEXT_END",
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
# Build the document string
|
||||
combined_document_string = build_document_metadata_markdown(
|
||||
metadata_sections
|
||||
formatted_messages.append(
|
||||
{
|
||||
"message_id": msg.get("id", ""),
|
||||
"created_datetime": msg.get("createdDateTime", ""),
|
||||
"user_name": user_name,
|
||||
"content": msg_text,
|
||||
}
|
||||
)
|
||||
|
||||
# Generate unique identifier hash for this Teams message
|
||||
unique_identifier = f"{team_id}_{channel_id}_{message_id}"
|
||||
if not formatted_messages:
|
||||
logger.info(
|
||||
"No valid messages found in channel %s of team %s after filtering.",
|
||||
channel_name,
|
||||
team_name,
|
||||
)
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
total_messages_collected += len(formatted_messages)
|
||||
|
||||
# =======================================================
|
||||
# Group messages into batches of TEAMS_BATCH_SIZE
|
||||
# Each batch becomes a single document with conversation context
|
||||
# =======================================================
|
||||
for batch_start in range(
|
||||
0, len(formatted_messages), TEAMS_BATCH_SIZE
|
||||
):
|
||||
batch = formatted_messages[
|
||||
batch_start : batch_start + TEAMS_BATCH_SIZE
|
||||
]
|
||||
|
||||
# Build combined document string from all messages in this batch
|
||||
combined_document_string = _build_batch_document_string(
|
||||
team_name=team_name,
|
||||
team_id=team_id,
|
||||
channel_name=channel_name,
|
||||
channel_id=channel_id,
|
||||
messages=batch,
|
||||
)
|
||||
|
||||
# Generate unique identifier for this batch using
|
||||
# team_id + channel_id + first message id + last message id
|
||||
first_msg_id = batch[0].get("message_id", "")
|
||||
last_msg_id = batch[-1].get("message_id", "")
|
||||
unique_identifier = (
|
||||
f"{team_id}_{channel_id}_{first_msg_id}_{last_msg_id}"
|
||||
)
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.TEAMS_CONNECTOR,
|
||||
unique_identifier,
|
||||
|
|
@ -331,7 +414,6 @@ async def index_teams_messages(
|
|||
if existing_document:
|
||||
# Document exists - check if content has changed
|
||||
if existing_document.content_hash == content_hash:
|
||||
# Ensure status is ready (might have been stuck in processing/pending)
|
||||
if not DocumentStatus.is_state(
|
||||
existing_document.status, DocumentStatus.READY
|
||||
):
|
||||
|
|
@ -342,7 +424,7 @@ async def index_teams_messages(
|
|||
continue
|
||||
|
||||
# Queue existing document for update (will be set to processing in Phase 2)
|
||||
messages_to_process.append(
|
||||
batches_to_process.append(
|
||||
{
|
||||
"document": existing_document,
|
||||
"is_new": False,
|
||||
|
|
@ -352,14 +434,21 @@ async def index_teams_messages(
|
|||
"team_id": team_id,
|
||||
"channel_name": channel_name,
|
||||
"channel_id": channel_id,
|
||||
"message_id": message_id,
|
||||
"first_message_id": first_msg_id,
|
||||
"last_message_id": last_msg_id,
|
||||
"first_message_time": batch[0].get(
|
||||
"created_datetime", "Unknown"
|
||||
),
|
||||
"last_message_time": batch[-1].get(
|
||||
"created_datetime", "Unknown"
|
||||
),
|
||||
"message_count": len(batch),
|
||||
"start_date": start_date_str,
|
||||
"end_date": end_date_str,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Document doesn't exist by unique_identifier_hash
|
||||
# Check if a document with the same content_hash exists (from another connector)
|
||||
with session.no_autoflush:
|
||||
duplicate_by_content = (
|
||||
|
|
@ -370,9 +459,10 @@ async def index_teams_messages(
|
|||
|
||||
if duplicate_by_content:
|
||||
logger.info(
|
||||
"Teams message %s in channel %s already indexed by another connector "
|
||||
"Teams batch (%s msgs) in %s/%s already indexed by another connector "
|
||||
"(existing document ID: %s, type: %s). Skipping.",
|
||||
message_id,
|
||||
len(batch),
|
||||
team_name,
|
||||
channel_name,
|
||||
duplicate_by_content.id,
|
||||
duplicate_by_content.document_type,
|
||||
|
|
@ -391,6 +481,9 @@ async def index_teams_messages(
|
|||
"team_id": team_id,
|
||||
"channel_name": channel_name,
|
||||
"channel_id": channel_id,
|
||||
"first_message_id": first_msg_id,
|
||||
"last_message_id": last_msg_id,
|
||||
"message_count": len(batch),
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
content="Pending...", # Placeholder until processed
|
||||
|
|
@ -406,7 +499,7 @@ async def index_teams_messages(
|
|||
session.add(document)
|
||||
new_documents_created = True
|
||||
|
||||
messages_to_process.append(
|
||||
batches_to_process.append(
|
||||
{
|
||||
"document": document,
|
||||
"is_new": True,
|
||||
|
|
@ -416,12 +509,30 @@ async def index_teams_messages(
|
|||
"team_id": team_id,
|
||||
"channel_name": channel_name,
|
||||
"channel_id": channel_id,
|
||||
"message_id": message_id,
|
||||
"first_message_id": first_msg_id,
|
||||
"last_message_id": last_msg_id,
|
||||
"first_message_time": batch[0].get(
|
||||
"created_datetime", "Unknown"
|
||||
),
|
||||
"last_message_time": batch[-1].get(
|
||||
"created_datetime", "Unknown"
|
||||
),
|
||||
"message_count": len(batch),
|
||||
"start_date": start_date_str,
|
||||
"end_date": end_date_str,
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Phase 1: Collected %s messages from %s/%s, "
|
||||
"grouped into %s batch(es)",
|
||||
len(formatted_messages),
|
||||
team_name,
|
||||
channel_name,
|
||||
(len(formatted_messages) + TEAMS_BATCH_SIZE - 1)
|
||||
// TEAMS_BATCH_SIZE,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error processing channel %s in team %s: %s",
|
||||
|
|
@ -441,17 +552,20 @@ async def index_teams_messages(
|
|||
# Commit all pending documents - they all appear in UI now
|
||||
if new_documents_created:
|
||||
logger.info(
|
||||
f"Phase 1: Committing {len([m for m in messages_to_process if m['is_new']])} pending documents"
|
||||
"Phase 1: Committing %s pending batch documents "
|
||||
"(%s total messages across all channels)",
|
||||
len([b for b in batches_to_process if b["is_new"]]),
|
||||
total_messages_collected,
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# =======================================================================
|
||||
# PHASE 2: Process each document one by one
|
||||
# PHASE 2: Process each batch document one by one
|
||||
# Each document transitions: pending → processing → ready/failed
|
||||
# =======================================================================
|
||||
logger.info(f"Phase 2: Processing {len(messages_to_process)} documents")
|
||||
logger.info("Phase 2: Processing %s batch documents", len(batches_to_process))
|
||||
|
||||
for item in messages_to_process:
|
||||
for item in batches_to_process:
|
||||
# Send heartbeat periodically
|
||||
if on_heartbeat_callback:
|
||||
current_time = time.time()
|
||||
|
|
@ -481,6 +595,11 @@ async def index_teams_messages(
|
|||
"team_id": item["team_id"],
|
||||
"channel_name": item["channel_name"],
|
||||
"channel_id": item["channel_id"],
|
||||
"first_message_id": item["first_message_id"],
|
||||
"last_message_id": item["last_message_id"],
|
||||
"first_message_time": item["first_message_time"],
|
||||
"last_message_time": item["last_message_time"],
|
||||
"message_count": item["message_count"],
|
||||
"start_date": item["start_date"],
|
||||
"end_date": item["end_date"],
|
||||
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
|
|
@ -495,20 +614,25 @@ async def index_teams_messages(
|
|||
# Batch commit every 10 documents (for ready status updates)
|
||||
if documents_indexed % 10 == 0:
|
||||
logger.info(
|
||||
"Committing batch: %s Teams messages processed so far",
|
||||
"Committing batch: %s Teams batch documents processed so far",
|
||||
documents_indexed,
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Teams message: {e!s}", exc_info=True)
|
||||
logger.error(
|
||||
"Error processing Teams batch document: %s",
|
||||
str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
# Mark document as failed with reason (visible in UI)
|
||||
try:
|
||||
document.status = DocumentStatus.failed(str(e))
|
||||
document.updated_at = get_current_timestamp()
|
||||
except Exception as status_error:
|
||||
logger.error(
|
||||
f"Failed to update document status to failed: {status_error}"
|
||||
"Failed to update document status to failed: %s",
|
||||
str(status_error),
|
||||
)
|
||||
documents_failed += 1
|
||||
continue
|
||||
|
|
@ -518,7 +642,9 @@ async def index_teams_messages(
|
|||
|
||||
# Final commit for any remaining documents not yet committed in batches
|
||||
logger.info(
|
||||
"Final commit: Total %s Teams messages processed", documents_indexed
|
||||
"Final commit: Total %s Teams batch documents processed (from %s messages)",
|
||||
documents_indexed,
|
||||
total_messages_collected,
|
||||
)
|
||||
try:
|
||||
await session.commit()
|
||||
|
|
@ -530,8 +656,9 @@ async def index_teams_messages(
|
|||
or "uniqueviolationerror" in str(e).lower()
|
||||
):
|
||||
logger.warning(
|
||||
f"Duplicate content_hash detected during final commit. "
|
||||
f"Rolling back and continuing. Error: {e!s}"
|
||||
"Duplicate content_hash detected during final commit. "
|
||||
"Rolling back and continuing. Error: %s",
|
||||
str(e),
|
||||
)
|
||||
await session.rollback()
|
||||
else:
|
||||
|
|
@ -557,13 +684,16 @@ async def index_teams_messages(
|
|||
"documents_failed": documents_failed,
|
||||
"duplicate_content_count": duplicate_content_count,
|
||||
"skipped_channels_count": len(skipped_channels),
|
||||
"total_messages_collected": total_messages_collected,
|
||||
"batch_size": TEAMS_BATCH_SIZE,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Teams indexing completed: %s ready, %s skipped, %s failed "
|
||||
"(%s duplicate content)",
|
||||
"Teams indexing completed: %s batch docs ready (from %s messages), "
|
||||
"%s skipped, %s failed (%s duplicate content)",
|
||||
documents_indexed,
|
||||
total_messages_collected,
|
||||
documents_skipped,
|
||||
documents_failed,
|
||||
duplicate_content_count,
|
||||
|
|
|
|||
|
|
@ -38,10 +38,22 @@ def safe_set_chunks(document: Document, chunks: list) -> None:
|
|||
# Instead of: document.chunks = chunks (DANGEROUS!)
|
||||
safe_set_chunks(document, chunks) # Always safe
|
||||
"""
|
||||
from sqlalchemy.orm import object_session
|
||||
from sqlalchemy.orm.attributes import set_committed_value
|
||||
|
||||
# Keep relationship assignment lazy-load-safe.
|
||||
set_committed_value(document, "chunks", chunks)
|
||||
|
||||
# Ensure chunk rows are actually persisted.
|
||||
# set_committed_value bypasses normal unit-of-work tracking, so we need to
|
||||
# explicitly attach chunk objects to the current session.
|
||||
session = object_session(document)
|
||||
if session is not None:
|
||||
if document.id is not None:
|
||||
for chunk in chunks:
|
||||
chunk.document_id = document.id
|
||||
session.add_all(chunks)
|
||||
|
||||
|
||||
def get_current_timestamp() -> datetime:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -9,9 +9,17 @@ Message content in new_chat_messages can be stored in various formats:
|
|||
These utilities help extract and transform content for different use cases.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.db import ChatVisibility
|
||||
|
||||
|
||||
def extract_text_content(content: str | dict | list) -> str:
|
||||
|
|
@ -38,6 +46,7 @@ def extract_text_content(content: str | dict | list) -> str:
|
|||
async def bootstrap_history_from_db(
|
||||
session: AsyncSession,
|
||||
thread_id: int,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
) -> list[HumanMessage | AIMessage]:
|
||||
"""
|
||||
Load message history from database and convert to LangChain format.
|
||||
|
|
@ -45,20 +54,28 @@ async def bootstrap_history_from_db(
|
|||
Used for cloned chats where the LangGraph checkpointer has no state,
|
||||
but we have messages in the database that should be used as context.
|
||||
|
||||
When thread_visibility is SEARCH_SPACE, user messages are prefixed with
|
||||
the author's display name so the LLM sees who said what.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
thread_id: The chat thread ID
|
||||
thread_visibility: When SEARCH_SPACE, user messages get author prefix
|
||||
|
||||
Returns:
|
||||
List of LangChain messages (HumanMessage/AIMessage)
|
||||
"""
|
||||
from app.db import NewChatMessage
|
||||
from app.db import ChatVisibility, NewChatMessage
|
||||
|
||||
result = await session.execute(
|
||||
is_shared = thread_visibility == ChatVisibility.SEARCH_SPACE
|
||||
stmt = (
|
||||
select(NewChatMessage)
|
||||
.filter(NewChatMessage.thread_id == thread_id)
|
||||
.order_by(NewChatMessage.created_at)
|
||||
)
|
||||
if is_shared:
|
||||
stmt = stmt.options(selectinload(NewChatMessage.author))
|
||||
result = await session.execute(stmt)
|
||||
db_messages = result.scalars().all()
|
||||
|
||||
langchain_messages: list[HumanMessage | AIMessage] = []
|
||||
|
|
@ -68,6 +85,11 @@ async def bootstrap_history_from_db(
|
|||
if not text_content:
|
||||
continue
|
||||
if msg.role == "user":
|
||||
if is_shared:
|
||||
author_name = (
|
||||
msg.author.display_name if msg.author else None
|
||||
) or "A team member"
|
||||
text_content = f"**[{author_name}]:** {text_content}"
|
||||
langchain_messages.append(HumanMessage(content=text_content))
|
||||
elif msg.role == "assistant":
|
||||
langchain_messages.append(AIMessage(content=text_content))
|
||||
|
|
|
|||
46
surfsense_backend/app/utils/indexing_locks.py
Normal file
46
surfsense_backend/app/utils/indexing_locks.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
"""Redis-based connector indexing locks to prevent duplicate sync tasks."""
|
||||
|
||||
import redis
|
||||
|
||||
from app.config import config
|
||||
|
||||
_redis_client: redis.Redis | None = None
|
||||
LOCK_TTL_SECONDS = config.CONNECTOR_INDEXING_LOCK_TTL_SECONDS
|
||||
|
||||
|
||||
def get_indexing_lock_redis_client() -> redis.Redis:
|
||||
"""Get or create Redis client for connector indexing locks."""
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
_redis_client = redis.from_url(config.REDIS_APP_URL, decode_responses=True)
|
||||
return _redis_client
|
||||
|
||||
|
||||
def _get_connector_lock_key(connector_id: int) -> str:
|
||||
"""Generate Redis key for a connector indexing lock."""
|
||||
return f"indexing:connector_lock:{connector_id}"
|
||||
|
||||
|
||||
def acquire_connector_indexing_lock(connector_id: int) -> bool:
|
||||
"""Acquire lock for connector indexing. Returns True if acquired."""
|
||||
key = _get_connector_lock_key(connector_id)
|
||||
return bool(
|
||||
get_indexing_lock_redis_client().set(
|
||||
key,
|
||||
"1",
|
||||
nx=True,
|
||||
ex=LOCK_TTL_SECONDS,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def release_connector_indexing_lock(connector_id: int) -> None:
|
||||
"""Release lock for connector indexing."""
|
||||
key = _get_connector_lock_key(connector_id)
|
||||
get_indexing_lock_redis_client().delete(key)
|
||||
|
||||
|
||||
def is_connector_indexing_locked(connector_id: int) -> bool:
|
||||
"""Check if connector indexing lock exists."""
|
||||
key = _get_connector_lock_key(connector_id)
|
||||
return bool(get_indexing_lock_redis_client().exists(key))
|
||||
Loading…
Add table
Add a link
Reference in a new issue