mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-19 18:45:15 +02:00
feat: migrated old chat to new chat
This commit is contained in:
parent
b5e20e7515
commit
bb971460fc
25 changed files with 368 additions and 4391 deletions
|
|
@ -0,0 +1,216 @@
|
|||
"""Migrate old chats to new_chat_threads and remove old tables
|
||||
|
||||
Revision ID: 49
|
||||
Revises: 48
|
||||
Create Date: 2025-12-21
|
||||
|
||||
This migration:
|
||||
1. Migrates data from old 'chats' table to 'new_chat_threads' and 'new_chat_messages'
|
||||
2. Drops the 'podcasts' table (podcast data is not migrated as per user request)
|
||||
3. Drops the 'chats' table
|
||||
4. Removes the 'chattype' enum
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "49"
|
||||
down_revision: str | None = "48"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def extract_text_content(content: str | dict | list) -> str:
|
||||
"""Extract plain text content from various message formats."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, dict):
|
||||
# Handle dict with 'text' key
|
||||
if "text" in content:
|
||||
return content["text"]
|
||||
return str(content)
|
||||
if isinstance(content, list):
|
||||
# Handle list of parts (e.g., [{"type": "text", "text": "..."}])
|
||||
texts = []
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
texts.append(part.get("text", ""))
|
||||
elif isinstance(part, str):
|
||||
texts.append(part)
|
||||
return "\n".join(texts) if texts else ""
|
||||
return ""
|
||||
|
||||
|
||||
def parse_timestamp(ts, fallback):
|
||||
"""Parse ISO timestamp string to datetime object."""
|
||||
if ts is None:
|
||||
return fallback
|
||||
if isinstance(ts, datetime):
|
||||
return ts
|
||||
if isinstance(ts, str):
|
||||
try:
|
||||
# Handle ISO format like '2025-11-26T22:43:34.399Z'
|
||||
ts = ts.replace("Z", "+00:00")
|
||||
return datetime.fromisoformat(ts)
|
||||
except (ValueError, TypeError):
|
||||
return fallback
|
||||
return fallback
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Migrate old chats to new_chat_threads and remove old tables."""
|
||||
connection = op.get_bind()
|
||||
|
||||
# Get all old chats
|
||||
old_chats = connection.execute(
|
||||
sa.text("""
|
||||
SELECT id, title, messages, search_space_id, created_at
|
||||
FROM chats
|
||||
ORDER BY created_at ASC
|
||||
""")
|
||||
).fetchall()
|
||||
|
||||
print(f"[Migration 49] Found {len(old_chats)} old chats to migrate")
|
||||
|
||||
migrated_count = 0
|
||||
for chat_id, title, messages_json, search_space_id, created_at in old_chats:
|
||||
try:
|
||||
# Parse messages JSON
|
||||
if isinstance(messages_json, str):
|
||||
messages = json.loads(messages_json)
|
||||
else:
|
||||
messages = messages_json or []
|
||||
|
||||
# Skip empty chats
|
||||
if not messages:
|
||||
print(f"[Migration 49] Skipping empty chat {chat_id}")
|
||||
continue
|
||||
|
||||
# Create new thread
|
||||
result = connection.execute(
|
||||
sa.text("""
|
||||
INSERT INTO new_chat_threads
|
||||
(title, archived, search_space_id, created_at, updated_at)
|
||||
VALUES (:title, FALSE, :search_space_id, :created_at, :created_at)
|
||||
RETURNING id
|
||||
"""),
|
||||
{
|
||||
"title": title or "Migrated Chat",
|
||||
"search_space_id": search_space_id,
|
||||
"created_at": created_at,
|
||||
},
|
||||
)
|
||||
new_thread_id = result.fetchone()[0]
|
||||
|
||||
# Migrate messages - only user and assistant roles, skip SOURCES/TERMINAL_INFO
|
||||
message_count = 0
|
||||
for msg in messages:
|
||||
role_lower = msg.get("role", "").lower()
|
||||
|
||||
# Only migrate user and assistant messages
|
||||
if role_lower not in ("user", "assistant"):
|
||||
continue
|
||||
|
||||
# Convert to uppercase for database enum
|
||||
role = role_lower.upper()
|
||||
|
||||
# Extract content - handle various formats
|
||||
content_raw = msg.get("content", "")
|
||||
content_text = extract_text_content(content_raw)
|
||||
|
||||
# Skip empty messages
|
||||
if not content_text.strip():
|
||||
continue
|
||||
|
||||
# Parse message timestamp
|
||||
msg_created_at = parse_timestamp(msg.get("createdAt"), created_at)
|
||||
|
||||
# Store content as JSONB array format for assistant-ui compatibility
|
||||
content_list = [{"type": "text", "text": content_text}]
|
||||
|
||||
# Use direct SQL with string interpolation for the enum since CAST doesn't work
|
||||
# The enum value comes from trusted source (our own code), not user input
|
||||
connection.execute(
|
||||
sa.text(f"""
|
||||
INSERT INTO new_chat_messages
|
||||
(thread_id, role, content, created_at)
|
||||
VALUES (:thread_id, '{role}', CAST(:content AS jsonb), :created_at)
|
||||
"""),
|
||||
{
|
||||
"thread_id": new_thread_id,
|
||||
"content": json.dumps(content_list),
|
||||
"created_at": msg_created_at,
|
||||
},
|
||||
)
|
||||
message_count += 1
|
||||
|
||||
print(
|
||||
f"[Migration 49] Migrated chat {chat_id} -> thread {new_thread_id} ({message_count} messages)"
|
||||
)
|
||||
migrated_count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Migration 49] Error migrating chat {chat_id}: {e}")
|
||||
# Re-raise to abort migration - we don't want partial data
|
||||
raise
|
||||
|
||||
print(f"[Migration 49] Successfully migrated {migrated_count} chats")
|
||||
|
||||
# Drop podcasts table (FK references chats, so drop first)
|
||||
print("[Migration 49] Dropping podcasts table...")
|
||||
op.drop_table("podcasts")
|
||||
|
||||
# Drop chats table
|
||||
print("[Migration 49] Dropping chats table...")
|
||||
op.drop_table("chats")
|
||||
|
||||
# Drop chattype enum
|
||||
print("[Migration 49] Dropping chattype enum...")
|
||||
op.execute(sa.text("DROP TYPE IF EXISTS chattype"))
|
||||
|
||||
print("[Migration 49] Migration complete!")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Recreate old tables (data cannot be restored)."""
|
||||
# Recreate chattype enum
|
||||
op.execute(
|
||||
sa.text("""
|
||||
CREATE TYPE chattype AS ENUM ('QNA')
|
||||
""")
|
||||
)
|
||||
|
||||
# Recreate chats table
|
||||
op.create_table(
|
||||
"chats",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||
sa.Column("type", sa.Enum("QNA", name="chattype"), nullable=False),
|
||||
sa.Column("title", sa.String(), nullable=False, index=True),
|
||||
sa.Column("initial_connectors", sa.ARRAY(sa.String()), nullable=True),
|
||||
sa.Column("messages", sa.JSON(), nullable=False),
|
||||
sa.Column("state_version", sa.BigInteger(), nullable=False, default=1),
|
||||
sa.Column("search_space_id", sa.Integer(), sa.ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("created_at", sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
# Recreate podcasts table
|
||||
op.create_table(
|
||||
"podcasts",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||
sa.Column("title", sa.String(), nullable=False, index=True),
|
||||
sa.Column("podcast_transcript", sa.JSON(), nullable=False, server_default="{}"),
|
||||
sa.Column("file_location", sa.String(500), nullable=False, server_default=""),
|
||||
sa.Column("chat_id", sa.Integer(), sa.ForeignKey("chats.id", ondelete="CASCADE"), nullable=True),
|
||||
sa.Column("chat_state_version", sa.BigInteger(), nullable=True),
|
||||
sa.Column("search_space_id", sa.Integer(), sa.ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("created_at", sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
print("[Migration 49 Downgrade] Tables recreated (data not restored)")
|
||||
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
"""Define the configurable parameters for the agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Configuration:
|
||||
"""The configuration for the agent."""
|
||||
|
||||
# Input parameters provided at invocation
|
||||
user_query: str
|
||||
connectors_to_search: list[str]
|
||||
user_id: str
|
||||
search_space_id: int
|
||||
document_ids_to_add_in_context: list[int]
|
||||
language: str | None = None
|
||||
top_k: int = 10
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(
|
||||
cls, config: RunnableConfig | None = None
|
||||
) -> Configuration:
|
||||
"""Create a Configuration instance from a RunnableConfig object."""
|
||||
configurable = (config.get("configurable") or {}) if config else {}
|
||||
_fields = {f.name for f in fields(cls) if f.init}
|
||||
return cls(**{k: v for k, v in configurable.items() if k in _fields})
|
||||
|
|
@ -1,47 +0,0 @@
|
|||
from langgraph.graph import StateGraph
|
||||
|
||||
from .configuration import Configuration
|
||||
from .nodes import (
|
||||
generate_further_questions,
|
||||
handle_qna_workflow,
|
||||
reformulate_user_query,
|
||||
)
|
||||
from .state import State
|
||||
|
||||
|
||||
def build_graph():
|
||||
"""
|
||||
Build and return the LangGraph workflow.
|
||||
|
||||
This function constructs the researcher agent graph for Q&A workflow.
|
||||
The workflow follows a simple path:
|
||||
1. Reformulate user query based on chat history
|
||||
2. Handle QNA workflow (fetch documents and generate answer)
|
||||
3. Generate follow-up questions
|
||||
|
||||
Returns:
|
||||
A compiled LangGraph workflow
|
||||
"""
|
||||
# Define a new graph with state class
|
||||
workflow = StateGraph(State, config_schema=Configuration)
|
||||
|
||||
# Add nodes to the graph
|
||||
workflow.add_node("reformulate_user_query", reformulate_user_query)
|
||||
workflow.add_node("handle_qna_workflow", handle_qna_workflow)
|
||||
workflow.add_node("generate_further_questions", generate_further_questions)
|
||||
|
||||
# Define the edges - simple linear flow for QNA
|
||||
workflow.add_edge("__start__", "reformulate_user_query")
|
||||
workflow.add_edge("reformulate_user_query", "handle_qna_workflow")
|
||||
workflow.add_edge("handle_qna_workflow", "generate_further_questions")
|
||||
workflow.add_edge("generate_further_questions", "__end__")
|
||||
|
||||
# Compile the workflow into an executable graph
|
||||
graph = workflow.compile()
|
||||
graph.name = "Surfsense Researcher" # This defines the custom name in LangSmith
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
# Compile the graph once when the module is loaded
|
||||
graph = build_graph()
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,140 +0,0 @@
|
|||
import datetime
|
||||
|
||||
|
||||
def _build_language_instruction(language: str | None = None):
|
||||
"""Build language instruction for prompts."""
|
||||
if language:
|
||||
return f"\n\nIMPORTANT: Please respond in {language} language. All your responses, explanations, and analysis should be written in {language}."
|
||||
return ""
|
||||
|
||||
|
||||
def get_further_questions_system_prompt():
|
||||
return f"""
|
||||
Today's date: {datetime.datetime.now().strftime("%Y-%m-%d")}
|
||||
<further_questions_system>
|
||||
You are an expert research assistant specializing in generating contextually relevant follow-up questions. Your task is to analyze the chat history and available documents to suggest further questions that would naturally extend the conversation and provide additional value to the user.
|
||||
|
||||
<input>
|
||||
- chat_history: Provided in XML format within <chat_history> tags, containing <user> and <assistant> message pairs that show the chronological conversation flow. This provides context about what has already been discussed.
|
||||
- available_documents: Provided in XML format within <documents> tags, containing individual <document> elements with <document_metadata> and <document_content> sections. Each document contains multiple `<chunk id='...'>...</chunk>` blocks inside <document_content>. This helps understand what information is accessible for answering potential follow-up questions.
|
||||
</input>
|
||||
|
||||
<output_format>
|
||||
A JSON object with the following structure:
|
||||
{{
|
||||
"further_questions": [
|
||||
{{
|
||||
"id": 0,
|
||||
"question": "further qn 1"
|
||||
}},
|
||||
{{
|
||||
"id": 1,
|
||||
"question": "further qn 2"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
</output_format>
|
||||
|
||||
<instructions>
|
||||
1. **Analyze Chat History:** Review the entire conversation flow to understand:
|
||||
* The main topics and themes discussed
|
||||
* The user's interests and areas of focus
|
||||
* Questions that have been asked and answered
|
||||
* Any gaps or areas that could be explored further
|
||||
* The depth level of the current discussion
|
||||
|
||||
2. **Evaluate Available Documents:** Consider the documents in context to identify:
|
||||
* Additional information that hasn't been explored yet
|
||||
* Related topics that could be of interest
|
||||
* Specific details or data points that could warrant deeper investigation
|
||||
* Cross-references or connections between different documents
|
||||
|
||||
3. **Generate Relevant Follow-up Questions:** Create 3-5 further questions that:
|
||||
* Are directly related to the ongoing conversation but explore new angles
|
||||
* Can be reasonably answered using the available documents or knowledge base
|
||||
* Progress the conversation forward rather than repeating previous topics
|
||||
* Match the user's apparent level of interest and expertise
|
||||
* Are specific and actionable rather than overly broad
|
||||
* Consider practical applications, comparisons, deeper analysis, or related concepts
|
||||
|
||||
4. **Ensure Question Quality:** Each question should:
|
||||
* Be clear and well-formulated
|
||||
* Provide genuine value to the user
|
||||
* Be distinct from other suggested questions
|
||||
* Be answerable within the current context
|
||||
* Encourage meaningful exploration of the topic
|
||||
|
||||
5. **Prioritize and Order:** Arrange questions by relevance and natural progression:
|
||||
* Most directly related to the current discussion first
|
||||
* Questions that build upon previous answers
|
||||
* Questions that explore practical applications or implications
|
||||
* Questions that introduce related but new concepts
|
||||
|
||||
6. **Adhere Strictly to Output Format:** Ensure the final output is a valid JSON object with:
|
||||
* Correct field names (`further_questions`, `id`, `question`)
|
||||
* Sequential numbering starting from 0
|
||||
* Proper data types and JSON formatting
|
||||
</instructions>
|
||||
|
||||
<examples>
|
||||
Input:
|
||||
<chat_history>
|
||||
<user>What are the best machine learning algorithms for text classification?</user>
|
||||
<assistant>For text classification, several algorithms work well depending on your specific needs:
|
||||
|
||||
**Traditional Methods:**
|
||||
- **Support Vector Machines (SVM)** - Excellent for high-dimensional text data
|
||||
- **Naive Bayes** - Simple, fast, and works well with small datasets
|
||||
- **Logistic Regression** - Good baseline with interpretable results
|
||||
|
||||
**Modern Deep Learning:**
|
||||
- **Neural Networks** - More complex but can capture intricate patterns
|
||||
- **Transformer models** - State-of-the-art for most text classification tasks
|
||||
|
||||
The choice depends on your dataset size, computational resources, and accuracy requirements.</assistant>
|
||||
</chat_history>
|
||||
|
||||
<documents>
|
||||
<document>
|
||||
<metadata>
|
||||
<source_id>101</source_id>
|
||||
<source_type>FILE</source_type>
|
||||
</metadata>
|
||||
<content>
|
||||
# Machine Learning for Text Classification: A Comprehensive Guide
|
||||
|
||||
## Performance Comparison
|
||||
Recent studies show that transformer-based models achieve 95%+ accuracy on most text classification benchmarks, while traditional methods like SVM typically achieve 85-90% accuracy.
|
||||
|
||||
## Dataset Considerations
|
||||
- Small datasets (< 1000 samples): Naive Bayes, SVM
|
||||
- Large datasets (> 10,000 samples): Neural networks, transformers
|
||||
- Imbalanced datasets: Require special handling with techniques like SMOTE
|
||||
</content>
|
||||
</document>
|
||||
</documents>
|
||||
|
||||
Output:
|
||||
{{
|
||||
"further_questions": [
|
||||
{{
|
||||
"id": 0,
|
||||
"question": "What are the key differences in performance between traditional algorithms like SVM and modern deep learning approaches for text classification?"
|
||||
}},
|
||||
{{
|
||||
"id": 1,
|
||||
"question": "How do you handle imbalanced datasets when training text classification models?"
|
||||
}},
|
||||
{{
|
||||
"id": 2,
|
||||
"question": "What preprocessing techniques are most effective for improving text classification accuracy?"
|
||||
}},
|
||||
{{
|
||||
"id": 3,
|
||||
"question": "Are there specific domains or use cases where certain classification algorithms perform better than others?"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
</examples>
|
||||
</further_questions_system>
|
||||
"""
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
"""QnA Agent."""
|
||||
|
||||
from .graph import graph
|
||||
|
||||
__all__ = ["graph"]
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
"""Define the configurable parameters for the agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Configuration:
|
||||
"""The configuration for the Q&A agent."""
|
||||
|
||||
# Configuration parameters for the Q&A agent
|
||||
user_query: str # The user's question to answer
|
||||
reformulated_query: str # The reformulated query
|
||||
relevant_documents: list[
|
||||
Any
|
||||
] # Documents provided directly to the agent for answering
|
||||
search_space_id: int # Search space identifier
|
||||
language: str | None = None # Language for responses
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(
|
||||
cls, config: RunnableConfig | None = None
|
||||
) -> Configuration:
|
||||
"""Create a Configuration instance from a RunnableConfig object."""
|
||||
configurable = (config.get("configurable") or {}) if config else {}
|
||||
_fields = {f.name for f in fields(cls) if f.init}
|
||||
return cls(**{k: v for k, v in configurable.items() if k in _fields})
|
||||
|
|
@ -1,201 +0,0 @@
|
|||
"""Default system prompts for Q&A agent.
|
||||
|
||||
The prompt system is modular with 3 parts:
|
||||
- Part 1 (Base): Core instructions for answering questions (no citations)
|
||||
- Part 2 (Citations): Citation-specific instructions and formatting rules
|
||||
- Part 3 (Custom): User's custom instructions (empty by default)
|
||||
|
||||
Combinations:
|
||||
- Part 1 only: Answers without citations
|
||||
- Part 1 + Part 2: Answers with citations
|
||||
- Part 1 + Part 2 + Part 3: Answers with citations and custom instructions
|
||||
"""
|
||||
|
||||
# Part 1: Base system prompt for answering without citations
|
||||
DEFAULT_QNA_BASE_PROMPT = """Today's date: {date}
|
||||
You are SurfSense, an advanced AI research assistant that provides detailed, well-researched answers to user questions by synthesizing information from multiple personal knowledge sources.{language_instruction}
|
||||
{chat_history_section}
|
||||
<knowledge_sources>
|
||||
- EXTENSION: "Web content saved via SurfSense browser extension" (personal browsing history)
|
||||
- FILE: "User-uploaded documents (PDFs, Word, etc.)" (personal files)
|
||||
- SLACK_CONNECTOR: "Slack conversations and shared content" (personal workspace communications)
|
||||
- NOTION_CONNECTOR: "Notion workspace pages and databases" (personal knowledge management)
|
||||
- YOUTUBE_VIDEO: "YouTube video transcripts and metadata" (personally saved videos)
|
||||
- GITHUB_CONNECTOR: "GitHub repository content and issues" (personal repositories and interactions)
|
||||
- ELASTICSEARCH_CONNECTOR: "Elasticsearch indexed documents and data" (personal Elasticsearch instances and custom data sources)
|
||||
- LINEAR_CONNECTOR: "Linear project issues and discussions" (personal project management)
|
||||
- JIRA_CONNECTOR: "Jira project issues, tickets, and comments" (personal project tracking)
|
||||
- CONFLUENCE_CONNECTOR: "Confluence pages and comments" (personal project documentation)
|
||||
- CLICKUP_CONNECTOR: "ClickUp tasks and project data" (personal task management)
|
||||
- GOOGLE_CALENDAR_CONNECTOR: "Google Calendar events, meetings, and schedules" (personal calendar and time management)
|
||||
- GOOGLE_GMAIL_CONNECTOR: "Google Gmail emails and conversations" (personal emails and communications)
|
||||
- DISCORD_CONNECTOR: "Discord server conversations and shared content" (personal community communications)
|
||||
- AIRTABLE_CONNECTOR: "Airtable records, tables, and database content" (personal data management and organization)
|
||||
- TAVILY_API: "Tavily search API results" (personalized search results)
|
||||
- LINKUP_API: "Linkup search API results" (personalized search results)
|
||||
- LUMA_CONNECTOR: "Luma events"
|
||||
- WEBCRAWLER_CONNECTOR: "Webpages indexed by SurfSense" (personally selected websites)
|
||||
</knowledge_sources>
|
||||
|
||||
<instructions>
|
||||
1. Review the chat history to understand the conversation context and any previous topics discussed.
|
||||
2. Carefully analyze all provided documents in the <document> sections.
|
||||
3. Extract relevant information that directly addresses the user's question.
|
||||
4. Provide a comprehensive, detailed answer using information from the user's personal knowledge sources.
|
||||
5. Structure your answer logically and conversationally, as if having a detailed discussion with the user.
|
||||
6. Use your own words to synthesize and connect ideas from the documents.
|
||||
7. If documents contain conflicting information, acknowledge this and present both perspectives.
|
||||
8. If the user's question cannot be fully answered with the provided documents, clearly state what information is missing.
|
||||
9. Provide actionable insights and practical information when relevant to the user's question.
|
||||
10. Use the chat history to maintain conversation continuity and refer to previous discussions when relevant.
|
||||
11. Remember that all knowledge sources contain personal information - provide answers that reflect this personal context.
|
||||
12. Be conversational and engaging while maintaining accuracy.
|
||||
</instructions>
|
||||
|
||||
<format>
|
||||
- Write in a clear, conversational tone suitable for detailed Q&A discussions
|
||||
- Provide comprehensive answers that thoroughly address the user's question
|
||||
- Use appropriate paragraphs and structure for readability
|
||||
- ALWAYS provide personalized answers that reflect the user's own knowledge and context
|
||||
- Be thorough and detailed in your explanations while remaining focused on the user's specific question
|
||||
- If asking follow-up questions would be helpful, suggest them at the end of your response
|
||||
</format>
|
||||
|
||||
<user_query_instructions>
|
||||
When you see a user query, focus exclusively on providing a detailed, comprehensive answer using information from the provided documents, which contain the user's personal knowledge and data.
|
||||
|
||||
Make sure your response:
|
||||
1. Considers the chat history for context and conversation continuity
|
||||
2. Directly and thoroughly answers the user's question with personalized information from their own knowledge sources
|
||||
3. Is conversational, engaging, and detailed
|
||||
4. Acknowledges the personal nature of the information being provided
|
||||
5. Offers follow-up suggestions when appropriate
|
||||
</user_query_instructions>
|
||||
"""
|
||||
|
||||
# Part 2: Citation-specific instructions to add citation capabilities
|
||||
DEFAULT_QNA_CITATION_INSTRUCTIONS = """
|
||||
<citation_instructions>
|
||||
CRITICAL CITATION REQUIREMENTS:
|
||||
|
||||
1. For EVERY piece of information you include from the documents, add a citation in the format [citation:chunk_id] where chunk_id is the exact value from the `<chunk id='...'>` tag inside `<document_content>`.
|
||||
2. Make sure ALL factual statements from the documents have proper citations.
|
||||
3. If multiple chunks support the same point, include all relevant citations [citation:chunk_id1], [citation:chunk_id2].
|
||||
4. You MUST use the exact chunk_id values from the `<chunk id='...'>` attributes. Do not create your own citation numbers.
|
||||
5. Every citation MUST be in the format [citation:chunk_id] where chunk_id is the exact chunk id value.
|
||||
6. Never modify or change the chunk_id - always use the original values exactly as provided in the chunk tags.
|
||||
7. Do not return citations as clickable links.
|
||||
8. Never format citations as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only.
|
||||
9. Citations must ONLY appear as [citation:chunk_id] or [citation:chunk_id1], [citation:chunk_id2] format - never with parentheses, hyperlinks, or other formatting.
|
||||
10. Never make up chunk IDs. Only use chunk_id values that are explicitly provided in the `<chunk id='...'>` tags.
|
||||
11. If you are unsure about a chunk_id, do not include a citation rather than guessing or making one up.
|
||||
|
||||
<document_structure_example>
|
||||
The documents you receive are structured like this:
|
||||
|
||||
<document>
|
||||
<document_metadata>
|
||||
<document_id>42</document_id>
|
||||
<document_type>GITHUB_CONNECTOR</document_type>
|
||||
<title><![CDATA[Some repo / file / issue title]]></title>
|
||||
<url><![CDATA[https://example.com]]></url>
|
||||
<metadata_json><![CDATA[{{"any":"other metadata"}}]]></metadata_json>
|
||||
</document_metadata>
|
||||
|
||||
<document_content>
|
||||
<chunk id='123'><![CDATA[First chunk text...]]></chunk>
|
||||
<chunk id='124'><![CDATA[Second chunk text...]]></chunk>
|
||||
</document_content>
|
||||
</document>
|
||||
|
||||
IMPORTANT: You MUST cite using the chunk ids (e.g. 123, 124). Do NOT cite document_id.
|
||||
</document_structure_example>
|
||||
|
||||
<citation_format>
|
||||
- Every fact from the documents must have a citation in the format [citation:chunk_id] where chunk_id is the EXACT id value from a `<chunk id='...'>` tag
|
||||
- Citations should appear at the end of the sentence containing the information they support
|
||||
- Multiple citations should be separated by commas: [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3]
|
||||
- No need to return references section. Just citations in answer.
|
||||
- NEVER create your own citation format - use the exact chunk_id values from the documents in the [citation:chunk_id] format
|
||||
- NEVER format citations as clickable links or as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only
|
||||
- NEVER make up chunk IDs if you are unsure about the chunk_id. It is better to omit the citation than to guess
|
||||
</citation_format>
|
||||
|
||||
<citation_examples>
|
||||
CORRECT citation formats:
|
||||
- [citation:5]
|
||||
- [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3]
|
||||
|
||||
INCORRECT citation formats (DO NOT use):
|
||||
- Using parentheses and markdown links: ([citation:5](https://github.com/MODSetter/SurfSense))
|
||||
- Using parentheses around brackets: ([citation:5])
|
||||
- Using hyperlinked text: [link to source 5](https://example.com)
|
||||
- Using footnote style: ... library¹
|
||||
- Making up source IDs when source_id is unknown
|
||||
- Using old IEEE format: [1], [2], [3]
|
||||
- Using source types instead of IDs: [citation:GITHUB_CONNECTOR] instead of [citation:5]
|
||||
</citation_examples>
|
||||
|
||||
<citation_output_example>
|
||||
Based on your GitHub repositories and video content, Python's asyncio library provides tools for writing concurrent code using the async/await syntax [citation:5]. It's particularly useful for I/O-bound and high-level structured network code [citation:5].
|
||||
|
||||
The key advantage of asyncio is that it can improve performance by allowing other code to run while waiting for I/O operations to complete [citation:12]. This makes it excellent for scenarios like web scraping, API calls, database operations, or any situation where your program spends time waiting for external resources.
|
||||
|
||||
However, from your video learning, it's important to note that asyncio is not suitable for CPU-bound tasks as it runs on a single thread [citation:12]. For computationally intensive work, you'd want to use multiprocessing instead.
|
||||
</citation_output_example>
|
||||
</citation_instructions>
|
||||
"""
|
||||
|
||||
# Part 3: User's custom instructions (empty by default, can be set by user from UI)
|
||||
DEFAULT_QNA_CUSTOM_INSTRUCTIONS = ""
|
||||
|
||||
# Full prompt with all parts combined (for backward compatibility and migration)
|
||||
DEFAULT_QNA_CITATION_PROMPT = (
|
||||
DEFAULT_QNA_BASE_PROMPT
|
||||
+ DEFAULT_QNA_CITATION_INSTRUCTIONS
|
||||
+ DEFAULT_QNA_CUSTOM_INSTRUCTIONS
|
||||
)
|
||||
|
||||
DEFAULT_QNA_NO_DOCUMENTS_PROMPT = """Today's date: {date}
|
||||
You are SurfSense, an advanced AI research assistant that provides helpful, detailed answers to user questions in a conversational manner.{language_instruction}
|
||||
{chat_history_section}
|
||||
<context>
|
||||
The user has asked a question but there are no specific documents from their personal knowledge base available to answer it. You should provide a helpful response based on:
|
||||
1. The conversation history and context
|
||||
2. Your general knowledge and expertise
|
||||
3. Understanding of the user's needs and interests based on our conversation
|
||||
</context>
|
||||
|
||||
<instructions>
|
||||
1. Provide a comprehensive, helpful answer to the user's question
|
||||
2. Draw upon the conversation history to understand context and the user's specific needs
|
||||
3. Use your general knowledge to provide accurate, detailed information
|
||||
4. Be conversational and engaging, as if having a detailed discussion with the user
|
||||
5. Acknowledge when you're drawing from general knowledge rather than their personal sources
|
||||
6. Provide actionable insights and practical information when relevant
|
||||
7. Structure your answer logically and clearly
|
||||
8. If the question would benefit from personalized information from their knowledge base, gently suggest they might want to add relevant content to SurfSense
|
||||
9. Be honest about limitations while still being maximally helpful
|
||||
10. Maintain the helpful, knowledgeable tone that users expect from SurfSense
|
||||
</instructions>
|
||||
|
||||
<format>
|
||||
- Write in a clear, conversational tone suitable for detailed Q&A discussions
|
||||
- Provide comprehensive answers that thoroughly address the user's question
|
||||
- Use appropriate paragraphs and structure for readability
|
||||
- No citations are needed since you're using general knowledge
|
||||
- Be thorough and detailed in your explanations while remaining focused on the user's specific question
|
||||
- If asking follow-up questions would be helpful, suggest them at the end of your response
|
||||
- When appropriate, mention that adding relevant content to their SurfSense knowledge base could provide more personalized answers
|
||||
</format>
|
||||
|
||||
<user_query_instructions>
|
||||
When answering the user's question without access to their personal documents:
|
||||
1. Review the chat history to understand conversation context and maintain continuity
|
||||
2. Provide the most helpful and comprehensive answer possible using general knowledge
|
||||
3. Be conversational and engaging
|
||||
4. Draw upon conversation history for context
|
||||
5. Be clear that you're providing general information
|
||||
6. Suggest ways the user could get more personalized answers by expanding their knowledge base when relevant
|
||||
</user_query_instructions>
|
||||
"""
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
from langgraph.graph import StateGraph
|
||||
|
||||
from .configuration import Configuration
|
||||
from .nodes import answer_question, rerank_documents
|
||||
from .state import State
|
||||
|
||||
# Define a new graph
|
||||
workflow = StateGraph(State, config_schema=Configuration)
|
||||
|
||||
# Add the nodes to the graph
|
||||
workflow.add_node("rerank_documents", rerank_documents)
|
||||
workflow.add_node("answer_question", answer_question)
|
||||
|
||||
# Connect the nodes
|
||||
workflow.add_edge("__start__", "rerank_documents")
|
||||
workflow.add_edge("rerank_documents", "answer_question")
|
||||
workflow.add_edge("answer_question", "__end__")
|
||||
|
||||
# Compile the workflow into an executable graph
|
||||
graph = workflow.compile()
|
||||
graph.name = "SurfSense QnA Agent" # This defines the custom name in LangSmith
|
||||
|
|
@ -1,297 +0,0 @@
|
|||
import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import SearchSpace
|
||||
from app.services.reranker_service import RerankerService
|
||||
|
||||
from ..utils import (
|
||||
calculate_token_count,
|
||||
format_documents_section,
|
||||
langchain_chat_history_to_str,
|
||||
optimize_documents_for_token_limit,
|
||||
)
|
||||
from .configuration import Configuration
|
||||
from .default_prompts import (
|
||||
DEFAULT_QNA_BASE_PROMPT,
|
||||
DEFAULT_QNA_CITATION_INSTRUCTIONS,
|
||||
DEFAULT_QNA_NO_DOCUMENTS_PROMPT,
|
||||
)
|
||||
from .state import State
|
||||
|
||||
|
||||
def _build_language_instruction(language: str | None = None):
|
||||
"""Build language instruction for prompts."""
|
||||
if language:
|
||||
return f"\n\nIMPORTANT: Please respond in {language} language. All your responses, explanations, and analysis should be written in {language}."
|
||||
return ""
|
||||
|
||||
|
||||
def _build_chat_history_section(chat_history: str | None = None):
|
||||
"""Build chat history section for prompts."""
|
||||
if chat_history:
|
||||
return f"""
|
||||
<chat_history>
|
||||
{chat_history if chat_history else "NO CHAT HISTORY PROVIDED"}
|
||||
</chat_history>
|
||||
"""
|
||||
return """
|
||||
<chat_history>
|
||||
NO CHAT HISTORY PROVIDED
|
||||
</chat_history>
|
||||
"""
|
||||
|
||||
|
||||
def _format_system_prompt(
|
||||
prompt_template: str,
|
||||
chat_history: str | None = None,
|
||||
language: str | None = None,
|
||||
):
|
||||
"""Format a system prompt template with dynamic values."""
|
||||
date = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
language_instruction = _build_language_instruction(language)
|
||||
chat_history_section = _build_chat_history_section(chat_history)
|
||||
|
||||
return prompt_template.format(
|
||||
date=date,
|
||||
language_instruction=language_instruction,
|
||||
chat_history_section=chat_history_section,
|
||||
)
|
||||
|
||||
|
||||
async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, Any]:
|
||||
"""
|
||||
Rerank the documents based on relevance to the user's question.
|
||||
|
||||
This node takes the relevant documents provided in the configuration,
|
||||
reranks them using the reranker service based on the user's query,
|
||||
and updates the state with the reranked documents.
|
||||
|
||||
Documents are now document-grouped with a `chunks` list. Reranking is done
|
||||
using the concatenated `content` field, and the full structure (including
|
||||
`chunks`) is preserved for proper citation formatting.
|
||||
|
||||
If reranking is disabled, returns the original documents without processing.
|
||||
|
||||
Returns:
|
||||
Dict containing the reranked documents.
|
||||
"""
|
||||
# Get configuration and relevant documents
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
documents = configuration.relevant_documents
|
||||
user_query = configuration.user_query
|
||||
reformulated_query = configuration.reformulated_query
|
||||
|
||||
# If no documents were provided, return empty list
|
||||
if not documents or len(documents) == 0:
|
||||
return {"reranked_documents": []}
|
||||
|
||||
# Get reranker service from app config
|
||||
reranker_service = RerankerService.get_reranker_instance()
|
||||
|
||||
# If reranking is not enabled, sort by existing score and return
|
||||
if not reranker_service:
|
||||
print("Reranking is disabled. Sorting documents by existing score.")
|
||||
sorted_documents = sorted(
|
||||
documents, key=lambda x: x.get("score", 0), reverse=True
|
||||
)
|
||||
return {"reranked_documents": sorted_documents}
|
||||
|
||||
# Perform reranking
|
||||
try:
|
||||
# Pass documents directly to reranker - it will use:
|
||||
# - "content" (concatenated chunk text) for scoring
|
||||
# - "chunk_id" (primary chunk id) for matching
|
||||
# The full document structure including "chunks" is preserved
|
||||
reranked_docs = reranker_service.rerank_documents(
|
||||
user_query + "\n" + reformulated_query, documents
|
||||
)
|
||||
|
||||
# Sort by score in descending order
|
||||
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
|
||||
|
||||
print(f"Reranked {len(reranked_docs)} documents for Q&A query: {user_query}")
|
||||
|
||||
return {"reranked_documents": reranked_docs}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during reranking: {e!s}")
|
||||
# Fall back to original documents if reranking fails
|
||||
return {"reranked_documents": documents}
|
||||
|
||||
|
||||
async def answer_question(
|
||||
state: State, config: RunnableConfig, writer: StreamWriter
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Answer the user's question using the provided documents with real-time streaming.
|
||||
|
||||
This node takes the relevant documents provided in the configuration and uses
|
||||
an LLM to generate a comprehensive answer to the user's question with
|
||||
proper citations. The citations follow [citation:chunk_id] format using chunk IDs from the
|
||||
`<chunk id='...'>` tags in the provided documents. If no documents are provided, it will use chat history to generate
|
||||
an answer.
|
||||
|
||||
The response is streamed token-by-token for real-time updates to the frontend.
|
||||
|
||||
Returns:
|
||||
Dict containing the final answer in the "final_answer" key.
|
||||
"""
|
||||
from app.services.llm_service import get_fast_llm
|
||||
|
||||
# Get configuration and relevant documents from configuration
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
documents = state.reranked_documents
|
||||
user_query = configuration.user_query
|
||||
search_space_id = configuration.search_space_id
|
||||
language = configuration.language
|
||||
|
||||
# Get streaming service from state
|
||||
streaming_service = state.streaming_service
|
||||
|
||||
# Fetch search space to get QnA configuration
|
||||
result = await state.db_session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalar_one_or_none()
|
||||
|
||||
if not search_space:
|
||||
error_message = f"Search space {search_space_id} not found"
|
||||
print(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
# Get QnA configuration from search space
|
||||
citations_enabled = search_space.citations_enabled
|
||||
custom_instructions_text = search_space.qna_custom_instructions or ""
|
||||
|
||||
# Use constants for base prompt and citation instructions
|
||||
qna_base_prompt = DEFAULT_QNA_BASE_PROMPT
|
||||
qna_citation_instructions = (
|
||||
DEFAULT_QNA_CITATION_INSTRUCTIONS if citations_enabled else ""
|
||||
)
|
||||
qna_custom_instructions = (
|
||||
f"\n<special_important_custom_instructions>\n{custom_instructions_text}\n</special_important_custom_instructions>"
|
||||
if custom_instructions_text
|
||||
else ""
|
||||
)
|
||||
|
||||
# Get search space's fast LLM
|
||||
llm = await get_fast_llm(state.db_session, search_space_id)
|
||||
if not llm:
|
||||
error_message = f"No fast LLM configured for search space {search_space_id}"
|
||||
print(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
# Determine if we have documents and optimize for token limits
|
||||
has_documents_initially = documents and len(documents) > 0
|
||||
chat_history_str = langchain_chat_history_to_str(state.chat_history)
|
||||
|
||||
if has_documents_initially:
|
||||
# Compose the full citation prompt: base + citation instructions + custom instructions
|
||||
full_citation_prompt_template = (
|
||||
qna_base_prompt + qna_citation_instructions + qna_custom_instructions
|
||||
)
|
||||
|
||||
# Create base message template for token calculation (without documents)
|
||||
base_human_message_template = f"""
|
||||
|
||||
User's question:
|
||||
<user_query>
|
||||
{user_query}
|
||||
</user_query>
|
||||
|
||||
Please provide a detailed, comprehensive answer to the user's question using the information from their personal knowledge sources. Make sure to cite all information appropriately and engage in a conversational manner.
|
||||
"""
|
||||
|
||||
# Use initial system prompt for token calculation
|
||||
initial_system_prompt = _format_system_prompt(
|
||||
full_citation_prompt_template, chat_history_str, language
|
||||
)
|
||||
base_messages = [
|
||||
SystemMessage(content=initial_system_prompt),
|
||||
HumanMessage(content=base_human_message_template),
|
||||
]
|
||||
|
||||
# Optimize documents to fit within token limits
|
||||
optimized_documents, has_optimized_documents = (
|
||||
optimize_documents_for_token_limit(documents, base_messages, llm.model)
|
||||
)
|
||||
|
||||
# Update state based on optimization result
|
||||
documents = optimized_documents
|
||||
has_documents = has_optimized_documents
|
||||
else:
|
||||
has_documents = False
|
||||
|
||||
# Choose system prompt based on final document availability
|
||||
# With documents: use base + citation instructions + custom instructions
|
||||
# Without documents: use the default no-documents prompt from constants
|
||||
if has_documents:
|
||||
full_citation_prompt_template = (
|
||||
qna_base_prompt + qna_citation_instructions + qna_custom_instructions
|
||||
)
|
||||
system_prompt = _format_system_prompt(
|
||||
full_citation_prompt_template, chat_history_str, language
|
||||
)
|
||||
else:
|
||||
system_prompt = _format_system_prompt(
|
||||
DEFAULT_QNA_NO_DOCUMENTS_PROMPT + qna_custom_instructions,
|
||||
chat_history_str,
|
||||
language,
|
||||
)
|
||||
|
||||
# Generate documents section
|
||||
documents_text = (
|
||||
format_documents_section(
|
||||
documents, "Source material from your personal knowledge base"
|
||||
)
|
||||
if has_documents
|
||||
else ""
|
||||
)
|
||||
|
||||
# Create final human message content
|
||||
instruction_text = (
|
||||
"Please provide a detailed, comprehensive answer to the user's question using the information from their personal knowledge sources. Make sure to cite all information appropriately and engage in a conversational manner."
|
||||
if has_documents
|
||||
else "Please provide a helpful answer to the user's question based on our conversation history and your general knowledge. Engage in a conversational manner."
|
||||
)
|
||||
|
||||
human_message_content = f"""
|
||||
{documents_text}
|
||||
|
||||
User's question:
|
||||
<user_query>
|
||||
{user_query}
|
||||
</user_query>
|
||||
|
||||
{instruction_text}
|
||||
"""
|
||||
|
||||
# Create final messages for the LLM
|
||||
messages_with_chat_history = [
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessage(content=human_message_content),
|
||||
]
|
||||
|
||||
# Log final token count
|
||||
total_tokens = calculate_token_count(messages_with_chat_history, llm.model)
|
||||
print(f"Final token count: {total_tokens}")
|
||||
|
||||
# Stream the LLM response token by token
|
||||
final_answer = ""
|
||||
|
||||
async for chunk in llm.astream(messages_with_chat_history):
|
||||
# Extract the content from the chunk
|
||||
if hasattr(chunk, "content") and chunk.content:
|
||||
token = chunk.content
|
||||
final_answer += token
|
||||
|
||||
# Stream the token to the frontend via custom stream
|
||||
if streaming_service:
|
||||
writer({"yield_value": streaming_service.format_text_chunk(token)})
|
||||
|
||||
return {"final_answer": final_answer}
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
"""Define the state structures for the agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.streaming_service import StreamingService
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
"""Defines the dynamic state for the Q&A agent during execution.
|
||||
|
||||
This state tracks the database session, chat history, and the outputs
|
||||
generated by the agent's nodes during question answering.
|
||||
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
|
||||
for more information.
|
||||
"""
|
||||
|
||||
# Runtime context
|
||||
db_session: AsyncSession
|
||||
|
||||
# Streaming service for real-time token streaming
|
||||
streaming_service: StreamingService | None = None
|
||||
|
||||
chat_history: list[Any] | None = field(default_factory=list)
|
||||
# OUTPUT: Populated by agent nodes
|
||||
reranked_documents: list[Any] | None = None
|
||||
final_answer: str | None = None
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
"""Define the state structures for the agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.streaming_service import StreamingService
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
"""Defines the dynamic state for the agent during execution.
|
||||
|
||||
This state tracks the database session and the outputs generated by the agent's nodes.
|
||||
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
|
||||
for more information.
|
||||
"""
|
||||
|
||||
# Runtime context (not part of actual graph state)
|
||||
db_session: AsyncSession
|
||||
|
||||
# Streaming service
|
||||
streaming_service: StreamingService
|
||||
|
||||
chat_history: list[Any] | None = field(default_factory=list)
|
||||
|
||||
reformulated_query: str | None = field(default=None)
|
||||
further_questions: Any | None = field(default=None)
|
||||
|
||||
# Temporary field to hold reranked documents from sub-agents for further question generation
|
||||
reranked_documents: list[Any] | None = field(default=None)
|
||||
|
||||
# OUTPUT: Populated by agent nodes
|
||||
# Using field to explicitly mark as part of state
|
||||
final_written_report: str | None = field(default=None)
|
||||
|
|
@ -1,292 +0,0 @@
|
|||
import json
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from litellm import get_model_info, token_counter
|
||||
|
||||
|
||||
class DocumentTokenInfo(NamedTuple):
|
||||
"""Information about a document and its token cost."""
|
||||
|
||||
index: int
|
||||
document: dict[str, Any]
|
||||
formatted_content: str
|
||||
token_count: int
|
||||
|
||||
|
||||
def get_connector_emoji(connector_name: str) -> str:
|
||||
"""Get an appropriate emoji for a connector type."""
|
||||
connector_emojis = {
|
||||
"YOUTUBE_VIDEO": "📹",
|
||||
"EXTENSION": "🧩",
|
||||
"FILE": "📄",
|
||||
"SLACK_CONNECTOR": "💬",
|
||||
"NOTION_CONNECTOR": "📘",
|
||||
"GITHUB_CONNECTOR": "🐙",
|
||||
"LINEAR_CONNECTOR": "📊",
|
||||
"JIRA_CONNECTOR": "🎫",
|
||||
"DISCORD_CONNECTOR": "🗨️",
|
||||
"TAVILY_API": "🔍",
|
||||
"LINKUP_API": "🔗",
|
||||
"BAIDU_SEARCH_API": "🇨🇳",
|
||||
"GOOGLE_CALENDAR_CONNECTOR": "📅",
|
||||
"AIRTABLE_CONNECTOR": "🗃️",
|
||||
"LUMA_CONNECTOR": "✨",
|
||||
"ELASTICSEARCH_CONNECTOR": "⚡",
|
||||
"WEBCRAWLER_CONNECTOR": "🌐",
|
||||
"BOOKSTACK_CONNECTOR": "📚",
|
||||
"NOTE": "📝",
|
||||
}
|
||||
return connector_emojis.get(connector_name, "🔎")
|
||||
|
||||
|
||||
def get_connector_friendly_name(connector_name: str) -> str:
|
||||
"""Convert technical connector IDs to user-friendly names."""
|
||||
connector_friendly_names = {
|
||||
"YOUTUBE_VIDEO": "YouTube",
|
||||
"EXTENSION": "Browser Extension",
|
||||
"FILE": "Files",
|
||||
"SLACK_CONNECTOR": "Slack",
|
||||
"NOTION_CONNECTOR": "Notion",
|
||||
"GITHUB_CONNECTOR": "GitHub",
|
||||
"LINEAR_CONNECTOR": "Linear",
|
||||
"JIRA_CONNECTOR": "Jira",
|
||||
"CONFLUENCE_CONNECTOR": "Confluence",
|
||||
"GOOGLE_CALENDAR_CONNECTOR": "Google Calendar",
|
||||
"DISCORD_CONNECTOR": "Discord",
|
||||
"TAVILY_API": "Tavily Search",
|
||||
"LINKUP_API": "Linkup Search",
|
||||
"BAIDU_SEARCH_API": "Baidu Search",
|
||||
"AIRTABLE_CONNECTOR": "Airtable",
|
||||
"LUMA_CONNECTOR": "Luma",
|
||||
"ELASTICSEARCH_CONNECTOR": "Elasticsearch",
|
||||
"WEBCRAWLER_CONNECTOR": "Web Pages",
|
||||
"BOOKSTACK_CONNECTOR": "BookStack",
|
||||
"NOTE": "Notes",
|
||||
}
|
||||
return connector_friendly_names.get(connector_name, connector_name)
|
||||
|
||||
|
||||
def convert_langchain_messages_to_dict(
|
||||
messages: list[BaseMessage],
|
||||
) -> list[dict[str, str]]:
|
||||
"""Convert LangChain messages to format expected by token_counter."""
|
||||
role_mapping = {"system": "system", "human": "user", "ai": "assistant"}
|
||||
|
||||
converted_messages = []
|
||||
for msg in messages:
|
||||
role = role_mapping.get(getattr(msg, "type", None), "user")
|
||||
converted_messages.append({"role": role, "content": str(msg.content)})
|
||||
|
||||
return converted_messages
|
||||
|
||||
|
||||
def format_document_for_citation(document: dict[str, Any]) -> str:
|
||||
"""Format a single document for citation in the new document+chunks XML format.
|
||||
|
||||
IMPORTANT:
|
||||
- Citations must reference real DB chunk IDs: `[citation:<chunk_id>]`
|
||||
- Document metadata is included under <document_metadata>, but citations are NOT document_id-based.
|
||||
"""
|
||||
|
||||
def _to_cdata(value: Any) -> str:
|
||||
text = "" if value is None else str(value)
|
||||
# Safely nest CDATA even if the content includes "]]>"
|
||||
return "<![CDATA[" + text.replace("]]>", "]]]]><![CDATA[>") + "]]>"
|
||||
|
||||
doc_info = document.get("document", {}) or {}
|
||||
metadata = doc_info.get("metadata", {}) or {}
|
||||
|
||||
doc_id = doc_info.get("id", "")
|
||||
title = doc_info.get("title", "")
|
||||
document_type = doc_info.get("document_type", "CRAWLED_URL")
|
||||
url = (
|
||||
metadata.get("url")
|
||||
or metadata.get("source")
|
||||
or metadata.get("page_url")
|
||||
or metadata.get("VisitedWebPageURL")
|
||||
or ""
|
||||
)
|
||||
|
||||
metadata_json = json.dumps(metadata, ensure_ascii=False)
|
||||
|
||||
chunks = document.get("chunks") or []
|
||||
if not chunks:
|
||||
# Fallback: treat `content` as a single chunk (no chunk_id available for citation)
|
||||
chunks = [{"chunk_id": "", "content": document.get("content", "")}]
|
||||
|
||||
chunks_xml = "\n".join(
|
||||
[
|
||||
f"<chunk id='{chunk.get('chunk_id', '')}'>{_to_cdata(chunk.get('content', ''))}</chunk>"
|
||||
for chunk in chunks
|
||||
]
|
||||
)
|
||||
|
||||
return f"""<document>
|
||||
<document_metadata>
|
||||
<document_id>{doc_id}</document_id>
|
||||
<document_type>{document_type}</document_type>
|
||||
<title>{_to_cdata(title)}</title>
|
||||
<url>{_to_cdata(url)}</url>
|
||||
<metadata_json>{_to_cdata(metadata_json)}</metadata_json>
|
||||
</document_metadata>
|
||||
|
||||
<document_content>
|
||||
{chunks_xml}
|
||||
</document_content>
|
||||
</document>"""
|
||||
|
||||
|
||||
def format_documents_section(
|
||||
documents: list[dict[str, Any]], section_title: str = "Source material"
|
||||
) -> str:
|
||||
"""Format multiple documents into a complete documents section."""
|
||||
if not documents:
|
||||
return ""
|
||||
|
||||
formatted_docs = [format_document_for_citation(doc) for doc in documents]
|
||||
|
||||
return f"""{section_title}:
|
||||
<documents>
|
||||
{chr(10).join(formatted_docs)}
|
||||
</documents>"""
|
||||
|
||||
|
||||
def calculate_document_token_costs(
|
||||
documents: list[dict[str, Any]], model: str
|
||||
) -> list[DocumentTokenInfo]:
|
||||
"""Pre-calculate token costs for each document."""
|
||||
document_token_info = []
|
||||
|
||||
for i, doc in enumerate(documents):
|
||||
formatted_doc = format_document_for_citation(doc)
|
||||
|
||||
# Calculate token count for this document
|
||||
token_count = token_counter(
|
||||
messages=[{"role": "user", "content": formatted_doc}], model=model
|
||||
)
|
||||
|
||||
document_token_info.append(
|
||||
DocumentTokenInfo(
|
||||
index=i,
|
||||
document=doc,
|
||||
formatted_content=formatted_doc,
|
||||
token_count=token_count,
|
||||
)
|
||||
)
|
||||
|
||||
return document_token_info
|
||||
|
||||
|
||||
def find_optimal_documents_with_binary_search(
|
||||
document_tokens: list[DocumentTokenInfo], available_tokens: int
|
||||
) -> list[DocumentTokenInfo]:
|
||||
"""Use binary search to find the maximum number of documents that fit within token limit."""
|
||||
if not document_tokens or available_tokens <= 0:
|
||||
return []
|
||||
|
||||
left, right = 0, len(document_tokens)
|
||||
optimal_docs = []
|
||||
|
||||
while left <= right:
|
||||
mid = (left + right) // 2
|
||||
current_docs = document_tokens[:mid]
|
||||
current_token_sum = sum(doc_info.token_count for doc_info in current_docs)
|
||||
|
||||
if current_token_sum <= available_tokens:
|
||||
optimal_docs = current_docs
|
||||
left = mid + 1
|
||||
else:
|
||||
right = mid - 1
|
||||
|
||||
return optimal_docs
|
||||
|
||||
|
||||
def get_model_context_window(model_name: str) -> int:
|
||||
"""Get the total context window size for a model (input + output tokens)."""
|
||||
try:
|
||||
model_info = get_model_info(model_name)
|
||||
context_window = model_info.get("max_input_tokens", 4096) # Default fallback
|
||||
return context_window
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}"
|
||||
)
|
||||
return 4096 # Conservative fallback
|
||||
|
||||
|
||||
def optimize_documents_for_token_limit(
|
||||
documents: list[dict[str, Any]], base_messages: list[BaseMessage], model_name: str
|
||||
) -> tuple[list[dict[str, Any]], bool]:
|
||||
"""
|
||||
Optimize documents to fit within token limits using binary search.
|
||||
|
||||
Args:
|
||||
documents: List of documents with content and metadata
|
||||
base_messages: Base messages without documents (chat history + system + human message template)
|
||||
model_name: Model name for token counting (required)
|
||||
output_token_buffer: Number of tokens to reserve for model output
|
||||
|
||||
Returns:
|
||||
Tuple of (optimized_documents, has_documents_remaining)
|
||||
"""
|
||||
if not documents:
|
||||
return [], False
|
||||
|
||||
model = model_name
|
||||
context_window = get_model_context_window(model)
|
||||
|
||||
# Calculate base token cost
|
||||
base_messages_dict = convert_langchain_messages_to_dict(base_messages)
|
||||
base_tokens = token_counter(messages=base_messages_dict, model=model)
|
||||
available_tokens_for_docs = context_window - base_tokens
|
||||
|
||||
print(
|
||||
f"Token optimization: Context window={context_window}, Base={base_tokens}, Available for docs={available_tokens_for_docs}"
|
||||
)
|
||||
|
||||
if available_tokens_for_docs <= 0:
|
||||
print("No tokens available for documents after base content and output buffer")
|
||||
return [], False
|
||||
|
||||
# Calculate token costs for all documents
|
||||
document_token_info = calculate_document_token_costs(documents, model)
|
||||
|
||||
# Find optimal number of documents using binary search
|
||||
optimal_doc_info = find_optimal_documents_with_binary_search(
|
||||
document_token_info, available_tokens_for_docs
|
||||
)
|
||||
|
||||
# Extract the original document objects
|
||||
optimized_documents = [doc_info.document for doc_info in optimal_doc_info]
|
||||
has_documents_remaining = len(optimized_documents) > 0
|
||||
|
||||
print(
|
||||
f"Token optimization result: Using {len(optimized_documents)}/{len(documents)} documents"
|
||||
)
|
||||
|
||||
return optimized_documents, has_documents_remaining
|
||||
|
||||
|
||||
def calculate_token_count(messages: list[BaseMessage], model_name: str) -> int:
|
||||
"""Calculate token count for a list of LangChain messages."""
|
||||
model = model_name
|
||||
messages_dict = convert_langchain_messages_to_dict(messages)
|
||||
return token_counter(messages=messages_dict, model=model)
|
||||
|
||||
|
||||
def langchain_chat_history_to_str(chat_history: list[BaseMessage]) -> str:
|
||||
"""
|
||||
Convert a list of chat history messages to a string.
|
||||
"""
|
||||
chat_history_str = ""
|
||||
|
||||
for chat_message in chat_history:
|
||||
if isinstance(chat_message, HumanMessage):
|
||||
chat_history_str += f"<user>{chat_message.content}</user>\n"
|
||||
elif isinstance(chat_message, AIMessage):
|
||||
chat_history_str += f"<assistant>{chat_message.content}</assistant>\n"
|
||||
elif isinstance(chat_message, SystemMessage):
|
||||
chat_history_str += f"<system>{chat_message.content}</system>\n"
|
||||
|
||||
return chat_history_str
|
||||
|
|
@ -77,10 +77,6 @@ class SearchSourceConnectorType(str, Enum):
|
|||
BOOKSTACK_CONNECTOR = "BOOKSTACK_CONNECTOR"
|
||||
|
||||
|
||||
class ChatType(str, Enum):
|
||||
QNA = "QNA"
|
||||
|
||||
|
||||
class LiteLLMProvider(str, Enum):
|
||||
"""
|
||||
Enum for LLM providers supported by LiteLLM.
|
||||
|
|
@ -317,21 +313,6 @@ class BaseModel(Base):
|
|||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
|
||||
class Chat(BaseModel, TimestampMixin):
|
||||
__tablename__ = "chats"
|
||||
|
||||
type = Column(SQLAlchemyEnum(ChatType), nullable=False)
|
||||
title = Column(String, nullable=False, index=True)
|
||||
initial_connectors = Column(ARRAY(String), nullable=True)
|
||||
messages = Column(JSON, nullable=False)
|
||||
state_version = Column(BigInteger, nullable=False, default=1)
|
||||
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship("SearchSpace", back_populates="chats")
|
||||
|
||||
|
||||
class NewChatMessageRole(str, Enum):
|
||||
"""Role enum for new chat messages."""
|
||||
|
||||
|
|
@ -363,9 +344,6 @@ class NewChatThread(BaseModel, TimestampMixin):
|
|||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
# Relationships
|
||||
search_space = relationship("SearchSpace", back_populates="new_chat_threads")
|
||||
|
|
@ -445,23 +423,6 @@ class Chunk(BaseModel, TimestampMixin):
|
|||
document = relationship("Document", back_populates="chunks")
|
||||
|
||||
|
||||
class Podcast(BaseModel, TimestampMixin):
|
||||
__tablename__ = "podcasts"
|
||||
|
||||
title = Column(String, nullable=False, index=True)
|
||||
podcast_transcript = Column(JSON, nullable=False, default={})
|
||||
file_location = Column(String(500), nullable=False, default="")
|
||||
chat_id = Column(
|
||||
Integer, ForeignKey("chats.id", ondelete="CASCADE"), nullable=True
|
||||
) # If generated from a chat, this will be the chat id, else null ( can be from a document or a chat )
|
||||
chat_state_version = Column(BigInteger, nullable=True)
|
||||
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship("SearchSpace", back_populates="podcasts")
|
||||
|
||||
|
||||
class SearchSpace(BaseModel, TimestampMixin):
|
||||
__tablename__ = "searchspaces"
|
||||
|
||||
|
|
@ -492,18 +453,6 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
order_by="Document.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
podcasts = relationship(
|
||||
"Podcast",
|
||||
back_populates="search_space",
|
||||
order_by="Podcast.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
chats = relationship(
|
||||
"Chat",
|
||||
back_populates="search_space",
|
||||
order_by="Chat.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
new_chat_threads = relationship(
|
||||
"NewChatThread",
|
||||
back_populates="search_space",
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from fastapi import APIRouter
|
|||
from .airtable_add_connector_route import (
|
||||
router as airtable_add_connector_router,
|
||||
)
|
||||
from .chats_routes import router as chats_router
|
||||
from .documents_routes import router as documents_router
|
||||
from .editor_routes import router as editor_router
|
||||
from .google_calendar_add_connector_route import (
|
||||
|
|
@ -17,7 +16,6 @@ from .logs_routes import router as logs_router
|
|||
from .luma_add_connector_route import router as luma_add_connector_router
|
||||
from .new_chat_routes import router as new_chat_router
|
||||
from .notes_routes import router as notes_router
|
||||
from .podcasts_routes import router as podcasts_router
|
||||
from .rbac_routes import router as rbac_router
|
||||
from .search_source_connectors_routes import router as search_source_connectors_router
|
||||
from .search_spaces_routes import router as search_spaces_router
|
||||
|
|
@ -29,9 +27,7 @@ router.include_router(rbac_router) # RBAC routes for roles, members, invites
|
|||
router.include_router(editor_router)
|
||||
router.include_router(documents_router)
|
||||
router.include_router(notes_router)
|
||||
router.include_router(podcasts_router)
|
||||
router.include_router(chats_router)
|
||||
router.include_router(new_chat_router) # New chat with assistant-ui persistence
|
||||
router.include_router(new_chat_router) # Chat with assistant-ui persistence
|
||||
router.include_router(search_source_connectors_router)
|
||||
router.include_router(google_calendar_add_connector_router)
|
||||
router.include_router(google_gmail_add_connector_router)
|
||||
|
|
|
|||
|
|
@ -1,617 +0,0 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.db import (
|
||||
Chat,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
SearchSpaceMembership,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas import (
|
||||
AISDKChatRequest,
|
||||
ChatCreate,
|
||||
ChatRead,
|
||||
ChatReadWithoutMessages,
|
||||
ChatUpdate,
|
||||
NewChatRequest,
|
||||
)
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.tasks.chat.stream_connector_search_results import (
|
||||
stream_connector_search_results,
|
||||
)
|
||||
from app.tasks.chat.stream_new_chat import stream_new_chat
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
from app.utils.validators import (
|
||||
validate_connectors,
|
||||
validate_document_ids,
|
||||
validate_messages,
|
||||
validate_research_mode,
|
||||
validate_search_space_id,
|
||||
validate_top_k,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
async def handle_chat_data(
|
||||
request: AISDKChatRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
# Validate and sanitize all input data
|
||||
messages = validate_messages(request.messages)
|
||||
|
||||
if messages[-1]["role"] != "user":
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Last message must be a user message"
|
||||
)
|
||||
|
||||
user_query = messages[-1]["content"]
|
||||
|
||||
# Extract and validate data from request
|
||||
request_data = request.data or {}
|
||||
search_space_id = validate_search_space_id(request_data.get("search_space_id"))
|
||||
research_mode = validate_research_mode(request_data.get("research_mode"))
|
||||
selected_connectors = validate_connectors(request_data.get("selected_connectors"))
|
||||
document_ids_to_add_in_context = validate_document_ids(
|
||||
request_data.get("document_ids_to_add_in_context")
|
||||
)
|
||||
top_k = validate_top_k(request_data.get("top_k"))
|
||||
# print("RESQUEST DATA:", request_data)
|
||||
# print("SELECTED CONNECTORS:", selected_connectors)
|
||||
|
||||
# Check if the user has chat access to the search space
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.CHATS_CREATE.value,
|
||||
"You don't have permission to use chat in this search space",
|
||||
)
|
||||
|
||||
# Get search space with LLM configs (preferences are now stored at search space level)
|
||||
search_space_result = await session.execute(
|
||||
select(SearchSpace)
|
||||
.options(selectinload(SearchSpace.llm_configs))
|
||||
.filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = search_space_result.scalars().first()
|
||||
|
||||
language = None
|
||||
llm_configs = [] # Initialize to empty list
|
||||
|
||||
if search_space and search_space.llm_configs:
|
||||
llm_configs = search_space.llm_configs
|
||||
|
||||
# Get language from configured LLM preferences
|
||||
# LLM preferences are now stored on the SearchSpace model
|
||||
from app.config import config as app_config
|
||||
|
||||
for llm_id in [
|
||||
search_space.fast_llm_id,
|
||||
search_space.long_context_llm_id,
|
||||
search_space.strategic_llm_id,
|
||||
]:
|
||||
if llm_id is not None:
|
||||
# Check if it's a global config (negative ID)
|
||||
if llm_id < 0:
|
||||
# Look in global configs
|
||||
for global_cfg in app_config.GLOBAL_LLM_CONFIGS:
|
||||
if global_cfg.get("id") == llm_id:
|
||||
language = global_cfg.get("language")
|
||||
if language:
|
||||
break
|
||||
else:
|
||||
# Look in custom configs
|
||||
for llm_config in llm_configs:
|
||||
if llm_config.id == llm_id and getattr(
|
||||
llm_config, "language", None
|
||||
):
|
||||
language = llm_config.language
|
||||
break
|
||||
if language:
|
||||
break
|
||||
|
||||
if not language and llm_configs:
|
||||
first_llm_config = llm_configs[0]
|
||||
language = getattr(first_llm_config, "language", None)
|
||||
|
||||
except HTTPException:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have access to this search space"
|
||||
) from None
|
||||
|
||||
langchain_chat_history = []
|
||||
for message in messages[:-1]:
|
||||
if message["role"] == "user":
|
||||
langchain_chat_history.append(HumanMessage(content=message["content"]))
|
||||
elif message["role"] == "assistant":
|
||||
langchain_chat_history.append(AIMessage(content=message["content"]))
|
||||
|
||||
response = StreamingResponse(
|
||||
stream_connector_search_results(
|
||||
user_query,
|
||||
user.id,
|
||||
search_space_id,
|
||||
session,
|
||||
research_mode,
|
||||
selected_connectors,
|
||||
langchain_chat_history,
|
||||
document_ids_to_add_in_context,
|
||||
language,
|
||||
top_k,
|
||||
)
|
||||
)
|
||||
|
||||
response.headers["x-vercel-ai-data-stream"] = "v1"
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/new_chat")
|
||||
async def handle_new_chat(
|
||||
request: NewChatRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Handle new chat requests using the SurfSense deep agent.
|
||||
|
||||
This endpoint uses the new deep agent with the Vercel AI SDK
|
||||
Data Stream Protocol (SSE format).
|
||||
|
||||
Args:
|
||||
request: NewChatRequest containing chat_id, user_query, and search_space_id
|
||||
session: Database session
|
||||
user: Current authenticated user
|
||||
|
||||
Returns:
|
||||
StreamingResponse with SSE formatted data
|
||||
"""
|
||||
# Validate the user query
|
||||
if not request.user_query or not request.user_query.strip():
|
||||
raise HTTPException(status_code=400, detail="User query cannot be empty")
|
||||
|
||||
# Check if the user has chat access to the search space
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
request.search_space_id,
|
||||
Permission.CHATS_CREATE.value,
|
||||
"You don't have permission to use chat in this search space",
|
||||
)
|
||||
except HTTPException:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have access to this search space"
|
||||
) from None
|
||||
|
||||
# Get LLM config ID from search space preferences (optional enhancement)
|
||||
# For now, we use the default global config (-1)
|
||||
llm_config_id = -1
|
||||
|
||||
# Optionally load LLM preferences from search space
|
||||
try:
|
||||
search_space_result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
|
||||
)
|
||||
search_space = search_space_result.scalars().first()
|
||||
|
||||
if search_space:
|
||||
# Use strategic_llm_id if available, otherwise fall back to fast_llm_id
|
||||
if search_space.strategic_llm_id is not None:
|
||||
llm_config_id = search_space.strategic_llm_id
|
||||
elif search_space.fast_llm_id is not None:
|
||||
llm_config_id = search_space.fast_llm_id
|
||||
except Exception:
|
||||
# Fall back to default config on any error
|
||||
pass
|
||||
|
||||
# Create the streaming response
|
||||
# chat_id is used as LangGraph's thread_id for automatic chat history management
|
||||
response = StreamingResponse(
|
||||
stream_new_chat(
|
||||
user_query=request.user_query.strip(),
|
||||
user_id=user.id,
|
||||
search_space_id=request.search_space_id,
|
||||
chat_id=request.chat_id,
|
||||
session=session,
|
||||
llm_config_id=llm_config_id,
|
||||
messages=request.messages, # Pass message history from frontend
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
# Set the required headers for Vercel AI SDK
|
||||
headers = VercelStreamingService.get_response_headers()
|
||||
for key, value in headers.items():
|
||||
response.headers[key] = value
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/chats", response_model=ChatRead)
|
||||
async def create_chat(
|
||||
chat: ChatCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Create a new chat.
|
||||
Requires CHATS_CREATE permission.
|
||||
"""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
chat.search_space_id,
|
||||
Permission.CHATS_CREATE.value,
|
||||
"You don't have permission to create chats in this search space",
|
||||
)
|
||||
db_chat = Chat(**chat.model_dump())
|
||||
session.add(db_chat)
|
||||
await session.commit()
|
||||
await session.refresh(db_chat)
|
||||
return db_chat
|
||||
except HTTPException:
|
||||
raise
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Database constraint violation. Please check your input data.",
|
||||
) from None
|
||||
except OperationalError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="An unexpected error occurred while creating the chat.",
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/chats", response_model=list[ChatReadWithoutMessages])
|
||||
async def read_chats(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
search_space_id: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
List chats the user has access to.
|
||||
Requires CHATS_READ permission for the search space(s).
|
||||
"""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="skip must be a non-negative integer"
|
||||
)
|
||||
|
||||
if limit <= 0 or limit > 1000: # Reasonable upper limit
|
||||
raise HTTPException(status_code=400, detail="limit must be between 1 and 1000")
|
||||
|
||||
# Validate search_space_id if provided
|
||||
if search_space_id is not None and search_space_id <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="search_space_id must be a positive integer"
|
||||
)
|
||||
try:
|
||||
if search_space_id is not None:
|
||||
# Check permission for specific search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to read chats in this search space",
|
||||
)
|
||||
# Select specific fields excluding messages
|
||||
query = (
|
||||
select(
|
||||
Chat.id,
|
||||
Chat.type,
|
||||
Chat.title,
|
||||
Chat.initial_connectors,
|
||||
Chat.search_space_id,
|
||||
Chat.created_at,
|
||||
Chat.state_version,
|
||||
)
|
||||
.filter(Chat.search_space_id == search_space_id)
|
||||
.order_by(Chat.created_at.desc())
|
||||
)
|
||||
else:
|
||||
# Get chats from all search spaces user has membership in
|
||||
query = (
|
||||
select(
|
||||
Chat.id,
|
||||
Chat.type,
|
||||
Chat.title,
|
||||
Chat.initial_connectors,
|
||||
Chat.search_space_id,
|
||||
Chat.created_at,
|
||||
Chat.state_version,
|
||||
)
|
||||
.join(SearchSpace)
|
||||
.join(SearchSpaceMembership)
|
||||
.filter(SearchSpaceMembership.user_id == user.id)
|
||||
.order_by(Chat.created_at.desc())
|
||||
)
|
||||
|
||||
result = await session.execute(query.offset(skip).limit(limit))
|
||||
return result.all()
|
||||
except HTTPException:
|
||||
raise
|
||||
except OperationalError:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An unexpected error occurred while fetching chats."
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/chats/search", response_model=list[ChatReadWithoutMessages])
|
||||
async def search_chats(
|
||||
title: str,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
search_space_id: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Search chats by title substring.
|
||||
Requires CHATS_READ permission for the search space(s).
|
||||
|
||||
Args:
|
||||
title: Case-insensitive substring to match against chat titles. Required.
|
||||
skip: Number of items to skip from the beginning. Default: 0.
|
||||
limit: Maximum number of items to return. Default: 100.
|
||||
search_space_id: Filter results to a specific search space. Default: None.
|
||||
session: Database session (injected).
|
||||
user: Current authenticated user (injected).
|
||||
|
||||
Returns:
|
||||
List of chats matching the search query.
|
||||
|
||||
Notes:
|
||||
- Title matching uses ILIKE (case-insensitive).
|
||||
- Results are ordered by creation date (most recent first).
|
||||
"""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="skip must be a non-negative integer"
|
||||
)
|
||||
|
||||
if limit <= 0 or limit > 1000:
|
||||
raise HTTPException(status_code=400, detail="limit must be between 1 and 1000")
|
||||
|
||||
# Validate search_space_id if provided
|
||||
if search_space_id is not None and search_space_id <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="search_space_id must be a positive integer"
|
||||
)
|
||||
|
||||
try:
|
||||
if search_space_id is not None:
|
||||
# Check permission for specific search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to read chats in this search space",
|
||||
)
|
||||
# Select specific fields excluding messages
|
||||
query = (
|
||||
select(
|
||||
Chat.id,
|
||||
Chat.type,
|
||||
Chat.title,
|
||||
Chat.initial_connectors,
|
||||
Chat.search_space_id,
|
||||
Chat.created_at,
|
||||
Chat.state_version,
|
||||
)
|
||||
.filter(Chat.search_space_id == search_space_id)
|
||||
.order_by(Chat.created_at.desc())
|
||||
)
|
||||
else:
|
||||
# Get chats from all search spaces user has membership in
|
||||
query = (
|
||||
select(
|
||||
Chat.id,
|
||||
Chat.type,
|
||||
Chat.title,
|
||||
Chat.initial_connectors,
|
||||
Chat.search_space_id,
|
||||
Chat.created_at,
|
||||
Chat.state_version,
|
||||
)
|
||||
.join(SearchSpace)
|
||||
.join(SearchSpaceMembership)
|
||||
.filter(SearchSpaceMembership.user_id == user.id)
|
||||
.order_by(Chat.created_at.desc())
|
||||
)
|
||||
|
||||
# Apply title search filter (case-insensitive)
|
||||
query = query.filter(Chat.title.ilike(f"%{title}%"))
|
||||
|
||||
result = await session.execute(query.offset(skip).limit(limit))
|
||||
return result.all()
|
||||
except HTTPException:
|
||||
raise
|
||||
except OperationalError:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="An unexpected error occurred while searching chats.",
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/chats/{chat_id}", response_model=ChatRead)
|
||||
async def read_chat(
|
||||
chat_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Get a specific chat by ID.
|
||||
Requires CHATS_READ permission for the search space.
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(select(Chat).filter(Chat.id == chat_id))
|
||||
chat = result.scalars().first()
|
||||
|
||||
if not chat:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Chat not found",
|
||||
)
|
||||
|
||||
# Check permission for the search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
chat.search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to read chats in this search space",
|
||||
)
|
||||
|
||||
return chat
|
||||
except HTTPException:
|
||||
raise
|
||||
except OperationalError:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="An unexpected error occurred while fetching the chat.",
|
||||
) from None
|
||||
|
||||
|
||||
@router.put("/chats/{chat_id}", response_model=ChatRead)
|
||||
async def update_chat(
|
||||
chat_id: int,
|
||||
chat_update: ChatUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Update a chat.
|
||||
Requires CHATS_UPDATE permission for the search space.
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(select(Chat).filter(Chat.id == chat_id))
|
||||
db_chat = result.scalars().first()
|
||||
|
||||
if not db_chat:
|
||||
raise HTTPException(status_code=404, detail="Chat not found")
|
||||
|
||||
# Check permission for the search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_chat.search_space_id,
|
||||
Permission.CHATS_UPDATE.value,
|
||||
"You don't have permission to update chats in this search space",
|
||||
)
|
||||
|
||||
update_data = chat_update.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
if key == "messages":
|
||||
db_chat.state_version = len(update_data["messages"])
|
||||
setattr(db_chat, key, value)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_chat)
|
||||
return db_chat
|
||||
except HTTPException:
|
||||
raise
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Database constraint violation. Please check your input data.",
|
||||
) from None
|
||||
except OperationalError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="An unexpected error occurred while updating the chat.",
|
||||
) from None
|
||||
|
||||
|
||||
@router.delete("/chats/{chat_id}", response_model=dict)
|
||||
async def delete_chat(
|
||||
chat_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Delete a chat.
|
||||
Requires CHATS_DELETE permission for the search space.
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(select(Chat).filter(Chat.id == chat_id))
|
||||
db_chat = result.scalars().first()
|
||||
|
||||
if not db_chat:
|
||||
raise HTTPException(status_code=404, detail="Chat not found")
|
||||
|
||||
# Check permission for the search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_chat.search_space_id,
|
||||
Permission.CHATS_DELETE.value,
|
||||
"You don't have permission to delete chats in this search space",
|
||||
)
|
||||
|
||||
await session.delete(db_chat)
|
||||
await session.commit()
|
||||
return {"message": "Chat deleted successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Cannot delete chat due to existing dependencies."
|
||||
) from None
|
||||
except OperationalError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="An unexpected error occurred while deleting the chat.",
|
||||
) from None
|
||||
|
|
@ -13,6 +13,7 @@ These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
|
|||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
|
@ -23,12 +24,14 @@ from app.db import (
|
|||
NewChatMessageRole,
|
||||
NewChatThread,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas.new_chat import (
|
||||
NewChatMessageAppend,
|
||||
NewChatMessageRead,
|
||||
NewChatRequest,
|
||||
NewChatThreadCreate,
|
||||
NewChatThreadRead,
|
||||
NewChatThreadUpdate,
|
||||
|
|
@ -37,6 +40,7 @@ from app.schemas.new_chat import (
|
|||
ThreadListItem,
|
||||
ThreadListResponse,
|
||||
)
|
||||
from app.tasks.chat.stream_new_chat import stream_new_chat
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
|
|
@ -74,13 +78,10 @@ async def list_threads(
|
|||
"You don't have permission to read chats in this search space",
|
||||
)
|
||||
|
||||
# Get all threads for this user in this search space
|
||||
# Get all threads in this search space
|
||||
query = (
|
||||
select(NewChatThread)
|
||||
.filter(
|
||||
NewChatThread.search_space_id == search_space_id,
|
||||
NewChatThread.user_id == user.id,
|
||||
)
|
||||
.filter(NewChatThread.search_space_id == search_space_id)
|
||||
.order_by(NewChatThread.updated_at.desc())
|
||||
)
|
||||
|
||||
|
|
@ -153,7 +154,6 @@ async def search_threads(
|
|||
select(NewChatThread)
|
||||
.filter(
|
||||
NewChatThread.search_space_id == search_space_id,
|
||||
NewChatThread.user_id == user.id,
|
||||
NewChatThread.title.ilike(f"%{title}%"),
|
||||
)
|
||||
.order_by(NewChatThread.updated_at.desc())
|
||||
|
|
@ -211,7 +211,6 @@ async def create_thread(
|
|||
title=thread.title,
|
||||
archived=thread.archived,
|
||||
search_space_id=thread.search_space_id,
|
||||
user_id=user.id,
|
||||
updated_at=now,
|
||||
)
|
||||
session.add(db_thread)
|
||||
|
|
@ -273,12 +272,6 @@ async def get_thread_messages(
|
|||
"You don't have permission to read chats in this search space",
|
||||
)
|
||||
|
||||
# Ensure user owns this thread
|
||||
if thread.user_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have access to this thread"
|
||||
)
|
||||
|
||||
# Return messages in the format expected by assistant-ui
|
||||
messages = [
|
||||
NewChatMessageRead(
|
||||
|
|
@ -336,11 +329,6 @@ async def get_thread_full(
|
|||
"You don't have permission to read chats in this search space",
|
||||
)
|
||||
|
||||
if thread.user_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have access to this thread"
|
||||
)
|
||||
|
||||
return thread
|
||||
|
||||
except HTTPException:
|
||||
|
|
@ -386,11 +374,6 @@ async def update_thread(
|
|||
"You don't have permission to update chats in this search space",
|
||||
)
|
||||
|
||||
if db_thread.user_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have access to this thread"
|
||||
)
|
||||
|
||||
# Update fields
|
||||
update_data = thread_update.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
|
|
@ -451,11 +434,6 @@ async def delete_thread(
|
|||
"You don't have permission to delete chats in this search space",
|
||||
)
|
||||
|
||||
if db_thread.user_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have access to this thread"
|
||||
)
|
||||
|
||||
await session.delete(db_thread)
|
||||
await session.commit()
|
||||
return {"message": "Thread deleted successfully"}
|
||||
|
|
@ -530,11 +508,6 @@ async def append_message(
|
|||
"You don't have permission to update chats in this search space",
|
||||
)
|
||||
|
||||
if thread.user_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have access to this thread"
|
||||
)
|
||||
|
||||
# Convert string role to enum
|
||||
role_str = (
|
||||
message.role.lower() if isinstance(message.role, str) else message.role
|
||||
|
|
@ -639,11 +612,6 @@ async def list_messages(
|
|||
"You don't have permission to read chats in this search space",
|
||||
)
|
||||
|
||||
if thread.user_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have access to this thread"
|
||||
)
|
||||
|
||||
# Get messages
|
||||
query = (
|
||||
select(NewChatMessage)
|
||||
|
|
@ -667,3 +635,79 @@ async def list_messages(
|
|||
status_code=500,
|
||||
detail=f"An unexpected error occurred while fetching messages: {e!s}",
|
||||
) from None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Chat Streaming Endpoint
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/new_chat")
|
||||
async def handle_new_chat(
|
||||
request: NewChatRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Stream chat responses from the deep agent.
|
||||
|
||||
This endpoint handles the new chat functionality with streaming responses
|
||||
using Server-Sent Events (SSE) format compatible with Vercel AI SDK.
|
||||
|
||||
Requires CHATS_CREATE permission.
|
||||
"""
|
||||
try:
|
||||
# Verify thread exists and user has permission
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == request.chat_id)
|
||||
)
|
||||
thread = result.scalars().first()
|
||||
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_CREATE.value,
|
||||
"You don't have permission to chat in this search space",
|
||||
)
|
||||
|
||||
# Get search space to check LLM config preferences
|
||||
search_space_result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
|
||||
)
|
||||
search_space = search_space_result.scalars().first()
|
||||
|
||||
# Determine LLM config ID (use search space preference or default)
|
||||
llm_config_id = -1 # Default to first global config
|
||||
if search_space and search_space.fast_llm_id:
|
||||
llm_config_id = search_space.fast_llm_id
|
||||
|
||||
# Return streaming response
|
||||
return StreamingResponse(
|
||||
stream_new_chat(
|
||||
user_query=request.user_query,
|
||||
user_id=str(user.id),
|
||||
search_space_id=request.search_space_id,
|
||||
chat_id=request.chat_id,
|
||||
session=session,
|
||||
llm_config_id=llm_config_id,
|
||||
messages=request.messages,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"An unexpected error occurred: {e!s}",
|
||||
) from None
|
||||
|
|
|
|||
|
|
@ -1,509 +0,0 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import (
|
||||
Chat,
|
||||
Permission,
|
||||
Podcast,
|
||||
SearchSpace,
|
||||
SearchSpaceMembership,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas import (
|
||||
PodcastCreate,
|
||||
PodcastGenerateRequest,
|
||||
PodcastRead,
|
||||
PodcastUpdate,
|
||||
)
|
||||
from app.tasks.podcast_tasks import generate_chat_podcast
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/podcasts", response_model=PodcastRead)
|
||||
async def create_podcast(
|
||||
podcast: PodcastCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Create a new podcast.
|
||||
Requires PODCASTS_CREATE permission.
|
||||
"""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
podcast.search_space_id,
|
||||
Permission.PODCASTS_CREATE.value,
|
||||
"You don't have permission to create podcasts in this search space",
|
||||
)
|
||||
db_podcast = Podcast(**podcast.model_dump())
|
||||
session.add(db_podcast)
|
||||
await session.commit()
|
||||
await session.refresh(db_podcast)
|
||||
return db_podcast
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Podcast creation failed due to constraint violation",
|
||||
) from None
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error occurred while creating podcast"
|
||||
) from None
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An unexpected error occurred"
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/podcasts", response_model=list[PodcastRead])
|
||||
async def read_podcasts(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
search_space_id: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
List podcasts the user has access to.
|
||||
Requires PODCASTS_READ permission for the search space(s).
|
||||
"""
|
||||
if skip < 0 or limit < 1:
|
||||
raise HTTPException(status_code=400, detail="Invalid pagination parameters")
|
||||
try:
|
||||
if search_space_id is not None:
|
||||
# Check permission for specific search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.PODCASTS_READ.value,
|
||||
"You don't have permission to read podcasts in this search space",
|
||||
)
|
||||
result = await session.execute(
|
||||
select(Podcast)
|
||||
.filter(Podcast.search_space_id == search_space_id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
else:
|
||||
# Get podcasts from all search spaces user has membership in
|
||||
result = await session.execute(
|
||||
select(Podcast)
|
||||
.join(SearchSpace)
|
||||
.join(SearchSpaceMembership)
|
||||
.filter(SearchSpaceMembership.user_id == user.id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
except HTTPException:
|
||||
raise
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error occurred while fetching podcasts"
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/podcasts/{podcast_id}", response_model=PodcastRead)
|
||||
async def read_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Get a specific podcast by ID.
|
||||
Requires PODCASTS_READ permission for the search space.
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
|
||||
podcast = result.scalars().first()
|
||||
|
||||
if not podcast:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Podcast not found",
|
||||
)
|
||||
|
||||
# Check permission for the search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
podcast.search_space_id,
|
||||
Permission.PODCASTS_READ.value,
|
||||
"You don't have permission to read podcasts in this search space",
|
||||
)
|
||||
|
||||
return podcast
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error occurred while fetching podcast"
|
||||
) from None
|
||||
|
||||
|
||||
@router.put("/podcasts/{podcast_id}", response_model=PodcastRead)
|
||||
async def update_podcast(
|
||||
podcast_id: int,
|
||||
podcast_update: PodcastUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Update a podcast.
|
||||
Requires PODCASTS_UPDATE permission for the search space.
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
|
||||
db_podcast = result.scalars().first()
|
||||
|
||||
if not db_podcast:
|
||||
raise HTTPException(status_code=404, detail="Podcast not found")
|
||||
|
||||
# Check permission for the search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_podcast.search_space_id,
|
||||
Permission.PODCASTS_UPDATE.value,
|
||||
"You don't have permission to update podcasts in this search space",
|
||||
)
|
||||
|
||||
update_data = podcast_update.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_podcast, key, value)
|
||||
await session.commit()
|
||||
await session.refresh(db_podcast)
|
||||
return db_podcast
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Update failed due to constraint violation"
|
||||
) from None
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error occurred while updating podcast"
|
||||
) from None
|
||||
|
||||
|
||||
@router.delete("/podcasts/{podcast_id}", response_model=dict)
|
||||
async def delete_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Delete a podcast.
|
||||
Requires PODCASTS_DELETE permission for the search space.
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
|
||||
db_podcast = result.scalars().first()
|
||||
|
||||
if not db_podcast:
|
||||
raise HTTPException(status_code=404, detail="Podcast not found")
|
||||
|
||||
# Check permission for the search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_podcast.search_space_id,
|
||||
Permission.PODCASTS_DELETE.value,
|
||||
"You don't have permission to delete podcasts in this search space",
|
||||
)
|
||||
|
||||
await session.delete(db_podcast)
|
||||
await session.commit()
|
||||
return {"message": "Podcast deleted successfully"}
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error occurred while deleting podcast"
|
||||
) from None
|
||||
|
||||
|
||||
async def generate_chat_podcast_with_new_session(
|
||||
chat_id: int,
|
||||
search_space_id: int,
|
||||
user_id: int,
|
||||
podcast_title: str | None = None,
|
||||
user_prompt: str | None = None,
|
||||
):
|
||||
"""Create a new session and process chat podcast generation."""
|
||||
from app.db import async_session_maker
|
||||
|
||||
async with async_session_maker() as session:
|
||||
try:
|
||||
await generate_chat_podcast(
|
||||
session, chat_id, search_space_id, user_id, podcast_title, user_prompt
|
||||
)
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.error(f"Error generating podcast from chat: {e!s}")
|
||||
|
||||
|
||||
@router.post("/podcasts/generate")
|
||||
async def generate_podcast(
|
||||
request: PodcastGenerateRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Generate a podcast from a chat or document.
|
||||
Requires PODCASTS_CREATE permission.
|
||||
"""
|
||||
try:
|
||||
# Check if the user has permission to create podcasts
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
request.search_space_id,
|
||||
Permission.PODCASTS_CREATE.value,
|
||||
"You don't have permission to create podcasts in this search space",
|
||||
)
|
||||
|
||||
if request.type == "CHAT":
|
||||
# Verify that all chat IDs belong to this user and search space
|
||||
query = (
|
||||
select(Chat)
|
||||
.filter(
|
||||
Chat.id.in_(request.ids),
|
||||
Chat.search_space_id == request.search_space_id,
|
||||
)
|
||||
.join(SearchSpace)
|
||||
.filter(SearchSpace.user_id == user.id)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
valid_chats = result.scalars().all()
|
||||
valid_chat_ids = [chat.id for chat in valid_chats]
|
||||
|
||||
# If any requested ID is not in valid IDs, raise error immediately
|
||||
if len(valid_chat_ids) != len(request.ids):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="One or more chat IDs do not belong to this user or search space",
|
||||
)
|
||||
|
||||
from app.tasks.celery_tasks.podcast_tasks import (
|
||||
generate_chat_podcast_task,
|
||||
)
|
||||
|
||||
# Add Celery tasks for each chat ID
|
||||
for chat_id in valid_chat_ids:
|
||||
generate_chat_podcast_task.delay(
|
||||
chat_id,
|
||||
request.search_space_id,
|
||||
user.id,
|
||||
request.podcast_title,
|
||||
request.user_prompt,
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Podcast generation started",
|
||||
}
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Podcast generation failed due to constraint violation",
|
||||
) from None
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error occurred while generating podcast"
|
||||
) from None
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"An unexpected error occurred: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/podcasts/{podcast_id}/stream")
|
||||
async def stream_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Stream a podcast audio file.
|
||||
Requires PODCASTS_READ permission for the search space.
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
|
||||
podcast = result.scalars().first()
|
||||
|
||||
if not podcast:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Podcast not found",
|
||||
)
|
||||
|
||||
# Check permission for the search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
podcast.search_space_id,
|
||||
Permission.PODCASTS_READ.value,
|
||||
"You don't have permission to access podcasts in this search space",
|
||||
)
|
||||
|
||||
# Get the file path
|
||||
file_path = podcast.file_location
|
||||
|
||||
# Check if the file exists
|
||||
if not os.path.isfile(file_path):
|
||||
raise HTTPException(status_code=404, detail="Podcast audio file not found")
|
||||
|
||||
# Define a generator function to stream the file
|
||||
def iterfile():
|
||||
with open(file_path, mode="rb") as file_like:
|
||||
yield from file_like
|
||||
|
||||
# Return a streaming response with appropriate headers
|
||||
return StreamingResponse(
|
||||
iterfile(),
|
||||
media_type="audio/mpeg",
|
||||
headers={
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Disposition": f"inline; filename={Path(file_path).name}",
|
||||
},
|
||||
)
|
||||
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error streaming podcast: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/podcasts/by-chat/{chat_id}", response_model=PodcastRead | None)
|
||||
async def get_podcast_by_chat_id(
|
||||
chat_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Get a podcast by its associated chat ID.
|
||||
Requires PODCASTS_READ permission for the search space.
|
||||
"""
|
||||
try:
|
||||
# First get the chat to find its search space
|
||||
chat_result = await session.execute(select(Chat).filter(Chat.id == chat_id))
|
||||
chat = chat_result.scalars().first()
|
||||
|
||||
if not chat:
|
||||
return None
|
||||
|
||||
# Check permission for the search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
chat.search_space_id,
|
||||
Permission.PODCASTS_READ.value,
|
||||
"You don't have permission to read podcasts in this search space",
|
||||
)
|
||||
|
||||
# Get the podcast
|
||||
result = await session.execute(
|
||||
select(Podcast).filter(Podcast.chat_id == chat_id)
|
||||
)
|
||||
podcast = result.scalars().first()
|
||||
|
||||
return podcast
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error fetching podcast: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/podcasts/task/{task_id}/status")
|
||||
async def get_podcast_task_status(
|
||||
task_id: str,
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Get the status of a podcast generation task.
|
||||
Used by new-chat frontend to poll for completion.
|
||||
|
||||
Returns:
|
||||
- status: "processing" | "success" | "error"
|
||||
- podcast_id: (only if status == "success")
|
||||
- title: (only if status == "success")
|
||||
- error: (only if status == "error")
|
||||
"""
|
||||
try:
|
||||
from celery.result import AsyncResult
|
||||
|
||||
from app.celery_app import celery_app
|
||||
|
||||
result = AsyncResult(task_id, app=celery_app)
|
||||
|
||||
if result.ready():
|
||||
# Task completed
|
||||
if result.successful():
|
||||
task_result = result.result
|
||||
if isinstance(task_result, dict):
|
||||
if task_result.get("status") == "success":
|
||||
return {
|
||||
"status": "success",
|
||||
"podcast_id": task_result.get("podcast_id"),
|
||||
"title": task_result.get("title"),
|
||||
"transcript_entries": task_result.get("transcript_entries"),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": task_result.get("error", "Unknown error"),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "Unexpected task result format",
|
||||
}
|
||||
else:
|
||||
# Task failed
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(result.result) if result.result else "Task failed",
|
||||
}
|
||||
else:
|
||||
# Task still processing
|
||||
return {
|
||||
"status": "processing",
|
||||
"state": result.state,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error checking task status: {e!s}"
|
||||
) from e
|
||||
|
|
@ -1,13 +1,4 @@
|
|||
from .base import IDModel, TimestampModel
|
||||
from .chats import (
|
||||
AISDKChatRequest,
|
||||
ChatBase,
|
||||
ChatCreate,
|
||||
ChatRead,
|
||||
ChatReadWithoutMessages,
|
||||
ChatUpdate,
|
||||
NewChatRequest,
|
||||
)
|
||||
from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate
|
||||
from .documents import (
|
||||
DocumentBase,
|
||||
|
|
@ -22,9 +13,11 @@ from .documents import (
|
|||
from .llm_config import LLMConfigBase, LLMConfigCreate, LLMConfigRead, LLMConfigUpdate
|
||||
from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate
|
||||
from .new_chat import (
|
||||
ChatMessage,
|
||||
NewChatMessageAppend,
|
||||
NewChatMessageCreate,
|
||||
NewChatMessageRead,
|
||||
NewChatRequest,
|
||||
NewChatThreadCreate,
|
||||
NewChatThreadRead,
|
||||
NewChatThreadUpdate,
|
||||
|
|
@ -33,13 +26,6 @@ from .new_chat import (
|
|||
ThreadListItem,
|
||||
ThreadListResponse,
|
||||
)
|
||||
from .podcasts import (
|
||||
PodcastBase,
|
||||
PodcastCreate,
|
||||
PodcastGenerateRequest,
|
||||
PodcastRead,
|
||||
PodcastUpdate,
|
||||
)
|
||||
from .rbac_schemas import (
|
||||
InviteAcceptRequest,
|
||||
InviteAcceptResponse,
|
||||
|
|
@ -73,44 +59,8 @@ from .search_space import (
|
|||
from .users import UserCreate, UserRead, UserUpdate
|
||||
|
||||
__all__ = [
|
||||
"AISDKChatRequest",
|
||||
"ChatBase",
|
||||
"ChatCreate",
|
||||
"ChatRead",
|
||||
"ChatReadWithoutMessages",
|
||||
"ChatUpdate",
|
||||
"ChunkBase",
|
||||
"ChunkCreate",
|
||||
"ChunkRead",
|
||||
"ChunkUpdate",
|
||||
"DocumentBase",
|
||||
"DocumentRead",
|
||||
"DocumentUpdate",
|
||||
"DocumentWithChunksRead",
|
||||
"DocumentsCreate",
|
||||
"ExtensionDocumentContent",
|
||||
"ExtensionDocumentMetadata",
|
||||
"IDModel",
|
||||
# RBAC schemas
|
||||
"InviteAcceptRequest",
|
||||
"InviteAcceptResponse",
|
||||
"InviteCreate",
|
||||
"InviteInfoResponse",
|
||||
"InviteRead",
|
||||
"InviteUpdate",
|
||||
"LLMConfigBase",
|
||||
"LLMConfigCreate",
|
||||
"LLMConfigRead",
|
||||
"LLMConfigUpdate",
|
||||
"LogBase",
|
||||
"LogCreate",
|
||||
"LogFilter",
|
||||
"LogRead",
|
||||
"LogUpdate",
|
||||
"MembershipRead",
|
||||
"MembershipReadWithUser",
|
||||
"MembershipUpdate",
|
||||
# New chat schemas (assistant-ui integration)
|
||||
# Chat schemas (assistant-ui integration)
|
||||
"ChatMessage",
|
||||
"NewChatMessageAppend",
|
||||
"NewChatMessageCreate",
|
||||
"NewChatMessageRead",
|
||||
|
|
@ -119,30 +69,64 @@ __all__ = [
|
|||
"NewChatThreadRead",
|
||||
"NewChatThreadUpdate",
|
||||
"NewChatThreadWithMessages",
|
||||
"ThreadHistoryLoadResponse",
|
||||
"ThreadListItem",
|
||||
"ThreadListResponse",
|
||||
# Chunk schemas
|
||||
"ChunkBase",
|
||||
"ChunkCreate",
|
||||
"ChunkRead",
|
||||
"ChunkUpdate",
|
||||
# Document schemas
|
||||
"DocumentBase",
|
||||
"DocumentRead",
|
||||
"DocumentUpdate",
|
||||
"DocumentWithChunksRead",
|
||||
"DocumentsCreate",
|
||||
"ExtensionDocumentContent",
|
||||
"ExtensionDocumentMetadata",
|
||||
"PaginatedResponse",
|
||||
# Base schemas
|
||||
"IDModel",
|
||||
"TimestampModel",
|
||||
# LLM Config schemas
|
||||
"LLMConfigBase",
|
||||
"LLMConfigCreate",
|
||||
"LLMConfigRead",
|
||||
"LLMConfigUpdate",
|
||||
# Log schemas
|
||||
"LogBase",
|
||||
"LogCreate",
|
||||
"LogFilter",
|
||||
"LogRead",
|
||||
"LogUpdate",
|
||||
# RBAC schemas
|
||||
"InviteAcceptRequest",
|
||||
"InviteAcceptResponse",
|
||||
"InviteCreate",
|
||||
"InviteInfoResponse",
|
||||
"InviteRead",
|
||||
"InviteUpdate",
|
||||
"MembershipRead",
|
||||
"MembershipReadWithUser",
|
||||
"MembershipUpdate",
|
||||
"PermissionInfo",
|
||||
"PermissionsListResponse",
|
||||
"PodcastBase",
|
||||
"PodcastCreate",
|
||||
"PodcastGenerateRequest",
|
||||
"PodcastRead",
|
||||
"PodcastUpdate",
|
||||
"RoleCreate",
|
||||
"RoleRead",
|
||||
"RoleUpdate",
|
||||
# Search source connector schemas
|
||||
"SearchSourceConnectorBase",
|
||||
"SearchSourceConnectorCreate",
|
||||
"SearchSourceConnectorRead",
|
||||
"SearchSourceConnectorUpdate",
|
||||
# Search space schemas
|
||||
"SearchSpaceBase",
|
||||
"SearchSpaceCreate",
|
||||
"SearchSpaceRead",
|
||||
"SearchSpaceUpdate",
|
||||
"SearchSpaceWithStats",
|
||||
"ThreadHistoryLoadResponse",
|
||||
"ThreadListItem",
|
||||
"ThreadListResponse",
|
||||
"TimestampModel",
|
||||
# User schemas
|
||||
"UserCreate",
|
||||
"UserRead",
|
||||
"UserSearchSpaceAccess",
|
||||
|
|
|
|||
|
|
@ -1,80 +0,0 @@
|
|||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from app.db import ChatType
|
||||
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
|
||||
class ChatBase(BaseModel):
|
||||
type: ChatType
|
||||
title: str
|
||||
initial_connectors: list[str] | None = None
|
||||
messages: list[Any]
|
||||
search_space_id: int
|
||||
state_version: int = 1
|
||||
|
||||
|
||||
class ChatBaseWithoutMessages(BaseModel):
|
||||
type: ChatType
|
||||
title: str
|
||||
search_space_id: int
|
||||
state_version: int = 1
|
||||
|
||||
|
||||
class ClientAttachment(BaseModel):
|
||||
name: str
|
||||
content_type: str
|
||||
url: str
|
||||
|
||||
|
||||
class ToolInvocation(BaseModel):
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
args: dict
|
||||
result: dict
|
||||
|
||||
|
||||
# class ClientMessage(BaseModel):
|
||||
# role: str
|
||||
# content: str
|
||||
# experimental_attachments: Optional[List[ClientAttachment]] = None
|
||||
# toolInvocations: Optional[List[ToolInvocation]] = None
|
||||
|
||||
|
||||
class AISDKChatRequest(BaseModel):
|
||||
messages: list[Any]
|
||||
data: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
"""A single message in the chat history."""
|
||||
|
||||
role: str # "user" or "assistant"
|
||||
content: str
|
||||
|
||||
|
||||
class NewChatRequest(BaseModel):
|
||||
"""Request schema for the new deep agent chat endpoint."""
|
||||
|
||||
chat_id: int
|
||||
user_query: str
|
||||
search_space_id: int
|
||||
messages: list[ChatMessage] | None = None # Optional chat history from frontend
|
||||
|
||||
|
||||
class ChatCreate(ChatBase):
|
||||
pass
|
||||
|
||||
|
||||
class ChatUpdate(ChatBase):
|
||||
pass
|
||||
|
||||
|
||||
class ChatRead(ChatBase, IDModel, TimestampModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ChatReadWithoutMessages(ChatBaseWithoutMessages, IDModel, TimestampModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
|
@ -127,3 +127,24 @@ class ThreadListResponse(BaseModel):
|
|||
|
||||
threads: list[ThreadListItem]
|
||||
archived_threads: list[ThreadListItem]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Chat Request Schemas (for deep agent)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
"""A single message in the chat history."""
|
||||
|
||||
role: str # "user" or "assistant"
|
||||
content: str
|
||||
|
||||
|
||||
class NewChatRequest(BaseModel):
|
||||
"""Request schema for the deep agent chat endpoint."""
|
||||
|
||||
chat_id: int
|
||||
user_query: str
|
||||
search_space_id: int
|
||||
messages: list[ChatMessage] | None = None # Optional chat history from frontend
|
||||
|
|
|
|||
|
|
@ -1,33 +0,0 @@
|
|||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
|
||||
class PodcastBase(BaseModel):
|
||||
title: str
|
||||
podcast_transcript: list[Any]
|
||||
file_location: str = ""
|
||||
search_space_id: int
|
||||
chat_state_version: int | None = None
|
||||
|
||||
|
||||
class PodcastCreate(PodcastBase):
|
||||
pass
|
||||
|
||||
|
||||
class PodcastUpdate(PodcastBase):
|
||||
pass
|
||||
|
||||
|
||||
class PodcastRead(PodcastBase, IDModel, TimestampModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class PodcastGenerateRequest(BaseModel):
|
||||
type: Literal["DOCUMENT", "CHAT"]
|
||||
ids: list[int]
|
||||
search_space_id: int
|
||||
podcast_title: str | None = None
|
||||
user_prompt: str | None = None
|
||||
|
|
@ -1,75 +0,0 @@
|
|||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.researcher.graph import graph as researcher_graph
|
||||
from app.agents.researcher.state import State
|
||||
from app.services.streaming_service import StreamingService
|
||||
|
||||
|
||||
async def stream_connector_search_results(
|
||||
user_query: str,
|
||||
user_id: str | UUID,
|
||||
search_space_id: int,
|
||||
session: AsyncSession,
|
||||
research_mode: str,
|
||||
selected_connectors: list[str],
|
||||
langchain_chat_history: list[Any],
|
||||
document_ids_to_add_in_context: list[int],
|
||||
language: str | None = None,
|
||||
top_k: int = 10,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream connector search results to the client
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
user_id: The user's ID (can be UUID object or string)
|
||||
search_space_id: The search space ID
|
||||
session: The database session
|
||||
research_mode: The research mode
|
||||
selected_connectors: List of selected connectors
|
||||
|
||||
Yields:
|
||||
str: Formatted response strings
|
||||
"""
|
||||
streaming_service = StreamingService()
|
||||
|
||||
# Convert UUID to string if needed
|
||||
user_id_str = str(user_id) if isinstance(user_id, UUID) else user_id
|
||||
|
||||
# Sample configuration
|
||||
config = {
|
||||
"configurable": {
|
||||
"user_query": user_query,
|
||||
"connectors_to_search": selected_connectors,
|
||||
"user_id": user_id_str,
|
||||
"search_space_id": search_space_id,
|
||||
"document_ids_to_add_in_context": document_ids_to_add_in_context,
|
||||
"language": language, # Add language to the configuration
|
||||
"top_k": top_k, # Add top_k to the configuration
|
||||
}
|
||||
}
|
||||
# print(f"Researcher configuration: {config['configurable']}") # Debug print
|
||||
# Initialize state with database session and streaming service
|
||||
initial_state = State(
|
||||
db_session=session,
|
||||
streaming_service=streaming_service,
|
||||
chat_history=langchain_chat_history,
|
||||
)
|
||||
|
||||
# Run the graph directly
|
||||
print("\nRunning the complete researcher workflow...")
|
||||
|
||||
# Use streaming with config parameter
|
||||
async for chunk in researcher_graph.astream(
|
||||
initial_state,
|
||||
config=config,
|
||||
stream_mode="custom",
|
||||
):
|
||||
if isinstance(chunk, dict) and "yield_value" in chunk:
|
||||
yield chunk["yield_value"]
|
||||
|
||||
yield streaming_service.format_completion()
|
||||
|
|
@ -18,7 +18,7 @@ from app.agents.new_chat.llm_config import (
|
|||
create_chat_litellm_from_config,
|
||||
load_llm_config_from_yaml,
|
||||
)
|
||||
from app.schemas.chats import ChatMessage
|
||||
from app.schemas.new_chat import ChatMessage
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue