mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 18:35:19 +02:00
feat: migrated old chat to new chat
This commit is contained in:
parent
b5e20e7515
commit
bb971460fc
25 changed files with 368 additions and 4391 deletions
|
|
@ -0,0 +1,216 @@
|
||||||
|
"""Migrate old chats to new_chat_threads and remove old tables
|
||||||
|
|
||||||
|
Revision ID: 49
|
||||||
|
Revises: 48
|
||||||
|
Create Date: 2025-12-21
|
||||||
|
|
||||||
|
This migration:
|
||||||
|
1. Migrates data from old 'chats' table to 'new_chat_threads' and 'new_chat_messages'
|
||||||
|
2. Drops the 'podcasts' table (podcast data is not migrated as per user request)
|
||||||
|
3. Drops the 'chats' table
|
||||||
|
4. Removes the 'chattype' enum
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "49"
|
||||||
|
down_revision: str | None = "48"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_text_content(content: str | dict | list) -> str:
|
||||||
|
"""Extract plain text content from various message formats."""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, dict):
|
||||||
|
# Handle dict with 'text' key
|
||||||
|
if "text" in content:
|
||||||
|
return content["text"]
|
||||||
|
return str(content)
|
||||||
|
if isinstance(content, list):
|
||||||
|
# Handle list of parts (e.g., [{"type": "text", "text": "..."}])
|
||||||
|
texts = []
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, dict) and part.get("type") == "text":
|
||||||
|
texts.append(part.get("text", ""))
|
||||||
|
elif isinstance(part, str):
|
||||||
|
texts.append(part)
|
||||||
|
return "\n".join(texts) if texts else ""
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def parse_timestamp(ts, fallback):
|
||||||
|
"""Parse ISO timestamp string to datetime object."""
|
||||||
|
if ts is None:
|
||||||
|
return fallback
|
||||||
|
if isinstance(ts, datetime):
|
||||||
|
return ts
|
||||||
|
if isinstance(ts, str):
|
||||||
|
try:
|
||||||
|
# Handle ISO format like '2025-11-26T22:43:34.399Z'
|
||||||
|
ts = ts.replace("Z", "+00:00")
|
||||||
|
return datetime.fromisoformat(ts)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return fallback
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Migrate old chats to new_chat_threads and remove old tables."""
|
||||||
|
connection = op.get_bind()
|
||||||
|
|
||||||
|
# Get all old chats
|
||||||
|
old_chats = connection.execute(
|
||||||
|
sa.text("""
|
||||||
|
SELECT id, title, messages, search_space_id, created_at
|
||||||
|
FROM chats
|
||||||
|
ORDER BY created_at ASC
|
||||||
|
""")
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
print(f"[Migration 49] Found {len(old_chats)} old chats to migrate")
|
||||||
|
|
||||||
|
migrated_count = 0
|
||||||
|
for chat_id, title, messages_json, search_space_id, created_at in old_chats:
|
||||||
|
try:
|
||||||
|
# Parse messages JSON
|
||||||
|
if isinstance(messages_json, str):
|
||||||
|
messages = json.loads(messages_json)
|
||||||
|
else:
|
||||||
|
messages = messages_json or []
|
||||||
|
|
||||||
|
# Skip empty chats
|
||||||
|
if not messages:
|
||||||
|
print(f"[Migration 49] Skipping empty chat {chat_id}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create new thread
|
||||||
|
result = connection.execute(
|
||||||
|
sa.text("""
|
||||||
|
INSERT INTO new_chat_threads
|
||||||
|
(title, archived, search_space_id, created_at, updated_at)
|
||||||
|
VALUES (:title, FALSE, :search_space_id, :created_at, :created_at)
|
||||||
|
RETURNING id
|
||||||
|
"""),
|
||||||
|
{
|
||||||
|
"title": title or "Migrated Chat",
|
||||||
|
"search_space_id": search_space_id,
|
||||||
|
"created_at": created_at,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
new_thread_id = result.fetchone()[0]
|
||||||
|
|
||||||
|
# Migrate messages - only user and assistant roles, skip SOURCES/TERMINAL_INFO
|
||||||
|
message_count = 0
|
||||||
|
for msg in messages:
|
||||||
|
role_lower = msg.get("role", "").lower()
|
||||||
|
|
||||||
|
# Only migrate user and assistant messages
|
||||||
|
if role_lower not in ("user", "assistant"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Convert to uppercase for database enum
|
||||||
|
role = role_lower.upper()
|
||||||
|
|
||||||
|
# Extract content - handle various formats
|
||||||
|
content_raw = msg.get("content", "")
|
||||||
|
content_text = extract_text_content(content_raw)
|
||||||
|
|
||||||
|
# Skip empty messages
|
||||||
|
if not content_text.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Parse message timestamp
|
||||||
|
msg_created_at = parse_timestamp(msg.get("createdAt"), created_at)
|
||||||
|
|
||||||
|
# Store content as JSONB array format for assistant-ui compatibility
|
||||||
|
content_list = [{"type": "text", "text": content_text}]
|
||||||
|
|
||||||
|
# Use direct SQL with string interpolation for the enum since CAST doesn't work
|
||||||
|
# The enum value comes from trusted source (our own code), not user input
|
||||||
|
connection.execute(
|
||||||
|
sa.text(f"""
|
||||||
|
INSERT INTO new_chat_messages
|
||||||
|
(thread_id, role, content, created_at)
|
||||||
|
VALUES (:thread_id, '{role}', CAST(:content AS jsonb), :created_at)
|
||||||
|
"""),
|
||||||
|
{
|
||||||
|
"thread_id": new_thread_id,
|
||||||
|
"content": json.dumps(content_list),
|
||||||
|
"created_at": msg_created_at,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
message_count += 1
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"[Migration 49] Migrated chat {chat_id} -> thread {new_thread_id} ({message_count} messages)"
|
||||||
|
)
|
||||||
|
migrated_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Migration 49] Error migrating chat {chat_id}: {e}")
|
||||||
|
# Re-raise to abort migration - we don't want partial data
|
||||||
|
raise
|
||||||
|
|
||||||
|
print(f"[Migration 49] Successfully migrated {migrated_count} chats")
|
||||||
|
|
||||||
|
# Drop podcasts table (FK references chats, so drop first)
|
||||||
|
print("[Migration 49] Dropping podcasts table...")
|
||||||
|
op.drop_table("podcasts")
|
||||||
|
|
||||||
|
# Drop chats table
|
||||||
|
print("[Migration 49] Dropping chats table...")
|
||||||
|
op.drop_table("chats")
|
||||||
|
|
||||||
|
# Drop chattype enum
|
||||||
|
print("[Migration 49] Dropping chattype enum...")
|
||||||
|
op.execute(sa.text("DROP TYPE IF EXISTS chattype"))
|
||||||
|
|
||||||
|
print("[Migration 49] Migration complete!")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Recreate old tables (data cannot be restored)."""
|
||||||
|
# Recreate chattype enum
|
||||||
|
op.execute(
|
||||||
|
sa.text("""
|
||||||
|
CREATE TYPE chattype AS ENUM ('QNA')
|
||||||
|
""")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Recreate chats table
|
||||||
|
op.create_table(
|
||||||
|
"chats",
|
||||||
|
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||||
|
sa.Column("type", sa.Enum("QNA", name="chattype"), nullable=False),
|
||||||
|
sa.Column("title", sa.String(), nullable=False, index=True),
|
||||||
|
sa.Column("initial_connectors", sa.ARRAY(sa.String()), nullable=True),
|
||||||
|
sa.Column("messages", sa.JSON(), nullable=False),
|
||||||
|
sa.Column("state_version", sa.BigInteger(), nullable=False, default=1),
|
||||||
|
sa.Column("search_space_id", sa.Integer(), sa.ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("created_at", sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Recreate podcasts table
|
||||||
|
op.create_table(
|
||||||
|
"podcasts",
|
||||||
|
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||||
|
sa.Column("title", sa.String(), nullable=False, index=True),
|
||||||
|
sa.Column("podcast_transcript", sa.JSON(), nullable=False, server_default="{}"),
|
||||||
|
sa.Column("file_location", sa.String(500), nullable=False, server_default=""),
|
||||||
|
sa.Column("chat_id", sa.Integer(), sa.ForeignKey("chats.id", ondelete="CASCADE"), nullable=True),
|
||||||
|
sa.Column("chat_state_version", sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column("search_space_id", sa.Integer(), sa.ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("created_at", sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
print("[Migration 49 Downgrade] Tables recreated (data not restored)")
|
||||||
|
|
||||||
|
|
@ -1,30 +0,0 @@
|
||||||
"""Define the configurable parameters for the agent."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass, fields
|
|
||||||
|
|
||||||
from langchain_core.runnables import RunnableConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(kw_only=True)
|
|
||||||
class Configuration:
|
|
||||||
"""The configuration for the agent."""
|
|
||||||
|
|
||||||
# Input parameters provided at invocation
|
|
||||||
user_query: str
|
|
||||||
connectors_to_search: list[str]
|
|
||||||
user_id: str
|
|
||||||
search_space_id: int
|
|
||||||
document_ids_to_add_in_context: list[int]
|
|
||||||
language: str | None = None
|
|
||||||
top_k: int = 10
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_runnable_config(
|
|
||||||
cls, config: RunnableConfig | None = None
|
|
||||||
) -> Configuration:
|
|
||||||
"""Create a Configuration instance from a RunnableConfig object."""
|
|
||||||
configurable = (config.get("configurable") or {}) if config else {}
|
|
||||||
_fields = {f.name for f in fields(cls) if f.init}
|
|
||||||
return cls(**{k: v for k, v in configurable.items() if k in _fields})
|
|
||||||
|
|
@ -1,47 +0,0 @@
|
||||||
from langgraph.graph import StateGraph
|
|
||||||
|
|
||||||
from .configuration import Configuration
|
|
||||||
from .nodes import (
|
|
||||||
generate_further_questions,
|
|
||||||
handle_qna_workflow,
|
|
||||||
reformulate_user_query,
|
|
||||||
)
|
|
||||||
from .state import State
|
|
||||||
|
|
||||||
|
|
||||||
def build_graph():
|
|
||||||
"""
|
|
||||||
Build and return the LangGraph workflow.
|
|
||||||
|
|
||||||
This function constructs the researcher agent graph for Q&A workflow.
|
|
||||||
The workflow follows a simple path:
|
|
||||||
1. Reformulate user query based on chat history
|
|
||||||
2. Handle QNA workflow (fetch documents and generate answer)
|
|
||||||
3. Generate follow-up questions
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A compiled LangGraph workflow
|
|
||||||
"""
|
|
||||||
# Define a new graph with state class
|
|
||||||
workflow = StateGraph(State, config_schema=Configuration)
|
|
||||||
|
|
||||||
# Add nodes to the graph
|
|
||||||
workflow.add_node("reformulate_user_query", reformulate_user_query)
|
|
||||||
workflow.add_node("handle_qna_workflow", handle_qna_workflow)
|
|
||||||
workflow.add_node("generate_further_questions", generate_further_questions)
|
|
||||||
|
|
||||||
# Define the edges - simple linear flow for QNA
|
|
||||||
workflow.add_edge("__start__", "reformulate_user_query")
|
|
||||||
workflow.add_edge("reformulate_user_query", "handle_qna_workflow")
|
|
||||||
workflow.add_edge("handle_qna_workflow", "generate_further_questions")
|
|
||||||
workflow.add_edge("generate_further_questions", "__end__")
|
|
||||||
|
|
||||||
# Compile the workflow into an executable graph
|
|
||||||
graph = workflow.compile()
|
|
||||||
graph.name = "Surfsense Researcher" # This defines the custom name in LangSmith
|
|
||||||
|
|
||||||
return graph
|
|
||||||
|
|
||||||
|
|
||||||
# Compile the graph once when the module is loaded
|
|
||||||
graph = build_graph()
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,140 +0,0 @@
|
||||||
import datetime
|
|
||||||
|
|
||||||
|
|
||||||
def _build_language_instruction(language: str | None = None):
|
|
||||||
"""Build language instruction for prompts."""
|
|
||||||
if language:
|
|
||||||
return f"\n\nIMPORTANT: Please respond in {language} language. All your responses, explanations, and analysis should be written in {language}."
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def get_further_questions_system_prompt():
|
|
||||||
return f"""
|
|
||||||
Today's date: {datetime.datetime.now().strftime("%Y-%m-%d")}
|
|
||||||
<further_questions_system>
|
|
||||||
You are an expert research assistant specializing in generating contextually relevant follow-up questions. Your task is to analyze the chat history and available documents to suggest further questions that would naturally extend the conversation and provide additional value to the user.
|
|
||||||
|
|
||||||
<input>
|
|
||||||
- chat_history: Provided in XML format within <chat_history> tags, containing <user> and <assistant> message pairs that show the chronological conversation flow. This provides context about what has already been discussed.
|
|
||||||
- available_documents: Provided in XML format within <documents> tags, containing individual <document> elements with <document_metadata> and <document_content> sections. Each document contains multiple `<chunk id='...'>...</chunk>` blocks inside <document_content>. This helps understand what information is accessible for answering potential follow-up questions.
|
|
||||||
</input>
|
|
||||||
|
|
||||||
<output_format>
|
|
||||||
A JSON object with the following structure:
|
|
||||||
{{
|
|
||||||
"further_questions": [
|
|
||||||
{{
|
|
||||||
"id": 0,
|
|
||||||
"question": "further qn 1"
|
|
||||||
}},
|
|
||||||
{{
|
|
||||||
"id": 1,
|
|
||||||
"question": "further qn 2"
|
|
||||||
}}
|
|
||||||
]
|
|
||||||
}}
|
|
||||||
</output_format>
|
|
||||||
|
|
||||||
<instructions>
|
|
||||||
1. **Analyze Chat History:** Review the entire conversation flow to understand:
|
|
||||||
* The main topics and themes discussed
|
|
||||||
* The user's interests and areas of focus
|
|
||||||
* Questions that have been asked and answered
|
|
||||||
* Any gaps or areas that could be explored further
|
|
||||||
* The depth level of the current discussion
|
|
||||||
|
|
||||||
2. **Evaluate Available Documents:** Consider the documents in context to identify:
|
|
||||||
* Additional information that hasn't been explored yet
|
|
||||||
* Related topics that could be of interest
|
|
||||||
* Specific details or data points that could warrant deeper investigation
|
|
||||||
* Cross-references or connections between different documents
|
|
||||||
|
|
||||||
3. **Generate Relevant Follow-up Questions:** Create 3-5 further questions that:
|
|
||||||
* Are directly related to the ongoing conversation but explore new angles
|
|
||||||
* Can be reasonably answered using the available documents or knowledge base
|
|
||||||
* Progress the conversation forward rather than repeating previous topics
|
|
||||||
* Match the user's apparent level of interest and expertise
|
|
||||||
* Are specific and actionable rather than overly broad
|
|
||||||
* Consider practical applications, comparisons, deeper analysis, or related concepts
|
|
||||||
|
|
||||||
4. **Ensure Question Quality:** Each question should:
|
|
||||||
* Be clear and well-formulated
|
|
||||||
* Provide genuine value to the user
|
|
||||||
* Be distinct from other suggested questions
|
|
||||||
* Be answerable within the current context
|
|
||||||
* Encourage meaningful exploration of the topic
|
|
||||||
|
|
||||||
5. **Prioritize and Order:** Arrange questions by relevance and natural progression:
|
|
||||||
* Most directly related to the current discussion first
|
|
||||||
* Questions that build upon previous answers
|
|
||||||
* Questions that explore practical applications or implications
|
|
||||||
* Questions that introduce related but new concepts
|
|
||||||
|
|
||||||
6. **Adhere Strictly to Output Format:** Ensure the final output is a valid JSON object with:
|
|
||||||
* Correct field names (`further_questions`, `id`, `question`)
|
|
||||||
* Sequential numbering starting from 0
|
|
||||||
* Proper data types and JSON formatting
|
|
||||||
</instructions>
|
|
||||||
|
|
||||||
<examples>
|
|
||||||
Input:
|
|
||||||
<chat_history>
|
|
||||||
<user>What are the best machine learning algorithms for text classification?</user>
|
|
||||||
<assistant>For text classification, several algorithms work well depending on your specific needs:
|
|
||||||
|
|
||||||
**Traditional Methods:**
|
|
||||||
- **Support Vector Machines (SVM)** - Excellent for high-dimensional text data
|
|
||||||
- **Naive Bayes** - Simple, fast, and works well with small datasets
|
|
||||||
- **Logistic Regression** - Good baseline with interpretable results
|
|
||||||
|
|
||||||
**Modern Deep Learning:**
|
|
||||||
- **Neural Networks** - More complex but can capture intricate patterns
|
|
||||||
- **Transformer models** - State-of-the-art for most text classification tasks
|
|
||||||
|
|
||||||
The choice depends on your dataset size, computational resources, and accuracy requirements.</assistant>
|
|
||||||
</chat_history>
|
|
||||||
|
|
||||||
<documents>
|
|
||||||
<document>
|
|
||||||
<metadata>
|
|
||||||
<source_id>101</source_id>
|
|
||||||
<source_type>FILE</source_type>
|
|
||||||
</metadata>
|
|
||||||
<content>
|
|
||||||
# Machine Learning for Text Classification: A Comprehensive Guide
|
|
||||||
|
|
||||||
## Performance Comparison
|
|
||||||
Recent studies show that transformer-based models achieve 95%+ accuracy on most text classification benchmarks, while traditional methods like SVM typically achieve 85-90% accuracy.
|
|
||||||
|
|
||||||
## Dataset Considerations
|
|
||||||
- Small datasets (< 1000 samples): Naive Bayes, SVM
|
|
||||||
- Large datasets (> 10,000 samples): Neural networks, transformers
|
|
||||||
- Imbalanced datasets: Require special handling with techniques like SMOTE
|
|
||||||
</content>
|
|
||||||
</document>
|
|
||||||
</documents>
|
|
||||||
|
|
||||||
Output:
|
|
||||||
{{
|
|
||||||
"further_questions": [
|
|
||||||
{{
|
|
||||||
"id": 0,
|
|
||||||
"question": "What are the key differences in performance between traditional algorithms like SVM and modern deep learning approaches for text classification?"
|
|
||||||
}},
|
|
||||||
{{
|
|
||||||
"id": 1,
|
|
||||||
"question": "How do you handle imbalanced datasets when training text classification models?"
|
|
||||||
}},
|
|
||||||
{{
|
|
||||||
"id": 2,
|
|
||||||
"question": "What preprocessing techniques are most effective for improving text classification accuracy?"
|
|
||||||
}},
|
|
||||||
{{
|
|
||||||
"id": 3,
|
|
||||||
"question": "Are there specific domains or use cases where certain classification algorithms perform better than others?"
|
|
||||||
}}
|
|
||||||
]
|
|
||||||
}}
|
|
||||||
</examples>
|
|
||||||
</further_questions_system>
|
|
||||||
"""
|
|
||||||
|
|
@ -1,5 +0,0 @@
|
||||||
"""QnA Agent."""
|
|
||||||
|
|
||||||
from .graph import graph
|
|
||||||
|
|
||||||
__all__ = ["graph"]
|
|
||||||
|
|
@ -1,31 +0,0 @@
|
||||||
"""Define the configurable parameters for the agent."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass, fields
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.runnables import RunnableConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(kw_only=True)
|
|
||||||
class Configuration:
|
|
||||||
"""The configuration for the Q&A agent."""
|
|
||||||
|
|
||||||
# Configuration parameters for the Q&A agent
|
|
||||||
user_query: str # The user's question to answer
|
|
||||||
reformulated_query: str # The reformulated query
|
|
||||||
relevant_documents: list[
|
|
||||||
Any
|
|
||||||
] # Documents provided directly to the agent for answering
|
|
||||||
search_space_id: int # Search space identifier
|
|
||||||
language: str | None = None # Language for responses
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_runnable_config(
|
|
||||||
cls, config: RunnableConfig | None = None
|
|
||||||
) -> Configuration:
|
|
||||||
"""Create a Configuration instance from a RunnableConfig object."""
|
|
||||||
configurable = (config.get("configurable") or {}) if config else {}
|
|
||||||
_fields = {f.name for f in fields(cls) if f.init}
|
|
||||||
return cls(**{k: v for k, v in configurable.items() if k in _fields})
|
|
||||||
|
|
@ -1,201 +0,0 @@
|
||||||
"""Default system prompts for Q&A agent.
|
|
||||||
|
|
||||||
The prompt system is modular with 3 parts:
|
|
||||||
- Part 1 (Base): Core instructions for answering questions (no citations)
|
|
||||||
- Part 2 (Citations): Citation-specific instructions and formatting rules
|
|
||||||
- Part 3 (Custom): User's custom instructions (empty by default)
|
|
||||||
|
|
||||||
Combinations:
|
|
||||||
- Part 1 only: Answers without citations
|
|
||||||
- Part 1 + Part 2: Answers with citations
|
|
||||||
- Part 1 + Part 2 + Part 3: Answers with citations and custom instructions
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Part 1: Base system prompt for answering without citations
|
|
||||||
DEFAULT_QNA_BASE_PROMPT = """Today's date: {date}
|
|
||||||
You are SurfSense, an advanced AI research assistant that provides detailed, well-researched answers to user questions by synthesizing information from multiple personal knowledge sources.{language_instruction}
|
|
||||||
{chat_history_section}
|
|
||||||
<knowledge_sources>
|
|
||||||
- EXTENSION: "Web content saved via SurfSense browser extension" (personal browsing history)
|
|
||||||
- FILE: "User-uploaded documents (PDFs, Word, etc.)" (personal files)
|
|
||||||
- SLACK_CONNECTOR: "Slack conversations and shared content" (personal workspace communications)
|
|
||||||
- NOTION_CONNECTOR: "Notion workspace pages and databases" (personal knowledge management)
|
|
||||||
- YOUTUBE_VIDEO: "YouTube video transcripts and metadata" (personally saved videos)
|
|
||||||
- GITHUB_CONNECTOR: "GitHub repository content and issues" (personal repositories and interactions)
|
|
||||||
- ELASTICSEARCH_CONNECTOR: "Elasticsearch indexed documents and data" (personal Elasticsearch instances and custom data sources)
|
|
||||||
- LINEAR_CONNECTOR: "Linear project issues and discussions" (personal project management)
|
|
||||||
- JIRA_CONNECTOR: "Jira project issues, tickets, and comments" (personal project tracking)
|
|
||||||
- CONFLUENCE_CONNECTOR: "Confluence pages and comments" (personal project documentation)
|
|
||||||
- CLICKUP_CONNECTOR: "ClickUp tasks and project data" (personal task management)
|
|
||||||
- GOOGLE_CALENDAR_CONNECTOR: "Google Calendar events, meetings, and schedules" (personal calendar and time management)
|
|
||||||
- GOOGLE_GMAIL_CONNECTOR: "Google Gmail emails and conversations" (personal emails and communications)
|
|
||||||
- DISCORD_CONNECTOR: "Discord server conversations and shared content" (personal community communications)
|
|
||||||
- AIRTABLE_CONNECTOR: "Airtable records, tables, and database content" (personal data management and organization)
|
|
||||||
- TAVILY_API: "Tavily search API results" (personalized search results)
|
|
||||||
- LINKUP_API: "Linkup search API results" (personalized search results)
|
|
||||||
- LUMA_CONNECTOR: "Luma events"
|
|
||||||
- WEBCRAWLER_CONNECTOR: "Webpages indexed by SurfSense" (personally selected websites)
|
|
||||||
</knowledge_sources>
|
|
||||||
|
|
||||||
<instructions>
|
|
||||||
1. Review the chat history to understand the conversation context and any previous topics discussed.
|
|
||||||
2. Carefully analyze all provided documents in the <document> sections.
|
|
||||||
3. Extract relevant information that directly addresses the user's question.
|
|
||||||
4. Provide a comprehensive, detailed answer using information from the user's personal knowledge sources.
|
|
||||||
5. Structure your answer logically and conversationally, as if having a detailed discussion with the user.
|
|
||||||
6. Use your own words to synthesize and connect ideas from the documents.
|
|
||||||
7. If documents contain conflicting information, acknowledge this and present both perspectives.
|
|
||||||
8. If the user's question cannot be fully answered with the provided documents, clearly state what information is missing.
|
|
||||||
9. Provide actionable insights and practical information when relevant to the user's question.
|
|
||||||
10. Use the chat history to maintain conversation continuity and refer to previous discussions when relevant.
|
|
||||||
11. Remember that all knowledge sources contain personal information - provide answers that reflect this personal context.
|
|
||||||
12. Be conversational and engaging while maintaining accuracy.
|
|
||||||
</instructions>
|
|
||||||
|
|
||||||
<format>
|
|
||||||
- Write in a clear, conversational tone suitable for detailed Q&A discussions
|
|
||||||
- Provide comprehensive answers that thoroughly address the user's question
|
|
||||||
- Use appropriate paragraphs and structure for readability
|
|
||||||
- ALWAYS provide personalized answers that reflect the user's own knowledge and context
|
|
||||||
- Be thorough and detailed in your explanations while remaining focused on the user's specific question
|
|
||||||
- If asking follow-up questions would be helpful, suggest them at the end of your response
|
|
||||||
</format>
|
|
||||||
|
|
||||||
<user_query_instructions>
|
|
||||||
When you see a user query, focus exclusively on providing a detailed, comprehensive answer using information from the provided documents, which contain the user's personal knowledge and data.
|
|
||||||
|
|
||||||
Make sure your response:
|
|
||||||
1. Considers the chat history for context and conversation continuity
|
|
||||||
2. Directly and thoroughly answers the user's question with personalized information from their own knowledge sources
|
|
||||||
3. Is conversational, engaging, and detailed
|
|
||||||
4. Acknowledges the personal nature of the information being provided
|
|
||||||
5. Offers follow-up suggestions when appropriate
|
|
||||||
</user_query_instructions>
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Part 2: Citation-specific instructions to add citation capabilities
|
|
||||||
DEFAULT_QNA_CITATION_INSTRUCTIONS = """
|
|
||||||
<citation_instructions>
|
|
||||||
CRITICAL CITATION REQUIREMENTS:
|
|
||||||
|
|
||||||
1. For EVERY piece of information you include from the documents, add a citation in the format [citation:chunk_id] where chunk_id is the exact value from the `<chunk id='...'>` tag inside `<document_content>`.
|
|
||||||
2. Make sure ALL factual statements from the documents have proper citations.
|
|
||||||
3. If multiple chunks support the same point, include all relevant citations [citation:chunk_id1], [citation:chunk_id2].
|
|
||||||
4. You MUST use the exact chunk_id values from the `<chunk id='...'>` attributes. Do not create your own citation numbers.
|
|
||||||
5. Every citation MUST be in the format [citation:chunk_id] where chunk_id is the exact chunk id value.
|
|
||||||
6. Never modify or change the chunk_id - always use the original values exactly as provided in the chunk tags.
|
|
||||||
7. Do not return citations as clickable links.
|
|
||||||
8. Never format citations as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only.
|
|
||||||
9. Citations must ONLY appear as [citation:chunk_id] or [citation:chunk_id1], [citation:chunk_id2] format - never with parentheses, hyperlinks, or other formatting.
|
|
||||||
10. Never make up chunk IDs. Only use chunk_id values that are explicitly provided in the `<chunk id='...'>` tags.
|
|
||||||
11. If you are unsure about a chunk_id, do not include a citation rather than guessing or making one up.
|
|
||||||
|
|
||||||
<document_structure_example>
|
|
||||||
The documents you receive are structured like this:
|
|
||||||
|
|
||||||
<document>
|
|
||||||
<document_metadata>
|
|
||||||
<document_id>42</document_id>
|
|
||||||
<document_type>GITHUB_CONNECTOR</document_type>
|
|
||||||
<title><![CDATA[Some repo / file / issue title]]></title>
|
|
||||||
<url><![CDATA[https://example.com]]></url>
|
|
||||||
<metadata_json><![CDATA[{{"any":"other metadata"}}]]></metadata_json>
|
|
||||||
</document_metadata>
|
|
||||||
|
|
||||||
<document_content>
|
|
||||||
<chunk id='123'><![CDATA[First chunk text...]]></chunk>
|
|
||||||
<chunk id='124'><![CDATA[Second chunk text...]]></chunk>
|
|
||||||
</document_content>
|
|
||||||
</document>
|
|
||||||
|
|
||||||
IMPORTANT: You MUST cite using the chunk ids (e.g. 123, 124). Do NOT cite document_id.
|
|
||||||
</document_structure_example>
|
|
||||||
|
|
||||||
<citation_format>
|
|
||||||
- Every fact from the documents must have a citation in the format [citation:chunk_id] where chunk_id is the EXACT id value from a `<chunk id='...'>` tag
|
|
||||||
- Citations should appear at the end of the sentence containing the information they support
|
|
||||||
- Multiple citations should be separated by commas: [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3]
|
|
||||||
- No need to return references section. Just citations in answer.
|
|
||||||
- NEVER create your own citation format - use the exact chunk_id values from the documents in the [citation:chunk_id] format
|
|
||||||
- NEVER format citations as clickable links or as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only
|
|
||||||
- NEVER make up chunk IDs if you are unsure about the chunk_id. It is better to omit the citation than to guess
|
|
||||||
</citation_format>
|
|
||||||
|
|
||||||
<citation_examples>
|
|
||||||
CORRECT citation formats:
|
|
||||||
- [citation:5]
|
|
||||||
- [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3]
|
|
||||||
|
|
||||||
INCORRECT citation formats (DO NOT use):
|
|
||||||
- Using parentheses and markdown links: ([citation:5](https://github.com/MODSetter/SurfSense))
|
|
||||||
- Using parentheses around brackets: ([citation:5])
|
|
||||||
- Using hyperlinked text: [link to source 5](https://example.com)
|
|
||||||
- Using footnote style: ... library¹
|
|
||||||
- Making up source IDs when source_id is unknown
|
|
||||||
- Using old IEEE format: [1], [2], [3]
|
|
||||||
- Using source types instead of IDs: [citation:GITHUB_CONNECTOR] instead of [citation:5]
|
|
||||||
</citation_examples>
|
|
||||||
|
|
||||||
<citation_output_example>
|
|
||||||
Based on your GitHub repositories and video content, Python's asyncio library provides tools for writing concurrent code using the async/await syntax [citation:5]. It's particularly useful for I/O-bound and high-level structured network code [citation:5].
|
|
||||||
|
|
||||||
The key advantage of asyncio is that it can improve performance by allowing other code to run while waiting for I/O operations to complete [citation:12]. This makes it excellent for scenarios like web scraping, API calls, database operations, or any situation where your program spends time waiting for external resources.
|
|
||||||
|
|
||||||
However, from your video learning, it's important to note that asyncio is not suitable for CPU-bound tasks as it runs on a single thread [citation:12]. For computationally intensive work, you'd want to use multiprocessing instead.
|
|
||||||
</citation_output_example>
|
|
||||||
</citation_instructions>
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Part 3: User's custom instructions (empty by default, can be set by user from UI)
|
|
||||||
DEFAULT_QNA_CUSTOM_INSTRUCTIONS = ""
|
|
||||||
|
|
||||||
# Full prompt with all parts combined (for backward compatibility and migration)
|
|
||||||
DEFAULT_QNA_CITATION_PROMPT = (
|
|
||||||
DEFAULT_QNA_BASE_PROMPT
|
|
||||||
+ DEFAULT_QNA_CITATION_INSTRUCTIONS
|
|
||||||
+ DEFAULT_QNA_CUSTOM_INSTRUCTIONS
|
|
||||||
)
|
|
||||||
|
|
||||||
DEFAULT_QNA_NO_DOCUMENTS_PROMPT = """Today's date: {date}
|
|
||||||
You are SurfSense, an advanced AI research assistant that provides helpful, detailed answers to user questions in a conversational manner.{language_instruction}
|
|
||||||
{chat_history_section}
|
|
||||||
<context>
|
|
||||||
The user has asked a question but there are no specific documents from their personal knowledge base available to answer it. You should provide a helpful response based on:
|
|
||||||
1. The conversation history and context
|
|
||||||
2. Your general knowledge and expertise
|
|
||||||
3. Understanding of the user's needs and interests based on our conversation
|
|
||||||
</context>
|
|
||||||
|
|
||||||
<instructions>
|
|
||||||
1. Provide a comprehensive, helpful answer to the user's question
|
|
||||||
2. Draw upon the conversation history to understand context and the user's specific needs
|
|
||||||
3. Use your general knowledge to provide accurate, detailed information
|
|
||||||
4. Be conversational and engaging, as if having a detailed discussion with the user
|
|
||||||
5. Acknowledge when you're drawing from general knowledge rather than their personal sources
|
|
||||||
6. Provide actionable insights and practical information when relevant
|
|
||||||
7. Structure your answer logically and clearly
|
|
||||||
8. If the question would benefit from personalized information from their knowledge base, gently suggest they might want to add relevant content to SurfSense
|
|
||||||
9. Be honest about limitations while still being maximally helpful
|
|
||||||
10. Maintain the helpful, knowledgeable tone that users expect from SurfSense
|
|
||||||
</instructions>
|
|
||||||
|
|
||||||
<format>
|
|
||||||
- Write in a clear, conversational tone suitable for detailed Q&A discussions
|
|
||||||
- Provide comprehensive answers that thoroughly address the user's question
|
|
||||||
- Use appropriate paragraphs and structure for readability
|
|
||||||
- No citations are needed since you're using general knowledge
|
|
||||||
- Be thorough and detailed in your explanations while remaining focused on the user's specific question
|
|
||||||
- If asking follow-up questions would be helpful, suggest them at the end of your response
|
|
||||||
- When appropriate, mention that adding relevant content to their SurfSense knowledge base could provide more personalized answers
|
|
||||||
</format>
|
|
||||||
|
|
||||||
<user_query_instructions>
|
|
||||||
When answering the user's question without access to their personal documents:
|
|
||||||
1. Review the chat history to understand conversation context and maintain continuity
|
|
||||||
2. Provide the most helpful and comprehensive answer possible using general knowledge
|
|
||||||
3. Be conversational and engaging
|
|
||||||
4. Draw upon conversation history for context
|
|
||||||
5. Be clear that you're providing general information
|
|
||||||
6. Suggest ways the user could get more personalized answers by expanding their knowledge base when relevant
|
|
||||||
</user_query_instructions>
|
|
||||||
"""
|
|
||||||
|
|
@ -1,21 +0,0 @@
|
||||||
from langgraph.graph import StateGraph
|
|
||||||
|
|
||||||
from .configuration import Configuration
|
|
||||||
from .nodes import answer_question, rerank_documents
|
|
||||||
from .state import State
|
|
||||||
|
|
||||||
# Define a new graph
|
|
||||||
workflow = StateGraph(State, config_schema=Configuration)
|
|
||||||
|
|
||||||
# Add the nodes to the graph
|
|
||||||
workflow.add_node("rerank_documents", rerank_documents)
|
|
||||||
workflow.add_node("answer_question", answer_question)
|
|
||||||
|
|
||||||
# Connect the nodes
|
|
||||||
workflow.add_edge("__start__", "rerank_documents")
|
|
||||||
workflow.add_edge("rerank_documents", "answer_question")
|
|
||||||
workflow.add_edge("answer_question", "__end__")
|
|
||||||
|
|
||||||
# Compile the workflow into an executable graph
|
|
||||||
graph = workflow.compile()
|
|
||||||
graph.name = "SurfSense QnA Agent" # This defines the custom name in LangSmith
|
|
||||||
|
|
@ -1,297 +0,0 @@
|
||||||
import datetime
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_core.runnables import RunnableConfig
|
|
||||||
from langgraph.types import StreamWriter
|
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
from app.db import SearchSpace
|
|
||||||
from app.services.reranker_service import RerankerService
|
|
||||||
|
|
||||||
from ..utils import (
|
|
||||||
calculate_token_count,
|
|
||||||
format_documents_section,
|
|
||||||
langchain_chat_history_to_str,
|
|
||||||
optimize_documents_for_token_limit,
|
|
||||||
)
|
|
||||||
from .configuration import Configuration
|
|
||||||
from .default_prompts import (
|
|
||||||
DEFAULT_QNA_BASE_PROMPT,
|
|
||||||
DEFAULT_QNA_CITATION_INSTRUCTIONS,
|
|
||||||
DEFAULT_QNA_NO_DOCUMENTS_PROMPT,
|
|
||||||
)
|
|
||||||
from .state import State
|
|
||||||
|
|
||||||
|
|
||||||
def _build_language_instruction(language: str | None = None):
|
|
||||||
"""Build language instruction for prompts."""
|
|
||||||
if language:
|
|
||||||
return f"\n\nIMPORTANT: Please respond in {language} language. All your responses, explanations, and analysis should be written in {language}."
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def _build_chat_history_section(chat_history: str | None = None):
|
|
||||||
"""Build chat history section for prompts."""
|
|
||||||
if chat_history:
|
|
||||||
return f"""
|
|
||||||
<chat_history>
|
|
||||||
{chat_history if chat_history else "NO CHAT HISTORY PROVIDED"}
|
|
||||||
</chat_history>
|
|
||||||
"""
|
|
||||||
return """
|
|
||||||
<chat_history>
|
|
||||||
NO CHAT HISTORY PROVIDED
|
|
||||||
</chat_history>
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def _format_system_prompt(
|
|
||||||
prompt_template: str,
|
|
||||||
chat_history: str | None = None,
|
|
||||||
language: str | None = None,
|
|
||||||
):
|
|
||||||
"""Format a system prompt template with dynamic values."""
|
|
||||||
date = datetime.datetime.now().strftime("%Y-%m-%d")
|
|
||||||
language_instruction = _build_language_instruction(language)
|
|
||||||
chat_history_section = _build_chat_history_section(chat_history)
|
|
||||||
|
|
||||||
return prompt_template.format(
|
|
||||||
date=date,
|
|
||||||
language_instruction=language_instruction,
|
|
||||||
chat_history_section=chat_history_section,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Rerank the documents based on relevance to the user's question.
|
|
||||||
|
|
||||||
This node takes the relevant documents provided in the configuration,
|
|
||||||
reranks them using the reranker service based on the user's query,
|
|
||||||
and updates the state with the reranked documents.
|
|
||||||
|
|
||||||
Documents are now document-grouped with a `chunks` list. Reranking is done
|
|
||||||
using the concatenated `content` field, and the full structure (including
|
|
||||||
`chunks`) is preserved for proper citation formatting.
|
|
||||||
|
|
||||||
If reranking is disabled, returns the original documents without processing.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict containing the reranked documents.
|
|
||||||
"""
|
|
||||||
# Get configuration and relevant documents
|
|
||||||
configuration = Configuration.from_runnable_config(config)
|
|
||||||
documents = configuration.relevant_documents
|
|
||||||
user_query = configuration.user_query
|
|
||||||
reformulated_query = configuration.reformulated_query
|
|
||||||
|
|
||||||
# If no documents were provided, return empty list
|
|
||||||
if not documents or len(documents) == 0:
|
|
||||||
return {"reranked_documents": []}
|
|
||||||
|
|
||||||
# Get reranker service from app config
|
|
||||||
reranker_service = RerankerService.get_reranker_instance()
|
|
||||||
|
|
||||||
# If reranking is not enabled, sort by existing score and return
|
|
||||||
if not reranker_service:
|
|
||||||
print("Reranking is disabled. Sorting documents by existing score.")
|
|
||||||
sorted_documents = sorted(
|
|
||||||
documents, key=lambda x: x.get("score", 0), reverse=True
|
|
||||||
)
|
|
||||||
return {"reranked_documents": sorted_documents}
|
|
||||||
|
|
||||||
# Perform reranking
|
|
||||||
try:
|
|
||||||
# Pass documents directly to reranker - it will use:
|
|
||||||
# - "content" (concatenated chunk text) for scoring
|
|
||||||
# - "chunk_id" (primary chunk id) for matching
|
|
||||||
# The full document structure including "chunks" is preserved
|
|
||||||
reranked_docs = reranker_service.rerank_documents(
|
|
||||||
user_query + "\n" + reformulated_query, documents
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sort by score in descending order
|
|
||||||
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
|
|
||||||
|
|
||||||
print(f"Reranked {len(reranked_docs)} documents for Q&A query: {user_query}")
|
|
||||||
|
|
||||||
return {"reranked_documents": reranked_docs}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error during reranking: {e!s}")
|
|
||||||
# Fall back to original documents if reranking fails
|
|
||||||
return {"reranked_documents": documents}
|
|
||||||
|
|
||||||
|
|
||||||
async def answer_question(
|
|
||||||
state: State, config: RunnableConfig, writer: StreamWriter
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Answer the user's question using the provided documents with real-time streaming.
|
|
||||||
|
|
||||||
This node takes the relevant documents provided in the configuration and uses
|
|
||||||
an LLM to generate a comprehensive answer to the user's question with
|
|
||||||
proper citations. The citations follow [citation:chunk_id] format using chunk IDs from the
|
|
||||||
`<chunk id='...'>` tags in the provided documents. If no documents are provided, it will use chat history to generate
|
|
||||||
an answer.
|
|
||||||
|
|
||||||
The response is streamed token-by-token for real-time updates to the frontend.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict containing the final answer in the "final_answer" key.
|
|
||||||
"""
|
|
||||||
from app.services.llm_service import get_fast_llm
|
|
||||||
|
|
||||||
# Get configuration and relevant documents from configuration
|
|
||||||
configuration = Configuration.from_runnable_config(config)
|
|
||||||
documents = state.reranked_documents
|
|
||||||
user_query = configuration.user_query
|
|
||||||
search_space_id = configuration.search_space_id
|
|
||||||
language = configuration.language
|
|
||||||
|
|
||||||
# Get streaming service from state
|
|
||||||
streaming_service = state.streaming_service
|
|
||||||
|
|
||||||
# Fetch search space to get QnA configuration
|
|
||||||
result = await state.db_session.execute(
|
|
||||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
|
||||||
)
|
|
||||||
search_space = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if not search_space:
|
|
||||||
error_message = f"Search space {search_space_id} not found"
|
|
||||||
print(error_message)
|
|
||||||
raise RuntimeError(error_message)
|
|
||||||
|
|
||||||
# Get QnA configuration from search space
|
|
||||||
citations_enabled = search_space.citations_enabled
|
|
||||||
custom_instructions_text = search_space.qna_custom_instructions or ""
|
|
||||||
|
|
||||||
# Use constants for base prompt and citation instructions
|
|
||||||
qna_base_prompt = DEFAULT_QNA_BASE_PROMPT
|
|
||||||
qna_citation_instructions = (
|
|
||||||
DEFAULT_QNA_CITATION_INSTRUCTIONS if citations_enabled else ""
|
|
||||||
)
|
|
||||||
qna_custom_instructions = (
|
|
||||||
f"\n<special_important_custom_instructions>\n{custom_instructions_text}\n</special_important_custom_instructions>"
|
|
||||||
if custom_instructions_text
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get search space's fast LLM
|
|
||||||
llm = await get_fast_llm(state.db_session, search_space_id)
|
|
||||||
if not llm:
|
|
||||||
error_message = f"No fast LLM configured for search space {search_space_id}"
|
|
||||||
print(error_message)
|
|
||||||
raise RuntimeError(error_message)
|
|
||||||
|
|
||||||
# Determine if we have documents and optimize for token limits
|
|
||||||
has_documents_initially = documents and len(documents) > 0
|
|
||||||
chat_history_str = langchain_chat_history_to_str(state.chat_history)
|
|
||||||
|
|
||||||
if has_documents_initially:
|
|
||||||
# Compose the full citation prompt: base + citation instructions + custom instructions
|
|
||||||
full_citation_prompt_template = (
|
|
||||||
qna_base_prompt + qna_citation_instructions + qna_custom_instructions
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create base message template for token calculation (without documents)
|
|
||||||
base_human_message_template = f"""
|
|
||||||
|
|
||||||
User's question:
|
|
||||||
<user_query>
|
|
||||||
{user_query}
|
|
||||||
</user_query>
|
|
||||||
|
|
||||||
Please provide a detailed, comprehensive answer to the user's question using the information from their personal knowledge sources. Make sure to cite all information appropriately and engage in a conversational manner.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Use initial system prompt for token calculation
|
|
||||||
initial_system_prompt = _format_system_prompt(
|
|
||||||
full_citation_prompt_template, chat_history_str, language
|
|
||||||
)
|
|
||||||
base_messages = [
|
|
||||||
SystemMessage(content=initial_system_prompt),
|
|
||||||
HumanMessage(content=base_human_message_template),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Optimize documents to fit within token limits
|
|
||||||
optimized_documents, has_optimized_documents = (
|
|
||||||
optimize_documents_for_token_limit(documents, base_messages, llm.model)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update state based on optimization result
|
|
||||||
documents = optimized_documents
|
|
||||||
has_documents = has_optimized_documents
|
|
||||||
else:
|
|
||||||
has_documents = False
|
|
||||||
|
|
||||||
# Choose system prompt based on final document availability
|
|
||||||
# With documents: use base + citation instructions + custom instructions
|
|
||||||
# Without documents: use the default no-documents prompt from constants
|
|
||||||
if has_documents:
|
|
||||||
full_citation_prompt_template = (
|
|
||||||
qna_base_prompt + qna_citation_instructions + qna_custom_instructions
|
|
||||||
)
|
|
||||||
system_prompt = _format_system_prompt(
|
|
||||||
full_citation_prompt_template, chat_history_str, language
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
system_prompt = _format_system_prompt(
|
|
||||||
DEFAULT_QNA_NO_DOCUMENTS_PROMPT + qna_custom_instructions,
|
|
||||||
chat_history_str,
|
|
||||||
language,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate documents section
|
|
||||||
documents_text = (
|
|
||||||
format_documents_section(
|
|
||||||
documents, "Source material from your personal knowledge base"
|
|
||||||
)
|
|
||||||
if has_documents
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create final human message content
|
|
||||||
instruction_text = (
|
|
||||||
"Please provide a detailed, comprehensive answer to the user's question using the information from their personal knowledge sources. Make sure to cite all information appropriately and engage in a conversational manner."
|
|
||||||
if has_documents
|
|
||||||
else "Please provide a helpful answer to the user's question based on our conversation history and your general knowledge. Engage in a conversational manner."
|
|
||||||
)
|
|
||||||
|
|
||||||
human_message_content = f"""
|
|
||||||
{documents_text}
|
|
||||||
|
|
||||||
User's question:
|
|
||||||
<user_query>
|
|
||||||
{user_query}
|
|
||||||
</user_query>
|
|
||||||
|
|
||||||
{instruction_text}
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Create final messages for the LLM
|
|
||||||
messages_with_chat_history = [
|
|
||||||
SystemMessage(content=system_prompt),
|
|
||||||
HumanMessage(content=human_message_content),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Log final token count
|
|
||||||
total_tokens = calculate_token_count(messages_with_chat_history, llm.model)
|
|
||||||
print(f"Final token count: {total_tokens}")
|
|
||||||
|
|
||||||
# Stream the LLM response token by token
|
|
||||||
final_answer = ""
|
|
||||||
|
|
||||||
async for chunk in llm.astream(messages_with_chat_history):
|
|
||||||
# Extract the content from the chunk
|
|
||||||
if hasattr(chunk, "content") and chunk.content:
|
|
||||||
token = chunk.content
|
|
||||||
final_answer += token
|
|
||||||
|
|
||||||
# Stream the token to the frontend via custom stream
|
|
||||||
if streaming_service:
|
|
||||||
writer({"yield_value": streaming_service.format_text_chunk(token)})
|
|
||||||
|
|
||||||
return {"final_answer": final_answer}
|
|
||||||
|
|
@ -1,32 +0,0 @@
|
||||||
"""Define the state structures for the agent."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.services.streaming_service import StreamingService
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class State:
|
|
||||||
"""Defines the dynamic state for the Q&A agent during execution.
|
|
||||||
|
|
||||||
This state tracks the database session, chat history, and the outputs
|
|
||||||
generated by the agent's nodes during question answering.
|
|
||||||
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
|
|
||||||
for more information.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Runtime context
|
|
||||||
db_session: AsyncSession
|
|
||||||
|
|
||||||
# Streaming service for real-time token streaming
|
|
||||||
streaming_service: StreamingService | None = None
|
|
||||||
|
|
||||||
chat_history: list[Any] | None = field(default_factory=list)
|
|
||||||
# OUTPUT: Populated by agent nodes
|
|
||||||
reranked_documents: list[Any] | None = None
|
|
||||||
final_answer: str | None = None
|
|
||||||
|
|
@ -1,38 +0,0 @@
|
||||||
"""Define the state structures for the agent."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.services.streaming_service import StreamingService
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class State:
|
|
||||||
"""Defines the dynamic state for the agent during execution.
|
|
||||||
|
|
||||||
This state tracks the database session and the outputs generated by the agent's nodes.
|
|
||||||
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
|
|
||||||
for more information.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Runtime context (not part of actual graph state)
|
|
||||||
db_session: AsyncSession
|
|
||||||
|
|
||||||
# Streaming service
|
|
||||||
streaming_service: StreamingService
|
|
||||||
|
|
||||||
chat_history: list[Any] | None = field(default_factory=list)
|
|
||||||
|
|
||||||
reformulated_query: str | None = field(default=None)
|
|
||||||
further_questions: Any | None = field(default=None)
|
|
||||||
|
|
||||||
# Temporary field to hold reranked documents from sub-agents for further question generation
|
|
||||||
reranked_documents: list[Any] | None = field(default=None)
|
|
||||||
|
|
||||||
# OUTPUT: Populated by agent nodes
|
|
||||||
# Using field to explicitly mark as part of state
|
|
||||||
final_written_report: str | None = field(default=None)
|
|
||||||
|
|
@ -1,292 +0,0 @@
|
||||||
import json
|
|
||||||
from typing import Any, NamedTuple
|
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
|
||||||
from litellm import get_model_info, token_counter
|
|
||||||
|
|
||||||
|
|
||||||
class DocumentTokenInfo(NamedTuple):
|
|
||||||
"""Information about a document and its token cost."""
|
|
||||||
|
|
||||||
index: int
|
|
||||||
document: dict[str, Any]
|
|
||||||
formatted_content: str
|
|
||||||
token_count: int
|
|
||||||
|
|
||||||
|
|
||||||
def get_connector_emoji(connector_name: str) -> str:
|
|
||||||
"""Get an appropriate emoji for a connector type."""
|
|
||||||
connector_emojis = {
|
|
||||||
"YOUTUBE_VIDEO": "📹",
|
|
||||||
"EXTENSION": "🧩",
|
|
||||||
"FILE": "📄",
|
|
||||||
"SLACK_CONNECTOR": "💬",
|
|
||||||
"NOTION_CONNECTOR": "📘",
|
|
||||||
"GITHUB_CONNECTOR": "🐙",
|
|
||||||
"LINEAR_CONNECTOR": "📊",
|
|
||||||
"JIRA_CONNECTOR": "🎫",
|
|
||||||
"DISCORD_CONNECTOR": "🗨️",
|
|
||||||
"TAVILY_API": "🔍",
|
|
||||||
"LINKUP_API": "🔗",
|
|
||||||
"BAIDU_SEARCH_API": "🇨🇳",
|
|
||||||
"GOOGLE_CALENDAR_CONNECTOR": "📅",
|
|
||||||
"AIRTABLE_CONNECTOR": "🗃️",
|
|
||||||
"LUMA_CONNECTOR": "✨",
|
|
||||||
"ELASTICSEARCH_CONNECTOR": "⚡",
|
|
||||||
"WEBCRAWLER_CONNECTOR": "🌐",
|
|
||||||
"BOOKSTACK_CONNECTOR": "📚",
|
|
||||||
"NOTE": "📝",
|
|
||||||
}
|
|
||||||
return connector_emojis.get(connector_name, "🔎")
|
|
||||||
|
|
||||||
|
|
||||||
def get_connector_friendly_name(connector_name: str) -> str:
|
|
||||||
"""Convert technical connector IDs to user-friendly names."""
|
|
||||||
connector_friendly_names = {
|
|
||||||
"YOUTUBE_VIDEO": "YouTube",
|
|
||||||
"EXTENSION": "Browser Extension",
|
|
||||||
"FILE": "Files",
|
|
||||||
"SLACK_CONNECTOR": "Slack",
|
|
||||||
"NOTION_CONNECTOR": "Notion",
|
|
||||||
"GITHUB_CONNECTOR": "GitHub",
|
|
||||||
"LINEAR_CONNECTOR": "Linear",
|
|
||||||
"JIRA_CONNECTOR": "Jira",
|
|
||||||
"CONFLUENCE_CONNECTOR": "Confluence",
|
|
||||||
"GOOGLE_CALENDAR_CONNECTOR": "Google Calendar",
|
|
||||||
"DISCORD_CONNECTOR": "Discord",
|
|
||||||
"TAVILY_API": "Tavily Search",
|
|
||||||
"LINKUP_API": "Linkup Search",
|
|
||||||
"BAIDU_SEARCH_API": "Baidu Search",
|
|
||||||
"AIRTABLE_CONNECTOR": "Airtable",
|
|
||||||
"LUMA_CONNECTOR": "Luma",
|
|
||||||
"ELASTICSEARCH_CONNECTOR": "Elasticsearch",
|
|
||||||
"WEBCRAWLER_CONNECTOR": "Web Pages",
|
|
||||||
"BOOKSTACK_CONNECTOR": "BookStack",
|
|
||||||
"NOTE": "Notes",
|
|
||||||
}
|
|
||||||
return connector_friendly_names.get(connector_name, connector_name)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_langchain_messages_to_dict(
|
|
||||||
messages: list[BaseMessage],
|
|
||||||
) -> list[dict[str, str]]:
|
|
||||||
"""Convert LangChain messages to format expected by token_counter."""
|
|
||||||
role_mapping = {"system": "system", "human": "user", "ai": "assistant"}
|
|
||||||
|
|
||||||
converted_messages = []
|
|
||||||
for msg in messages:
|
|
||||||
role = role_mapping.get(getattr(msg, "type", None), "user")
|
|
||||||
converted_messages.append({"role": role, "content": str(msg.content)})
|
|
||||||
|
|
||||||
return converted_messages
|
|
||||||
|
|
||||||
|
|
||||||
def format_document_for_citation(document: dict[str, Any]) -> str:
|
|
||||||
"""Format a single document for citation in the new document+chunks XML format.
|
|
||||||
|
|
||||||
IMPORTANT:
|
|
||||||
- Citations must reference real DB chunk IDs: `[citation:<chunk_id>]`
|
|
||||||
- Document metadata is included under <document_metadata>, but citations are NOT document_id-based.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _to_cdata(value: Any) -> str:
|
|
||||||
text = "" if value is None else str(value)
|
|
||||||
# Safely nest CDATA even if the content includes "]]>"
|
|
||||||
return "<![CDATA[" + text.replace("]]>", "]]]]><![CDATA[>") + "]]>"
|
|
||||||
|
|
||||||
doc_info = document.get("document", {}) or {}
|
|
||||||
metadata = doc_info.get("metadata", {}) or {}
|
|
||||||
|
|
||||||
doc_id = doc_info.get("id", "")
|
|
||||||
title = doc_info.get("title", "")
|
|
||||||
document_type = doc_info.get("document_type", "CRAWLED_URL")
|
|
||||||
url = (
|
|
||||||
metadata.get("url")
|
|
||||||
or metadata.get("source")
|
|
||||||
or metadata.get("page_url")
|
|
||||||
or metadata.get("VisitedWebPageURL")
|
|
||||||
or ""
|
|
||||||
)
|
|
||||||
|
|
||||||
metadata_json = json.dumps(metadata, ensure_ascii=False)
|
|
||||||
|
|
||||||
chunks = document.get("chunks") or []
|
|
||||||
if not chunks:
|
|
||||||
# Fallback: treat `content` as a single chunk (no chunk_id available for citation)
|
|
||||||
chunks = [{"chunk_id": "", "content": document.get("content", "")}]
|
|
||||||
|
|
||||||
chunks_xml = "\n".join(
|
|
||||||
[
|
|
||||||
f"<chunk id='{chunk.get('chunk_id', '')}'>{_to_cdata(chunk.get('content', ''))}</chunk>"
|
|
||||||
for chunk in chunks
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return f"""<document>
|
|
||||||
<document_metadata>
|
|
||||||
<document_id>{doc_id}</document_id>
|
|
||||||
<document_type>{document_type}</document_type>
|
|
||||||
<title>{_to_cdata(title)}</title>
|
|
||||||
<url>{_to_cdata(url)}</url>
|
|
||||||
<metadata_json>{_to_cdata(metadata_json)}</metadata_json>
|
|
||||||
</document_metadata>
|
|
||||||
|
|
||||||
<document_content>
|
|
||||||
{chunks_xml}
|
|
||||||
</document_content>
|
|
||||||
</document>"""
|
|
||||||
|
|
||||||
|
|
||||||
def format_documents_section(
|
|
||||||
documents: list[dict[str, Any]], section_title: str = "Source material"
|
|
||||||
) -> str:
|
|
||||||
"""Format multiple documents into a complete documents section."""
|
|
||||||
if not documents:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
formatted_docs = [format_document_for_citation(doc) for doc in documents]
|
|
||||||
|
|
||||||
return f"""{section_title}:
|
|
||||||
<documents>
|
|
||||||
{chr(10).join(formatted_docs)}
|
|
||||||
</documents>"""
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_document_token_costs(
|
|
||||||
documents: list[dict[str, Any]], model: str
|
|
||||||
) -> list[DocumentTokenInfo]:
|
|
||||||
"""Pre-calculate token costs for each document."""
|
|
||||||
document_token_info = []
|
|
||||||
|
|
||||||
for i, doc in enumerate(documents):
|
|
||||||
formatted_doc = format_document_for_citation(doc)
|
|
||||||
|
|
||||||
# Calculate token count for this document
|
|
||||||
token_count = token_counter(
|
|
||||||
messages=[{"role": "user", "content": formatted_doc}], model=model
|
|
||||||
)
|
|
||||||
|
|
||||||
document_token_info.append(
|
|
||||||
DocumentTokenInfo(
|
|
||||||
index=i,
|
|
||||||
document=doc,
|
|
||||||
formatted_content=formatted_doc,
|
|
||||||
token_count=token_count,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return document_token_info
|
|
||||||
|
|
||||||
|
|
||||||
def find_optimal_documents_with_binary_search(
|
|
||||||
document_tokens: list[DocumentTokenInfo], available_tokens: int
|
|
||||||
) -> list[DocumentTokenInfo]:
|
|
||||||
"""Use binary search to find the maximum number of documents that fit within token limit."""
|
|
||||||
if not document_tokens or available_tokens <= 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
left, right = 0, len(document_tokens)
|
|
||||||
optimal_docs = []
|
|
||||||
|
|
||||||
while left <= right:
|
|
||||||
mid = (left + right) // 2
|
|
||||||
current_docs = document_tokens[:mid]
|
|
||||||
current_token_sum = sum(doc_info.token_count for doc_info in current_docs)
|
|
||||||
|
|
||||||
if current_token_sum <= available_tokens:
|
|
||||||
optimal_docs = current_docs
|
|
||||||
left = mid + 1
|
|
||||||
else:
|
|
||||||
right = mid - 1
|
|
||||||
|
|
||||||
return optimal_docs
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_context_window(model_name: str) -> int:
|
|
||||||
"""Get the total context window size for a model (input + output tokens)."""
|
|
||||||
try:
|
|
||||||
model_info = get_model_info(model_name)
|
|
||||||
context_window = model_info.get("max_input_tokens", 4096) # Default fallback
|
|
||||||
return context_window
|
|
||||||
except Exception as e:
|
|
||||||
print(
|
|
||||||
f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}"
|
|
||||||
)
|
|
||||||
return 4096 # Conservative fallback
|
|
||||||
|
|
||||||
|
|
||||||
def optimize_documents_for_token_limit(
|
|
||||||
documents: list[dict[str, Any]], base_messages: list[BaseMessage], model_name: str
|
|
||||||
) -> tuple[list[dict[str, Any]], bool]:
|
|
||||||
"""
|
|
||||||
Optimize documents to fit within token limits using binary search.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
documents: List of documents with content and metadata
|
|
||||||
base_messages: Base messages without documents (chat history + system + human message template)
|
|
||||||
model_name: Model name for token counting (required)
|
|
||||||
output_token_buffer: Number of tokens to reserve for model output
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (optimized_documents, has_documents_remaining)
|
|
||||||
"""
|
|
||||||
if not documents:
|
|
||||||
return [], False
|
|
||||||
|
|
||||||
model = model_name
|
|
||||||
context_window = get_model_context_window(model)
|
|
||||||
|
|
||||||
# Calculate base token cost
|
|
||||||
base_messages_dict = convert_langchain_messages_to_dict(base_messages)
|
|
||||||
base_tokens = token_counter(messages=base_messages_dict, model=model)
|
|
||||||
available_tokens_for_docs = context_window - base_tokens
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"Token optimization: Context window={context_window}, Base={base_tokens}, Available for docs={available_tokens_for_docs}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if available_tokens_for_docs <= 0:
|
|
||||||
print("No tokens available for documents after base content and output buffer")
|
|
||||||
return [], False
|
|
||||||
|
|
||||||
# Calculate token costs for all documents
|
|
||||||
document_token_info = calculate_document_token_costs(documents, model)
|
|
||||||
|
|
||||||
# Find optimal number of documents using binary search
|
|
||||||
optimal_doc_info = find_optimal_documents_with_binary_search(
|
|
||||||
document_token_info, available_tokens_for_docs
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract the original document objects
|
|
||||||
optimized_documents = [doc_info.document for doc_info in optimal_doc_info]
|
|
||||||
has_documents_remaining = len(optimized_documents) > 0
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"Token optimization result: Using {len(optimized_documents)}/{len(documents)} documents"
|
|
||||||
)
|
|
||||||
|
|
||||||
return optimized_documents, has_documents_remaining
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_token_count(messages: list[BaseMessage], model_name: str) -> int:
|
|
||||||
"""Calculate token count for a list of LangChain messages."""
|
|
||||||
model = model_name
|
|
||||||
messages_dict = convert_langchain_messages_to_dict(messages)
|
|
||||||
return token_counter(messages=messages_dict, model=model)
|
|
||||||
|
|
||||||
|
|
||||||
def langchain_chat_history_to_str(chat_history: list[BaseMessage]) -> str:
|
|
||||||
"""
|
|
||||||
Convert a list of chat history messages to a string.
|
|
||||||
"""
|
|
||||||
chat_history_str = ""
|
|
||||||
|
|
||||||
for chat_message in chat_history:
|
|
||||||
if isinstance(chat_message, HumanMessage):
|
|
||||||
chat_history_str += f"<user>{chat_message.content}</user>\n"
|
|
||||||
elif isinstance(chat_message, AIMessage):
|
|
||||||
chat_history_str += f"<assistant>{chat_message.content}</assistant>\n"
|
|
||||||
elif isinstance(chat_message, SystemMessage):
|
|
||||||
chat_history_str += f"<system>{chat_message.content}</system>\n"
|
|
||||||
|
|
||||||
return chat_history_str
|
|
||||||
|
|
@ -77,10 +77,6 @@ class SearchSourceConnectorType(str, Enum):
|
||||||
BOOKSTACK_CONNECTOR = "BOOKSTACK_CONNECTOR"
|
BOOKSTACK_CONNECTOR = "BOOKSTACK_CONNECTOR"
|
||||||
|
|
||||||
|
|
||||||
class ChatType(str, Enum):
|
|
||||||
QNA = "QNA"
|
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMProvider(str, Enum):
|
class LiteLLMProvider(str, Enum):
|
||||||
"""
|
"""
|
||||||
Enum for LLM providers supported by LiteLLM.
|
Enum for LLM providers supported by LiteLLM.
|
||||||
|
|
@ -317,21 +313,6 @@ class BaseModel(Base):
|
||||||
id = Column(Integer, primary_key=True, index=True)
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
|
||||||
|
|
||||||
class Chat(BaseModel, TimestampMixin):
|
|
||||||
__tablename__ = "chats"
|
|
||||||
|
|
||||||
type = Column(SQLAlchemyEnum(ChatType), nullable=False)
|
|
||||||
title = Column(String, nullable=False, index=True)
|
|
||||||
initial_connectors = Column(ARRAY(String), nullable=True)
|
|
||||||
messages = Column(JSON, nullable=False)
|
|
||||||
state_version = Column(BigInteger, nullable=False, default=1)
|
|
||||||
|
|
||||||
search_space_id = Column(
|
|
||||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
|
||||||
)
|
|
||||||
search_space = relationship("SearchSpace", back_populates="chats")
|
|
||||||
|
|
||||||
|
|
||||||
class NewChatMessageRole(str, Enum):
|
class NewChatMessageRole(str, Enum):
|
||||||
"""Role enum for new chat messages."""
|
"""Role enum for new chat messages."""
|
||||||
|
|
||||||
|
|
@ -363,9 +344,6 @@ class NewChatThread(BaseModel, TimestampMixin):
|
||||||
search_space_id = Column(
|
search_space_id = Column(
|
||||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||||
)
|
)
|
||||||
user_id = Column(
|
|
||||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
search_space = relationship("SearchSpace", back_populates="new_chat_threads")
|
search_space = relationship("SearchSpace", back_populates="new_chat_threads")
|
||||||
|
|
@ -445,23 +423,6 @@ class Chunk(BaseModel, TimestampMixin):
|
||||||
document = relationship("Document", back_populates="chunks")
|
document = relationship("Document", back_populates="chunks")
|
||||||
|
|
||||||
|
|
||||||
class Podcast(BaseModel, TimestampMixin):
|
|
||||||
__tablename__ = "podcasts"
|
|
||||||
|
|
||||||
title = Column(String, nullable=False, index=True)
|
|
||||||
podcast_transcript = Column(JSON, nullable=False, default={})
|
|
||||||
file_location = Column(String(500), nullable=False, default="")
|
|
||||||
chat_id = Column(
|
|
||||||
Integer, ForeignKey("chats.id", ondelete="CASCADE"), nullable=True
|
|
||||||
) # If generated from a chat, this will be the chat id, else null ( can be from a document or a chat )
|
|
||||||
chat_state_version = Column(BigInteger, nullable=True)
|
|
||||||
|
|
||||||
search_space_id = Column(
|
|
||||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
|
||||||
)
|
|
||||||
search_space = relationship("SearchSpace", back_populates="podcasts")
|
|
||||||
|
|
||||||
|
|
||||||
class SearchSpace(BaseModel, TimestampMixin):
|
class SearchSpace(BaseModel, TimestampMixin):
|
||||||
__tablename__ = "searchspaces"
|
__tablename__ = "searchspaces"
|
||||||
|
|
||||||
|
|
@ -492,18 +453,6 @@ class SearchSpace(BaseModel, TimestampMixin):
|
||||||
order_by="Document.id",
|
order_by="Document.id",
|
||||||
cascade="all, delete-orphan",
|
cascade="all, delete-orphan",
|
||||||
)
|
)
|
||||||
podcasts = relationship(
|
|
||||||
"Podcast",
|
|
||||||
back_populates="search_space",
|
|
||||||
order_by="Podcast.id",
|
|
||||||
cascade="all, delete-orphan",
|
|
||||||
)
|
|
||||||
chats = relationship(
|
|
||||||
"Chat",
|
|
||||||
back_populates="search_space",
|
|
||||||
order_by="Chat.id",
|
|
||||||
cascade="all, delete-orphan",
|
|
||||||
)
|
|
||||||
new_chat_threads = relationship(
|
new_chat_threads = relationship(
|
||||||
"NewChatThread",
|
"NewChatThread",
|
||||||
back_populates="search_space",
|
back_populates="search_space",
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ from fastapi import APIRouter
|
||||||
from .airtable_add_connector_route import (
|
from .airtable_add_connector_route import (
|
||||||
router as airtable_add_connector_router,
|
router as airtable_add_connector_router,
|
||||||
)
|
)
|
||||||
from .chats_routes import router as chats_router
|
|
||||||
from .documents_routes import router as documents_router
|
from .documents_routes import router as documents_router
|
||||||
from .editor_routes import router as editor_router
|
from .editor_routes import router as editor_router
|
||||||
from .google_calendar_add_connector_route import (
|
from .google_calendar_add_connector_route import (
|
||||||
|
|
@ -17,7 +16,6 @@ from .logs_routes import router as logs_router
|
||||||
from .luma_add_connector_route import router as luma_add_connector_router
|
from .luma_add_connector_route import router as luma_add_connector_router
|
||||||
from .new_chat_routes import router as new_chat_router
|
from .new_chat_routes import router as new_chat_router
|
||||||
from .notes_routes import router as notes_router
|
from .notes_routes import router as notes_router
|
||||||
from .podcasts_routes import router as podcasts_router
|
|
||||||
from .rbac_routes import router as rbac_router
|
from .rbac_routes import router as rbac_router
|
||||||
from .search_source_connectors_routes import router as search_source_connectors_router
|
from .search_source_connectors_routes import router as search_source_connectors_router
|
||||||
from .search_spaces_routes import router as search_spaces_router
|
from .search_spaces_routes import router as search_spaces_router
|
||||||
|
|
@ -29,9 +27,7 @@ router.include_router(rbac_router) # RBAC routes for roles, members, invites
|
||||||
router.include_router(editor_router)
|
router.include_router(editor_router)
|
||||||
router.include_router(documents_router)
|
router.include_router(documents_router)
|
||||||
router.include_router(notes_router)
|
router.include_router(notes_router)
|
||||||
router.include_router(podcasts_router)
|
router.include_router(new_chat_router) # Chat with assistant-ui persistence
|
||||||
router.include_router(chats_router)
|
|
||||||
router.include_router(new_chat_router) # New chat with assistant-ui persistence
|
|
||||||
router.include_router(search_source_connectors_router)
|
router.include_router(search_source_connectors_router)
|
||||||
router.include_router(google_calendar_add_connector_router)
|
router.include_router(google_calendar_add_connector_router)
|
||||||
router.include_router(google_gmail_add_connector_router)
|
router.include_router(google_gmail_add_connector_router)
|
||||||
|
|
|
||||||
|
|
@ -1,617 +0,0 @@
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
|
||||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
from sqlalchemy.orm import selectinload
|
|
||||||
|
|
||||||
from app.db import (
|
|
||||||
Chat,
|
|
||||||
Permission,
|
|
||||||
SearchSpace,
|
|
||||||
SearchSpaceMembership,
|
|
||||||
User,
|
|
||||||
get_async_session,
|
|
||||||
)
|
|
||||||
from app.schemas import (
|
|
||||||
AISDKChatRequest,
|
|
||||||
ChatCreate,
|
|
||||||
ChatRead,
|
|
||||||
ChatReadWithoutMessages,
|
|
||||||
ChatUpdate,
|
|
||||||
NewChatRequest,
|
|
||||||
)
|
|
||||||
from app.services.new_streaming_service import VercelStreamingService
|
|
||||||
from app.tasks.chat.stream_connector_search_results import (
|
|
||||||
stream_connector_search_results,
|
|
||||||
)
|
|
||||||
from app.tasks.chat.stream_new_chat import stream_new_chat
|
|
||||||
from app.users import current_active_user
|
|
||||||
from app.utils.rbac import check_permission
|
|
||||||
from app.utils.validators import (
|
|
||||||
validate_connectors,
|
|
||||||
validate_document_ids,
|
|
||||||
validate_messages,
|
|
||||||
validate_research_mode,
|
|
||||||
validate_search_space_id,
|
|
||||||
validate_top_k,
|
|
||||||
)
|
|
||||||
|
|
||||||
router = APIRouter()
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/chat")
|
|
||||||
async def handle_chat_data(
|
|
||||||
request: AISDKChatRequest,
|
|
||||||
session: AsyncSession = Depends(get_async_session),
|
|
||||||
user: User = Depends(current_active_user),
|
|
||||||
):
|
|
||||||
# Validate and sanitize all input data
|
|
||||||
messages = validate_messages(request.messages)
|
|
||||||
|
|
||||||
if messages[-1]["role"] != "user":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400, detail="Last message must be a user message"
|
|
||||||
)
|
|
||||||
|
|
||||||
user_query = messages[-1]["content"]
|
|
||||||
|
|
||||||
# Extract and validate data from request
|
|
||||||
request_data = request.data or {}
|
|
||||||
search_space_id = validate_search_space_id(request_data.get("search_space_id"))
|
|
||||||
research_mode = validate_research_mode(request_data.get("research_mode"))
|
|
||||||
selected_connectors = validate_connectors(request_data.get("selected_connectors"))
|
|
||||||
document_ids_to_add_in_context = validate_document_ids(
|
|
||||||
request_data.get("document_ids_to_add_in_context")
|
|
||||||
)
|
|
||||||
top_k = validate_top_k(request_data.get("top_k"))
|
|
||||||
# print("RESQUEST DATA:", request_data)
|
|
||||||
# print("SELECTED CONNECTORS:", selected_connectors)
|
|
||||||
|
|
||||||
# Check if the user has chat access to the search space
|
|
||||||
try:
|
|
||||||
await check_permission(
|
|
||||||
session,
|
|
||||||
user,
|
|
||||||
search_space_id,
|
|
||||||
Permission.CHATS_CREATE.value,
|
|
||||||
"You don't have permission to use chat in this search space",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get search space with LLM configs (preferences are now stored at search space level)
|
|
||||||
search_space_result = await session.execute(
|
|
||||||
select(SearchSpace)
|
|
||||||
.options(selectinload(SearchSpace.llm_configs))
|
|
||||||
.filter(SearchSpace.id == search_space_id)
|
|
||||||
)
|
|
||||||
search_space = search_space_result.scalars().first()
|
|
||||||
|
|
||||||
language = None
|
|
||||||
llm_configs = [] # Initialize to empty list
|
|
||||||
|
|
||||||
if search_space and search_space.llm_configs:
|
|
||||||
llm_configs = search_space.llm_configs
|
|
||||||
|
|
||||||
# Get language from configured LLM preferences
|
|
||||||
# LLM preferences are now stored on the SearchSpace model
|
|
||||||
from app.config import config as app_config
|
|
||||||
|
|
||||||
for llm_id in [
|
|
||||||
search_space.fast_llm_id,
|
|
||||||
search_space.long_context_llm_id,
|
|
||||||
search_space.strategic_llm_id,
|
|
||||||
]:
|
|
||||||
if llm_id is not None:
|
|
||||||
# Check if it's a global config (negative ID)
|
|
||||||
if llm_id < 0:
|
|
||||||
# Look in global configs
|
|
||||||
for global_cfg in app_config.GLOBAL_LLM_CONFIGS:
|
|
||||||
if global_cfg.get("id") == llm_id:
|
|
||||||
language = global_cfg.get("language")
|
|
||||||
if language:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# Look in custom configs
|
|
||||||
for llm_config in llm_configs:
|
|
||||||
if llm_config.id == llm_id and getattr(
|
|
||||||
llm_config, "language", None
|
|
||||||
):
|
|
||||||
language = llm_config.language
|
|
||||||
break
|
|
||||||
if language:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not language and llm_configs:
|
|
||||||
first_llm_config = llm_configs[0]
|
|
||||||
language = getattr(first_llm_config, "language", None)
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403, detail="You don't have access to this search space"
|
|
||||||
) from None
|
|
||||||
|
|
||||||
langchain_chat_history = []
|
|
||||||
for message in messages[:-1]:
|
|
||||||
if message["role"] == "user":
|
|
||||||
langchain_chat_history.append(HumanMessage(content=message["content"]))
|
|
||||||
elif message["role"] == "assistant":
|
|
||||||
langchain_chat_history.append(AIMessage(content=message["content"]))
|
|
||||||
|
|
||||||
response = StreamingResponse(
|
|
||||||
stream_connector_search_results(
|
|
||||||
user_query,
|
|
||||||
user.id,
|
|
||||||
search_space_id,
|
|
||||||
session,
|
|
||||||
research_mode,
|
|
||||||
selected_connectors,
|
|
||||||
langchain_chat_history,
|
|
||||||
document_ids_to_add_in_context,
|
|
||||||
language,
|
|
||||||
top_k,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
response.headers["x-vercel-ai-data-stream"] = "v1"
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/new_chat")
|
|
||||||
async def handle_new_chat(
|
|
||||||
request: NewChatRequest,
|
|
||||||
session: AsyncSession = Depends(get_async_session),
|
|
||||||
user: User = Depends(current_active_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Handle new chat requests using the SurfSense deep agent.
|
|
||||||
|
|
||||||
This endpoint uses the new deep agent with the Vercel AI SDK
|
|
||||||
Data Stream Protocol (SSE format).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: NewChatRequest containing chat_id, user_query, and search_space_id
|
|
||||||
session: Database session
|
|
||||||
user: Current authenticated user
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
StreamingResponse with SSE formatted data
|
|
||||||
"""
|
|
||||||
# Validate the user query
|
|
||||||
if not request.user_query or not request.user_query.strip():
|
|
||||||
raise HTTPException(status_code=400, detail="User query cannot be empty")
|
|
||||||
|
|
||||||
# Check if the user has chat access to the search space
|
|
||||||
try:
|
|
||||||
await check_permission(
|
|
||||||
session,
|
|
||||||
user,
|
|
||||||
request.search_space_id,
|
|
||||||
Permission.CHATS_CREATE.value,
|
|
||||||
"You don't have permission to use chat in this search space",
|
|
||||||
)
|
|
||||||
except HTTPException:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403, detail="You don't have access to this search space"
|
|
||||||
) from None
|
|
||||||
|
|
||||||
# Get LLM config ID from search space preferences (optional enhancement)
|
|
||||||
# For now, we use the default global config (-1)
|
|
||||||
llm_config_id = -1
|
|
||||||
|
|
||||||
# Optionally load LLM preferences from search space
|
|
||||||
try:
|
|
||||||
search_space_result = await session.execute(
|
|
||||||
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
|
|
||||||
)
|
|
||||||
search_space = search_space_result.scalars().first()
|
|
||||||
|
|
||||||
if search_space:
|
|
||||||
# Use strategic_llm_id if available, otherwise fall back to fast_llm_id
|
|
||||||
if search_space.strategic_llm_id is not None:
|
|
||||||
llm_config_id = search_space.strategic_llm_id
|
|
||||||
elif search_space.fast_llm_id is not None:
|
|
||||||
llm_config_id = search_space.fast_llm_id
|
|
||||||
except Exception:
|
|
||||||
# Fall back to default config on any error
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Create the streaming response
|
|
||||||
# chat_id is used as LangGraph's thread_id for automatic chat history management
|
|
||||||
response = StreamingResponse(
|
|
||||||
stream_new_chat(
|
|
||||||
user_query=request.user_query.strip(),
|
|
||||||
user_id=user.id,
|
|
||||||
search_space_id=request.search_space_id,
|
|
||||||
chat_id=request.chat_id,
|
|
||||||
session=session,
|
|
||||||
llm_config_id=llm_config_id,
|
|
||||||
messages=request.messages, # Pass message history from frontend
|
|
||||||
),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set the required headers for Vercel AI SDK
|
|
||||||
headers = VercelStreamingService.get_response_headers()
|
|
||||||
for key, value in headers.items():
|
|
||||||
response.headers[key] = value
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/chats", response_model=ChatRead)
|
|
||||||
async def create_chat(
|
|
||||||
chat: ChatCreate,
|
|
||||||
session: AsyncSession = Depends(get_async_session),
|
|
||||||
user: User = Depends(current_active_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create a new chat.
|
|
||||||
Requires CHATS_CREATE permission.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
await check_permission(
|
|
||||||
session,
|
|
||||||
user,
|
|
||||||
chat.search_space_id,
|
|
||||||
Permission.CHATS_CREATE.value,
|
|
||||||
"You don't have permission to create chats in this search space",
|
|
||||||
)
|
|
||||||
db_chat = Chat(**chat.model_dump())
|
|
||||||
session.add(db_chat)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(db_chat)
|
|
||||||
return db_chat
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except IntegrityError:
|
|
||||||
await session.rollback()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Database constraint violation. Please check your input data.",
|
|
||||||
) from None
|
|
||||||
except OperationalError:
|
|
||||||
await session.rollback()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=503, detail="Database operation failed. Please try again later."
|
|
||||||
) from None
|
|
||||||
except Exception:
|
|
||||||
await session.rollback()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="An unexpected error occurred while creating the chat.",
|
|
||||||
) from None
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/chats", response_model=list[ChatReadWithoutMessages])
|
|
||||||
async def read_chats(
|
|
||||||
skip: int = 0,
|
|
||||||
limit: int = 100,
|
|
||||||
search_space_id: int | None = None,
|
|
||||||
session: AsyncSession = Depends(get_async_session),
|
|
||||||
user: User = Depends(current_active_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
List chats the user has access to.
|
|
||||||
Requires CHATS_READ permission for the search space(s).
|
|
||||||
"""
|
|
||||||
# Validate pagination parameters
|
|
||||||
if skip < 0:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400, detail="skip must be a non-negative integer"
|
|
||||||
)
|
|
||||||
|
|
||||||
if limit <= 0 or limit > 1000: # Reasonable upper limit
|
|
||||||
raise HTTPException(status_code=400, detail="limit must be between 1 and 1000")
|
|
||||||
|
|
||||||
# Validate search_space_id if provided
|
|
||||||
if search_space_id is not None and search_space_id <= 0:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400, detail="search_space_id must be a positive integer"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
if search_space_id is not None:
|
|
||||||
# Check permission for specific search space
|
|
||||||
await check_permission(
|
|
||||||
session,
|
|
||||||
user,
|
|
||||||
search_space_id,
|
|
||||||
Permission.CHATS_READ.value,
|
|
||||||
"You don't have permission to read chats in this search space",
|
|
||||||
)
|
|
||||||
# Select specific fields excluding messages
|
|
||||||
query = (
|
|
||||||
select(
|
|
||||||
Chat.id,
|
|
||||||
Chat.type,
|
|
||||||
Chat.title,
|
|
||||||
Chat.initial_connectors,
|
|
||||||
Chat.search_space_id,
|
|
||||||
Chat.created_at,
|
|
||||||
Chat.state_version,
|
|
||||||
)
|
|
||||||
.filter(Chat.search_space_id == search_space_id)
|
|
||||||
.order_by(Chat.created_at.desc())
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Get chats from all search spaces user has membership in
|
|
||||||
query = (
|
|
||||||
select(
|
|
||||||
Chat.id,
|
|
||||||
Chat.type,
|
|
||||||
Chat.title,
|
|
||||||
Chat.initial_connectors,
|
|
||||||
Chat.search_space_id,
|
|
||||||
Chat.created_at,
|
|
||||||
Chat.state_version,
|
|
||||||
)
|
|
||||||
.join(SearchSpace)
|
|
||||||
.join(SearchSpaceMembership)
|
|
||||||
.filter(SearchSpaceMembership.user_id == user.id)
|
|
||||||
.order_by(Chat.created_at.desc())
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await session.execute(query.offset(skip).limit(limit))
|
|
||||||
return result.all()
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except OperationalError:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=503, detail="Database operation failed. Please try again later."
|
|
||||||
) from None
|
|
||||||
except Exception:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500, detail="An unexpected error occurred while fetching chats."
|
|
||||||
) from None
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/chats/search", response_model=list[ChatReadWithoutMessages])
|
|
||||||
async def search_chats(
|
|
||||||
title: str,
|
|
||||||
skip: int = 0,
|
|
||||||
limit: int = 100,
|
|
||||||
search_space_id: int | None = None,
|
|
||||||
session: AsyncSession = Depends(get_async_session),
|
|
||||||
user: User = Depends(current_active_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Search chats by title substring.
|
|
||||||
Requires CHATS_READ permission for the search space(s).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
title: Case-insensitive substring to match against chat titles. Required.
|
|
||||||
skip: Number of items to skip from the beginning. Default: 0.
|
|
||||||
limit: Maximum number of items to return. Default: 100.
|
|
||||||
search_space_id: Filter results to a specific search space. Default: None.
|
|
||||||
session: Database session (injected).
|
|
||||||
user: Current authenticated user (injected).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of chats matching the search query.
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
- Title matching uses ILIKE (case-insensitive).
|
|
||||||
- Results are ordered by creation date (most recent first).
|
|
||||||
"""
|
|
||||||
# Validate pagination parameters
|
|
||||||
if skip < 0:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400, detail="skip must be a non-negative integer"
|
|
||||||
)
|
|
||||||
|
|
||||||
if limit <= 0 or limit > 1000:
|
|
||||||
raise HTTPException(status_code=400, detail="limit must be between 1 and 1000")
|
|
||||||
|
|
||||||
# Validate search_space_id if provided
|
|
||||||
if search_space_id is not None and search_space_id <= 0:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400, detail="search_space_id must be a positive integer"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if search_space_id is not None:
|
|
||||||
# Check permission for specific search space
|
|
||||||
await check_permission(
|
|
||||||
session,
|
|
||||||
user,
|
|
||||||
search_space_id,
|
|
||||||
Permission.CHATS_READ.value,
|
|
||||||
"You don't have permission to read chats in this search space",
|
|
||||||
)
|
|
||||||
# Select specific fields excluding messages
|
|
||||||
query = (
|
|
||||||
select(
|
|
||||||
Chat.id,
|
|
||||||
Chat.type,
|
|
||||||
Chat.title,
|
|
||||||
Chat.initial_connectors,
|
|
||||||
Chat.search_space_id,
|
|
||||||
Chat.created_at,
|
|
||||||
Chat.state_version,
|
|
||||||
)
|
|
||||||
.filter(Chat.search_space_id == search_space_id)
|
|
||||||
.order_by(Chat.created_at.desc())
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Get chats from all search spaces user has membership in
|
|
||||||
query = (
|
|
||||||
select(
|
|
||||||
Chat.id,
|
|
||||||
Chat.type,
|
|
||||||
Chat.title,
|
|
||||||
Chat.initial_connectors,
|
|
||||||
Chat.search_space_id,
|
|
||||||
Chat.created_at,
|
|
||||||
Chat.state_version,
|
|
||||||
)
|
|
||||||
.join(SearchSpace)
|
|
||||||
.join(SearchSpaceMembership)
|
|
||||||
.filter(SearchSpaceMembership.user_id == user.id)
|
|
||||||
.order_by(Chat.created_at.desc())
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply title search filter (case-insensitive)
|
|
||||||
query = query.filter(Chat.title.ilike(f"%{title}%"))
|
|
||||||
|
|
||||||
result = await session.execute(query.offset(skip).limit(limit))
|
|
||||||
return result.all()
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except OperationalError:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=503, detail="Database operation failed. Please try again later."
|
|
||||||
) from None
|
|
||||||
except Exception:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="An unexpected error occurred while searching chats.",
|
|
||||||
) from None
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/chats/{chat_id}", response_model=ChatRead)
|
|
||||||
async def read_chat(
|
|
||||||
chat_id: int,
|
|
||||||
session: AsyncSession = Depends(get_async_session),
|
|
||||||
user: User = Depends(current_active_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get a specific chat by ID.
|
|
||||||
Requires CHATS_READ permission for the search space.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await session.execute(select(Chat).filter(Chat.id == chat_id))
|
|
||||||
chat = result.scalars().first()
|
|
||||||
|
|
||||||
if not chat:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail="Chat not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check permission for the search space
|
|
||||||
await check_permission(
|
|
||||||
session,
|
|
||||||
user,
|
|
||||||
chat.search_space_id,
|
|
||||||
Permission.CHATS_READ.value,
|
|
||||||
"You don't have permission to read chats in this search space",
|
|
||||||
)
|
|
||||||
|
|
||||||
return chat
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except OperationalError:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=503, detail="Database operation failed. Please try again later."
|
|
||||||
) from None
|
|
||||||
except Exception:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="An unexpected error occurred while fetching the chat.",
|
|
||||||
) from None
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/chats/{chat_id}", response_model=ChatRead)
|
|
||||||
async def update_chat(
|
|
||||||
chat_id: int,
|
|
||||||
chat_update: ChatUpdate,
|
|
||||||
session: AsyncSession = Depends(get_async_session),
|
|
||||||
user: User = Depends(current_active_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Update a chat.
|
|
||||||
Requires CHATS_UPDATE permission for the search space.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await session.execute(select(Chat).filter(Chat.id == chat_id))
|
|
||||||
db_chat = result.scalars().first()
|
|
||||||
|
|
||||||
if not db_chat:
|
|
||||||
raise HTTPException(status_code=404, detail="Chat not found")
|
|
||||||
|
|
||||||
# Check permission for the search space
|
|
||||||
await check_permission(
|
|
||||||
session,
|
|
||||||
user,
|
|
||||||
db_chat.search_space_id,
|
|
||||||
Permission.CHATS_UPDATE.value,
|
|
||||||
"You don't have permission to update chats in this search space",
|
|
||||||
)
|
|
||||||
|
|
||||||
update_data = chat_update.model_dump(exclude_unset=True)
|
|
||||||
for key, value in update_data.items():
|
|
||||||
if key == "messages":
|
|
||||||
db_chat.state_version = len(update_data["messages"])
|
|
||||||
setattr(db_chat, key, value)
|
|
||||||
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(db_chat)
|
|
||||||
return db_chat
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except IntegrityError:
|
|
||||||
await session.rollback()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Database constraint violation. Please check your input data.",
|
|
||||||
) from None
|
|
||||||
except OperationalError:
|
|
||||||
await session.rollback()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=503, detail="Database operation failed. Please try again later."
|
|
||||||
) from None
|
|
||||||
except Exception:
|
|
||||||
await session.rollback()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="An unexpected error occurred while updating the chat.",
|
|
||||||
) from None
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/chats/{chat_id}", response_model=dict)
|
|
||||||
async def delete_chat(
|
|
||||||
chat_id: int,
|
|
||||||
session: AsyncSession = Depends(get_async_session),
|
|
||||||
user: User = Depends(current_active_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Delete a chat.
|
|
||||||
Requires CHATS_DELETE permission for the search space.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await session.execute(select(Chat).filter(Chat.id == chat_id))
|
|
||||||
db_chat = result.scalars().first()
|
|
||||||
|
|
||||||
if not db_chat:
|
|
||||||
raise HTTPException(status_code=404, detail="Chat not found")
|
|
||||||
|
|
||||||
# Check permission for the search space
|
|
||||||
await check_permission(
|
|
||||||
session,
|
|
||||||
user,
|
|
||||||
db_chat.search_space_id,
|
|
||||||
Permission.CHATS_DELETE.value,
|
|
||||||
"You don't have permission to delete chats in this search space",
|
|
||||||
)
|
|
||||||
|
|
||||||
await session.delete(db_chat)
|
|
||||||
await session.commit()
|
|
||||||
return {"message": "Chat deleted successfully"}
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except IntegrityError:
|
|
||||||
await session.rollback()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400, detail="Cannot delete chat due to existing dependencies."
|
|
||||||
) from None
|
|
||||||
except OperationalError:
|
|
||||||
await session.rollback()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=503, detail="Database operation failed. Please try again later."
|
|
||||||
) from None
|
|
||||||
except Exception:
|
|
||||||
await session.rollback()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="An unexpected error occurred while deleting the chat.",
|
|
||||||
) from None
|
|
||||||
|
|
@ -13,6 +13,7 @@ These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
|
@ -23,12 +24,14 @@ from app.db import (
|
||||||
NewChatMessageRole,
|
NewChatMessageRole,
|
||||||
NewChatThread,
|
NewChatThread,
|
||||||
Permission,
|
Permission,
|
||||||
|
SearchSpace,
|
||||||
User,
|
User,
|
||||||
get_async_session,
|
get_async_session,
|
||||||
)
|
)
|
||||||
from app.schemas.new_chat import (
|
from app.schemas.new_chat import (
|
||||||
NewChatMessageAppend,
|
NewChatMessageAppend,
|
||||||
NewChatMessageRead,
|
NewChatMessageRead,
|
||||||
|
NewChatRequest,
|
||||||
NewChatThreadCreate,
|
NewChatThreadCreate,
|
||||||
NewChatThreadRead,
|
NewChatThreadRead,
|
||||||
NewChatThreadUpdate,
|
NewChatThreadUpdate,
|
||||||
|
|
@ -37,6 +40,7 @@ from app.schemas.new_chat import (
|
||||||
ThreadListItem,
|
ThreadListItem,
|
||||||
ThreadListResponse,
|
ThreadListResponse,
|
||||||
)
|
)
|
||||||
|
from app.tasks.chat.stream_new_chat import stream_new_chat
|
||||||
from app.users import current_active_user
|
from app.users import current_active_user
|
||||||
from app.utils.rbac import check_permission
|
from app.utils.rbac import check_permission
|
||||||
|
|
||||||
|
|
@ -74,13 +78,10 @@ async def list_threads(
|
||||||
"You don't have permission to read chats in this search space",
|
"You don't have permission to read chats in this search space",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get all threads for this user in this search space
|
# Get all threads in this search space
|
||||||
query = (
|
query = (
|
||||||
select(NewChatThread)
|
select(NewChatThread)
|
||||||
.filter(
|
.filter(NewChatThread.search_space_id == search_space_id)
|
||||||
NewChatThread.search_space_id == search_space_id,
|
|
||||||
NewChatThread.user_id == user.id,
|
|
||||||
)
|
|
||||||
.order_by(NewChatThread.updated_at.desc())
|
.order_by(NewChatThread.updated_at.desc())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -153,7 +154,6 @@ async def search_threads(
|
||||||
select(NewChatThread)
|
select(NewChatThread)
|
||||||
.filter(
|
.filter(
|
||||||
NewChatThread.search_space_id == search_space_id,
|
NewChatThread.search_space_id == search_space_id,
|
||||||
NewChatThread.user_id == user.id,
|
|
||||||
NewChatThread.title.ilike(f"%{title}%"),
|
NewChatThread.title.ilike(f"%{title}%"),
|
||||||
)
|
)
|
||||||
.order_by(NewChatThread.updated_at.desc())
|
.order_by(NewChatThread.updated_at.desc())
|
||||||
|
|
@ -211,7 +211,6 @@ async def create_thread(
|
||||||
title=thread.title,
|
title=thread.title,
|
||||||
archived=thread.archived,
|
archived=thread.archived,
|
||||||
search_space_id=thread.search_space_id,
|
search_space_id=thread.search_space_id,
|
||||||
user_id=user.id,
|
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
)
|
)
|
||||||
session.add(db_thread)
|
session.add(db_thread)
|
||||||
|
|
@ -273,12 +272,6 @@ async def get_thread_messages(
|
||||||
"You don't have permission to read chats in this search space",
|
"You don't have permission to read chats in this search space",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure user owns this thread
|
|
||||||
if thread.user_id != user.id:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403, detail="You don't have access to this thread"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return messages in the format expected by assistant-ui
|
# Return messages in the format expected by assistant-ui
|
||||||
messages = [
|
messages = [
|
||||||
NewChatMessageRead(
|
NewChatMessageRead(
|
||||||
|
|
@ -336,11 +329,6 @@ async def get_thread_full(
|
||||||
"You don't have permission to read chats in this search space",
|
"You don't have permission to read chats in this search space",
|
||||||
)
|
)
|
||||||
|
|
||||||
if thread.user_id != user.id:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403, detail="You don't have access to this thread"
|
|
||||||
)
|
|
||||||
|
|
||||||
return thread
|
return thread
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|
@ -386,11 +374,6 @@ async def update_thread(
|
||||||
"You don't have permission to update chats in this search space",
|
"You don't have permission to update chats in this search space",
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_thread.user_id != user.id:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403, detail="You don't have access to this thread"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update fields
|
# Update fields
|
||||||
update_data = thread_update.model_dump(exclude_unset=True)
|
update_data = thread_update.model_dump(exclude_unset=True)
|
||||||
for key, value in update_data.items():
|
for key, value in update_data.items():
|
||||||
|
|
@ -451,11 +434,6 @@ async def delete_thread(
|
||||||
"You don't have permission to delete chats in this search space",
|
"You don't have permission to delete chats in this search space",
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_thread.user_id != user.id:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403, detail="You don't have access to this thread"
|
|
||||||
)
|
|
||||||
|
|
||||||
await session.delete(db_thread)
|
await session.delete(db_thread)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return {"message": "Thread deleted successfully"}
|
return {"message": "Thread deleted successfully"}
|
||||||
|
|
@ -530,11 +508,6 @@ async def append_message(
|
||||||
"You don't have permission to update chats in this search space",
|
"You don't have permission to update chats in this search space",
|
||||||
)
|
)
|
||||||
|
|
||||||
if thread.user_id != user.id:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403, detail="You don't have access to this thread"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert string role to enum
|
# Convert string role to enum
|
||||||
role_str = (
|
role_str = (
|
||||||
message.role.lower() if isinstance(message.role, str) else message.role
|
message.role.lower() if isinstance(message.role, str) else message.role
|
||||||
|
|
@ -639,11 +612,6 @@ async def list_messages(
|
||||||
"You don't have permission to read chats in this search space",
|
"You don't have permission to read chats in this search space",
|
||||||
)
|
)
|
||||||
|
|
||||||
if thread.user_id != user.id:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403, detail="You don't have access to this thread"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get messages
|
# Get messages
|
||||||
query = (
|
query = (
|
||||||
select(NewChatMessage)
|
select(NewChatMessage)
|
||||||
|
|
@ -667,3 +635,79 @@ async def list_messages(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail=f"An unexpected error occurred while fetching messages: {e!s}",
|
detail=f"An unexpected error occurred while fetching messages: {e!s}",
|
||||||
) from None
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Chat Streaming Endpoint
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/new_chat")
|
||||||
|
async def handle_new_chat(
|
||||||
|
request: NewChatRequest,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Stream chat responses from the deep agent.
|
||||||
|
|
||||||
|
This endpoint handles the new chat functionality with streaming responses
|
||||||
|
using Server-Sent Events (SSE) format compatible with Vercel AI SDK.
|
||||||
|
|
||||||
|
Requires CHATS_CREATE permission.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Verify thread exists and user has permission
|
||||||
|
result = await session.execute(
|
||||||
|
select(NewChatThread).filter(NewChatThread.id == request.chat_id)
|
||||||
|
)
|
||||||
|
thread = result.scalars().first()
|
||||||
|
|
||||||
|
if not thread:
|
||||||
|
raise HTTPException(status_code=404, detail="Thread not found")
|
||||||
|
|
||||||
|
await check_permission(
|
||||||
|
session,
|
||||||
|
user,
|
||||||
|
thread.search_space_id,
|
||||||
|
Permission.CHATS_CREATE.value,
|
||||||
|
"You don't have permission to chat in this search space",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get search space to check LLM config preferences
|
||||||
|
search_space_result = await session.execute(
|
||||||
|
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
|
||||||
|
)
|
||||||
|
search_space = search_space_result.scalars().first()
|
||||||
|
|
||||||
|
# Determine LLM config ID (use search space preference or default)
|
||||||
|
llm_config_id = -1 # Default to first global config
|
||||||
|
if search_space and search_space.fast_llm_id:
|
||||||
|
llm_config_id = search_space.fast_llm_id
|
||||||
|
|
||||||
|
# Return streaming response
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_new_chat(
|
||||||
|
user_query=request.user_query,
|
||||||
|
user_id=str(user.id),
|
||||||
|
search_space_id=request.search_space_id,
|
||||||
|
chat_id=request.chat_id,
|
||||||
|
session=session,
|
||||||
|
llm_config_id=llm_config_id,
|
||||||
|
messages=request.messages,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"An unexpected error occurred: {e!s}",
|
||||||
|
) from None
|
||||||
|
|
|
||||||
|
|
@ -1,509 +0,0 @@
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
from app.db import (
|
|
||||||
Chat,
|
|
||||||
Permission,
|
|
||||||
Podcast,
|
|
||||||
SearchSpace,
|
|
||||||
SearchSpaceMembership,
|
|
||||||
User,
|
|
||||||
get_async_session,
|
|
||||||
)
|
|
||||||
from app.schemas import (
|
|
||||||
PodcastCreate,
|
|
||||||
PodcastGenerateRequest,
|
|
||||||
PodcastRead,
|
|
||||||
PodcastUpdate,
|
|
||||||
)
|
|
||||||
from app.tasks.podcast_tasks import generate_chat_podcast
|
|
||||||
from app.users import current_active_user
|
|
||||||
from app.utils.rbac import check_permission
|
|
||||||
|
|
||||||
router = APIRouter()
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/podcasts", response_model=PodcastRead)
|
|
||||||
async def create_podcast(
|
|
||||||
podcast: PodcastCreate,
|
|
||||||
session: AsyncSession = Depends(get_async_session),
|
|
||||||
user: User = Depends(current_active_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create a new podcast.
|
|
||||||
Requires PODCASTS_CREATE permission.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
await check_permission(
|
|
||||||
session,
|
|
||||||
user,
|
|
||||||
podcast.search_space_id,
|
|
||||||
Permission.PODCASTS_CREATE.value,
|
|
||||||
"You don't have permission to create podcasts in this search space",
|
|
||||||
)
|
|
||||||
db_podcast = Podcast(**podcast.model_dump())
|
|
||||||
session.add(db_podcast)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(db_podcast)
|
|
||||||
return db_podcast
|
|
||||||
except HTTPException as he:
|
|
||||||
raise he
|
|
||||||
except IntegrityError:
|
|
||||||
await session.rollback()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Podcast creation failed due to constraint violation",
|
|
||||||
) from None
|
|
||||||
except SQLAlchemyError:
|
|
||||||
await session.rollback()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500, detail="Database error occurred while creating podcast"
|
|
||||||
) from None
|
|
||||||
except Exception:
|
|
||||||
await session.rollback()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500, detail="An unexpected error occurred"
|
|
||||||
) from None
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/podcasts", response_model=list[PodcastRead])
|
|
||||||
async def read_podcasts(
|
|
||||||
skip: int = 0,
|
|
||||||
limit: int = 100,
|
|
||||||
search_space_id: int | None = None,
|
|
||||||
session: AsyncSession = Depends(get_async_session),
|
|
||||||
user: User = Depends(current_active_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
List podcasts the user has access to.
|
|
||||||
Requires PODCASTS_READ permission for the search space(s).
|
|
||||||
"""
|
|
||||||
if skip < 0 or limit < 1:
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid pagination parameters")
|
|
||||||
try:
|
|
||||||
if search_space_id is not None:
|
|
||||||
# Check permission for specific search space
|
|
||||||
await check_permission(
|
|
||||||
session,
|
|
||||||
user,
|
|
||||||
search_space_id,
|
|
||||||
Permission.PODCASTS_READ.value,
|
|
||||||
"You don't have permission to read podcasts in this search space",
|
|
||||||
)
|
|
||||||
result = await session.execute(
|
|
||||||
select(Podcast)
|
|
||||||
.filter(Podcast.search_space_id == search_space_id)
|
|
||||||
.offset(skip)
|
|
||||||
.limit(limit)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Get podcasts from all search spaces user has membership in
|
|
||||||
result = await session.execute(
|
|
||||||
select(Podcast)
|
|
||||||
.join(SearchSpace)
|
|
||||||
.join(SearchSpaceMembership)
|
|
||||||
.filter(SearchSpaceMembership.user_id == user.id)
|
|
||||||
.offset(skip)
|
|
||||||
.limit(limit)
|
|
||||||
)
|
|
||||||
return result.scalars().all()
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except SQLAlchemyError:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500, detail="Database error occurred while fetching podcasts"
|
|
||||||
) from None
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/podcasts/{podcast_id}", response_model=PodcastRead)
|
|
||||||
async def read_podcast(
|
|
||||||
podcast_id: int,
|
|
||||||
session: AsyncSession = Depends(get_async_session),
|
|
||||||
user: User = Depends(current_active_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get a specific podcast by ID.
|
|
||||||
Requires PODCASTS_READ permission for the search space.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
|
|
||||||
podcast = result.scalars().first()
|
|
||||||
|
|
||||||
if not podcast:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail="Podcast not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check permission for the search space
|
|
||||||
await check_permission(
|
|
||||||
session,
|
|
||||||
user,
|
|
||||||
podcast.search_space_id,
|
|
||||||
Permission.PODCASTS_READ.value,
|
|
||||||
"You don't have permission to read podcasts in this search space",
|
|
||||||
)
|
|
||||||
|
|
||||||
return podcast
|
|
||||||
except HTTPException as he:
|
|
||||||
raise he
|
|
||||||
except SQLAlchemyError:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500, detail="Database error occurred while fetching podcast"
|
|
||||||
) from None
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/podcasts/{podcast_id}", response_model=PodcastRead)
|
|
||||||
async def update_podcast(
|
|
||||||
podcast_id: int,
|
|
||||||
podcast_update: PodcastUpdate,
|
|
||||||
session: AsyncSession = Depends(get_async_session),
|
|
||||||
user: User = Depends(current_active_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Update a podcast.
|
|
||||||
Requires PODCASTS_UPDATE permission for the search space.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
|
|
||||||
db_podcast = result.scalars().first()
|
|
||||||
|
|
||||||
if not db_podcast:
|
|
||||||
raise HTTPException(status_code=404, detail="Podcast not found")
|
|
||||||
|
|
||||||
# Check permission for the search space
|
|
||||||
await check_permission(
|
|
||||||
session,
|
|
||||||
user,
|
|
||||||
db_podcast.search_space_id,
|
|
||||||
Permission.PODCASTS_UPDATE.value,
|
|
||||||
"You don't have permission to update podcasts in this search space",
|
|
||||||
)
|
|
||||||
|
|
||||||
update_data = podcast_update.model_dump(exclude_unset=True)
|
|
||||||
for key, value in update_data.items():
|
|
||||||
setattr(db_podcast, key, value)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(db_podcast)
|
|
||||||
return db_podcast
|
|
||||||
except HTTPException as he:
|
|
||||||
raise he
|
|
||||||
except IntegrityError:
|
|
||||||
await session.rollback()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400, detail="Update failed due to constraint violation"
|
|
||||||
) from None
|
|
||||||
except SQLAlchemyError:
|
|
||||||
await session.rollback()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500, detail="Database error occurred while updating podcast"
|
|
||||||
) from None
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/podcasts/{podcast_id}", response_model=dict)
|
|
||||||
async def delete_podcast(
|
|
||||||
podcast_id: int,
|
|
||||||
session: AsyncSession = Depends(get_async_session),
|
|
||||||
user: User = Depends(current_active_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Delete a podcast.
|
|
||||||
Requires PODCASTS_DELETE permission for the search space.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
|
|
||||||
db_podcast = result.scalars().first()
|
|
||||||
|
|
||||||
if not db_podcast:
|
|
||||||
raise HTTPException(status_code=404, detail="Podcast not found")
|
|
||||||
|
|
||||||
# Check permission for the search space
|
|
||||||
await check_permission(
|
|
||||||
session,
|
|
||||||
user,
|
|
||||||
db_podcast.search_space_id,
|
|
||||||
Permission.PODCASTS_DELETE.value,
|
|
||||||
"You don't have permission to delete podcasts in this search space",
|
|
||||||
)
|
|
||||||
|
|
||||||
await session.delete(db_podcast)
|
|
||||||
await session.commit()
|
|
||||||
return {"message": "Podcast deleted successfully"}
|
|
||||||
except HTTPException as he:
|
|
||||||
raise he
|
|
||||||
except SQLAlchemyError:
|
|
||||||
await session.rollback()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500, detail="Database error occurred while deleting podcast"
|
|
||||||
) from None
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_chat_podcast_with_new_session(
|
|
||||||
chat_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: int,
|
|
||||||
podcast_title: str | None = None,
|
|
||||||
user_prompt: str | None = None,
|
|
||||||
):
|
|
||||||
"""Create a new session and process chat podcast generation."""
|
|
||||||
from app.db import async_session_maker
|
|
||||||
|
|
||||||
async with async_session_maker() as session:
|
|
||||||
try:
|
|
||||||
await generate_chat_podcast(
|
|
||||||
session, chat_id, search_space_id, user_id, podcast_title, user_prompt
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logging.error(f"Error generating podcast from chat: {e!s}")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/podcasts/generate")
|
|
||||||
async def generate_podcast(
|
|
||||||
request: PodcastGenerateRequest,
|
|
||||||
session: AsyncSession = Depends(get_async_session),
|
|
||||||
user: User = Depends(current_active_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Generate a podcast from a chat or document.
|
|
||||||
Requires PODCASTS_CREATE permission.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Check if the user has permission to create podcasts
|
|
||||||
await check_permission(
|
|
||||||
session,
|
|
||||||
user,
|
|
||||||
request.search_space_id,
|
|
||||||
Permission.PODCASTS_CREATE.value,
|
|
||||||
"You don't have permission to create podcasts in this search space",
|
|
||||||
)
|
|
||||||
|
|
||||||
if request.type == "CHAT":
|
|
||||||
# Verify that all chat IDs belong to this user and search space
|
|
||||||
query = (
|
|
||||||
select(Chat)
|
|
||||||
.filter(
|
|
||||||
Chat.id.in_(request.ids),
|
|
||||||
Chat.search_space_id == request.search_space_id,
|
|
||||||
)
|
|
||||||
.join(SearchSpace)
|
|
||||||
.filter(SearchSpace.user_id == user.id)
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await session.execute(query)
|
|
||||||
valid_chats = result.scalars().all()
|
|
||||||
valid_chat_ids = [chat.id for chat in valid_chats]
|
|
||||||
|
|
||||||
# If any requested ID is not in valid IDs, raise error immediately
|
|
||||||
if len(valid_chat_ids) != len(request.ids):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403,
|
|
||||||
detail="One or more chat IDs do not belong to this user or search space",
|
|
||||||
)
|
|
||||||
|
|
||||||
from app.tasks.celery_tasks.podcast_tasks import (
|
|
||||||
generate_chat_podcast_task,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add Celery tasks for each chat ID
|
|
||||||
for chat_id in valid_chat_ids:
|
|
||||||
generate_chat_podcast_task.delay(
|
|
||||||
chat_id,
|
|
||||||
request.search_space_id,
|
|
||||||
user.id,
|
|
||||||
request.podcast_title,
|
|
||||||
request.user_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"message": "Podcast generation started",
|
|
||||||
}
|
|
||||||
except HTTPException as he:
|
|
||||||
raise he
|
|
||||||
except IntegrityError:
|
|
||||||
await session.rollback()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Podcast generation failed due to constraint violation",
|
|
||||||
) from None
|
|
||||||
except SQLAlchemyError:
|
|
||||||
await session.rollback()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500, detail="Database error occurred while generating podcast"
|
|
||||||
) from None
|
|
||||||
except Exception as e:
|
|
||||||
await session.rollback()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500, detail=f"An unexpected error occurred: {e!s}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/podcasts/{podcast_id}/stream")
|
|
||||||
async def stream_podcast(
|
|
||||||
podcast_id: int,
|
|
||||||
session: AsyncSession = Depends(get_async_session),
|
|
||||||
user: User = Depends(current_active_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Stream a podcast audio file.
|
|
||||||
Requires PODCASTS_READ permission for the search space.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
|
|
||||||
podcast = result.scalars().first()
|
|
||||||
|
|
||||||
if not podcast:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail="Podcast not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check permission for the search space
|
|
||||||
await check_permission(
|
|
||||||
session,
|
|
||||||
user,
|
|
||||||
podcast.search_space_id,
|
|
||||||
Permission.PODCASTS_READ.value,
|
|
||||||
"You don't have permission to access podcasts in this search space",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the file path
|
|
||||||
file_path = podcast.file_location
|
|
||||||
|
|
||||||
# Check if the file exists
|
|
||||||
if not os.path.isfile(file_path):
|
|
||||||
raise HTTPException(status_code=404, detail="Podcast audio file not found")
|
|
||||||
|
|
||||||
# Define a generator function to stream the file
|
|
||||||
def iterfile():
|
|
||||||
with open(file_path, mode="rb") as file_like:
|
|
||||||
yield from file_like
|
|
||||||
|
|
||||||
# Return a streaming response with appropriate headers
|
|
||||||
return StreamingResponse(
|
|
||||||
iterfile(),
|
|
||||||
media_type="audio/mpeg",
|
|
||||||
headers={
|
|
||||||
"Accept-Ranges": "bytes",
|
|
||||||
"Content-Disposition": f"inline; filename={Path(file_path).name}",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException as he:
|
|
||||||
raise he
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500, detail=f"Error streaming podcast: {e!s}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/podcasts/by-chat/{chat_id}", response_model=PodcastRead | None)
|
|
||||||
async def get_podcast_by_chat_id(
|
|
||||||
chat_id: int,
|
|
||||||
session: AsyncSession = Depends(get_async_session),
|
|
||||||
user: User = Depends(current_active_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get a podcast by its associated chat ID.
|
|
||||||
Requires PODCASTS_READ permission for the search space.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# First get the chat to find its search space
|
|
||||||
chat_result = await session.execute(select(Chat).filter(Chat.id == chat_id))
|
|
||||||
chat = chat_result.scalars().first()
|
|
||||||
|
|
||||||
if not chat:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Check permission for the search space
|
|
||||||
await check_permission(
|
|
||||||
session,
|
|
||||||
user,
|
|
||||||
chat.search_space_id,
|
|
||||||
Permission.PODCASTS_READ.value,
|
|
||||||
"You don't have permission to read podcasts in this search space",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the podcast
|
|
||||||
result = await session.execute(
|
|
||||||
select(Podcast).filter(Podcast.chat_id == chat_id)
|
|
||||||
)
|
|
||||||
podcast = result.scalars().first()
|
|
||||||
|
|
||||||
return podcast
|
|
||||||
except HTTPException as he:
|
|
||||||
raise he
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500, detail=f"Error fetching podcast: {e!s}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/podcasts/task/{task_id}/status")
|
|
||||||
async def get_podcast_task_status(
|
|
||||||
task_id: str,
|
|
||||||
user: User = Depends(current_active_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get the status of a podcast generation task.
|
|
||||||
Used by new-chat frontend to poll for completion.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- status: "processing" | "success" | "error"
|
|
||||||
- podcast_id: (only if status == "success")
|
|
||||||
- title: (only if status == "success")
|
|
||||||
- error: (only if status == "error")
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from celery.result import AsyncResult
|
|
||||||
|
|
||||||
from app.celery_app import celery_app
|
|
||||||
|
|
||||||
result = AsyncResult(task_id, app=celery_app)
|
|
||||||
|
|
||||||
if result.ready():
|
|
||||||
# Task completed
|
|
||||||
if result.successful():
|
|
||||||
task_result = result.result
|
|
||||||
if isinstance(task_result, dict):
|
|
||||||
if task_result.get("status") == "success":
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"podcast_id": task_result.get("podcast_id"),
|
|
||||||
"title": task_result.get("title"),
|
|
||||||
"transcript_entries": task_result.get("transcript_entries"),
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"error": task_result.get("error", "Unknown error"),
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"error": "Unexpected task result format",
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
# Task failed
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"error": str(result.result) if result.result else "Task failed",
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
# Task still processing
|
|
||||||
return {
|
|
||||||
"status": "processing",
|
|
||||||
"state": result.state,
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500, detail=f"Error checking task status: {e!s}"
|
|
||||||
) from e
|
|
||||||
|
|
@ -1,13 +1,4 @@
|
||||||
from .base import IDModel, TimestampModel
|
from .base import IDModel, TimestampModel
|
||||||
from .chats import (
|
|
||||||
AISDKChatRequest,
|
|
||||||
ChatBase,
|
|
||||||
ChatCreate,
|
|
||||||
ChatRead,
|
|
||||||
ChatReadWithoutMessages,
|
|
||||||
ChatUpdate,
|
|
||||||
NewChatRequest,
|
|
||||||
)
|
|
||||||
from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate
|
from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate
|
||||||
from .documents import (
|
from .documents import (
|
||||||
DocumentBase,
|
DocumentBase,
|
||||||
|
|
@ -22,9 +13,11 @@ from .documents import (
|
||||||
from .llm_config import LLMConfigBase, LLMConfigCreate, LLMConfigRead, LLMConfigUpdate
|
from .llm_config import LLMConfigBase, LLMConfigCreate, LLMConfigRead, LLMConfigUpdate
|
||||||
from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate
|
from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate
|
||||||
from .new_chat import (
|
from .new_chat import (
|
||||||
|
ChatMessage,
|
||||||
NewChatMessageAppend,
|
NewChatMessageAppend,
|
||||||
NewChatMessageCreate,
|
NewChatMessageCreate,
|
||||||
NewChatMessageRead,
|
NewChatMessageRead,
|
||||||
|
NewChatRequest,
|
||||||
NewChatThreadCreate,
|
NewChatThreadCreate,
|
||||||
NewChatThreadRead,
|
NewChatThreadRead,
|
||||||
NewChatThreadUpdate,
|
NewChatThreadUpdate,
|
||||||
|
|
@ -33,13 +26,6 @@ from .new_chat import (
|
||||||
ThreadListItem,
|
ThreadListItem,
|
||||||
ThreadListResponse,
|
ThreadListResponse,
|
||||||
)
|
)
|
||||||
from .podcasts import (
|
|
||||||
PodcastBase,
|
|
||||||
PodcastCreate,
|
|
||||||
PodcastGenerateRequest,
|
|
||||||
PodcastRead,
|
|
||||||
PodcastUpdate,
|
|
||||||
)
|
|
||||||
from .rbac_schemas import (
|
from .rbac_schemas import (
|
||||||
InviteAcceptRequest,
|
InviteAcceptRequest,
|
||||||
InviteAcceptResponse,
|
InviteAcceptResponse,
|
||||||
|
|
@ -73,44 +59,8 @@ from .search_space import (
|
||||||
from .users import UserCreate, UserRead, UserUpdate
|
from .users import UserCreate, UserRead, UserUpdate
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AISDKChatRequest",
|
# Chat schemas (assistant-ui integration)
|
||||||
"ChatBase",
|
"ChatMessage",
|
||||||
"ChatCreate",
|
|
||||||
"ChatRead",
|
|
||||||
"ChatReadWithoutMessages",
|
|
||||||
"ChatUpdate",
|
|
||||||
"ChunkBase",
|
|
||||||
"ChunkCreate",
|
|
||||||
"ChunkRead",
|
|
||||||
"ChunkUpdate",
|
|
||||||
"DocumentBase",
|
|
||||||
"DocumentRead",
|
|
||||||
"DocumentUpdate",
|
|
||||||
"DocumentWithChunksRead",
|
|
||||||
"DocumentsCreate",
|
|
||||||
"ExtensionDocumentContent",
|
|
||||||
"ExtensionDocumentMetadata",
|
|
||||||
"IDModel",
|
|
||||||
# RBAC schemas
|
|
||||||
"InviteAcceptRequest",
|
|
||||||
"InviteAcceptResponse",
|
|
||||||
"InviteCreate",
|
|
||||||
"InviteInfoResponse",
|
|
||||||
"InviteRead",
|
|
||||||
"InviteUpdate",
|
|
||||||
"LLMConfigBase",
|
|
||||||
"LLMConfigCreate",
|
|
||||||
"LLMConfigRead",
|
|
||||||
"LLMConfigUpdate",
|
|
||||||
"LogBase",
|
|
||||||
"LogCreate",
|
|
||||||
"LogFilter",
|
|
||||||
"LogRead",
|
|
||||||
"LogUpdate",
|
|
||||||
"MembershipRead",
|
|
||||||
"MembershipReadWithUser",
|
|
||||||
"MembershipUpdate",
|
|
||||||
# New chat schemas (assistant-ui integration)
|
|
||||||
"NewChatMessageAppend",
|
"NewChatMessageAppend",
|
||||||
"NewChatMessageCreate",
|
"NewChatMessageCreate",
|
||||||
"NewChatMessageRead",
|
"NewChatMessageRead",
|
||||||
|
|
@ -119,30 +69,64 @@ __all__ = [
|
||||||
"NewChatThreadRead",
|
"NewChatThreadRead",
|
||||||
"NewChatThreadUpdate",
|
"NewChatThreadUpdate",
|
||||||
"NewChatThreadWithMessages",
|
"NewChatThreadWithMessages",
|
||||||
|
"ThreadHistoryLoadResponse",
|
||||||
|
"ThreadListItem",
|
||||||
|
"ThreadListResponse",
|
||||||
|
# Chunk schemas
|
||||||
|
"ChunkBase",
|
||||||
|
"ChunkCreate",
|
||||||
|
"ChunkRead",
|
||||||
|
"ChunkUpdate",
|
||||||
|
# Document schemas
|
||||||
|
"DocumentBase",
|
||||||
|
"DocumentRead",
|
||||||
|
"DocumentUpdate",
|
||||||
|
"DocumentWithChunksRead",
|
||||||
|
"DocumentsCreate",
|
||||||
|
"ExtensionDocumentContent",
|
||||||
|
"ExtensionDocumentMetadata",
|
||||||
"PaginatedResponse",
|
"PaginatedResponse",
|
||||||
|
# Base schemas
|
||||||
|
"IDModel",
|
||||||
|
"TimestampModel",
|
||||||
|
# LLM Config schemas
|
||||||
|
"LLMConfigBase",
|
||||||
|
"LLMConfigCreate",
|
||||||
|
"LLMConfigRead",
|
||||||
|
"LLMConfigUpdate",
|
||||||
|
# Log schemas
|
||||||
|
"LogBase",
|
||||||
|
"LogCreate",
|
||||||
|
"LogFilter",
|
||||||
|
"LogRead",
|
||||||
|
"LogUpdate",
|
||||||
|
# RBAC schemas
|
||||||
|
"InviteAcceptRequest",
|
||||||
|
"InviteAcceptResponse",
|
||||||
|
"InviteCreate",
|
||||||
|
"InviteInfoResponse",
|
||||||
|
"InviteRead",
|
||||||
|
"InviteUpdate",
|
||||||
|
"MembershipRead",
|
||||||
|
"MembershipReadWithUser",
|
||||||
|
"MembershipUpdate",
|
||||||
"PermissionInfo",
|
"PermissionInfo",
|
||||||
"PermissionsListResponse",
|
"PermissionsListResponse",
|
||||||
"PodcastBase",
|
|
||||||
"PodcastCreate",
|
|
||||||
"PodcastGenerateRequest",
|
|
||||||
"PodcastRead",
|
|
||||||
"PodcastUpdate",
|
|
||||||
"RoleCreate",
|
"RoleCreate",
|
||||||
"RoleRead",
|
"RoleRead",
|
||||||
"RoleUpdate",
|
"RoleUpdate",
|
||||||
|
# Search source connector schemas
|
||||||
"SearchSourceConnectorBase",
|
"SearchSourceConnectorBase",
|
||||||
"SearchSourceConnectorCreate",
|
"SearchSourceConnectorCreate",
|
||||||
"SearchSourceConnectorRead",
|
"SearchSourceConnectorRead",
|
||||||
"SearchSourceConnectorUpdate",
|
"SearchSourceConnectorUpdate",
|
||||||
|
# Search space schemas
|
||||||
"SearchSpaceBase",
|
"SearchSpaceBase",
|
||||||
"SearchSpaceCreate",
|
"SearchSpaceCreate",
|
||||||
"SearchSpaceRead",
|
"SearchSpaceRead",
|
||||||
"SearchSpaceUpdate",
|
"SearchSpaceUpdate",
|
||||||
"SearchSpaceWithStats",
|
"SearchSpaceWithStats",
|
||||||
"ThreadHistoryLoadResponse",
|
# User schemas
|
||||||
"ThreadListItem",
|
|
||||||
"ThreadListResponse",
|
|
||||||
"TimestampModel",
|
|
||||||
"UserCreate",
|
"UserCreate",
|
||||||
"UserRead",
|
"UserRead",
|
||||||
"UserSearchSpaceAccess",
|
"UserSearchSpaceAccess",
|
||||||
|
|
|
||||||
|
|
@ -1,80 +0,0 @@
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
|
|
||||||
from app.db import ChatType
|
|
||||||
|
|
||||||
from .base import IDModel, TimestampModel
|
|
||||||
|
|
||||||
|
|
||||||
class ChatBase(BaseModel):
|
|
||||||
type: ChatType
|
|
||||||
title: str
|
|
||||||
initial_connectors: list[str] | None = None
|
|
||||||
messages: list[Any]
|
|
||||||
search_space_id: int
|
|
||||||
state_version: int = 1
|
|
||||||
|
|
||||||
|
|
||||||
class ChatBaseWithoutMessages(BaseModel):
|
|
||||||
type: ChatType
|
|
||||||
title: str
|
|
||||||
search_space_id: int
|
|
||||||
state_version: int = 1
|
|
||||||
|
|
||||||
|
|
||||||
class ClientAttachment(BaseModel):
|
|
||||||
name: str
|
|
||||||
content_type: str
|
|
||||||
url: str
|
|
||||||
|
|
||||||
|
|
||||||
class ToolInvocation(BaseModel):
|
|
||||||
tool_call_id: str
|
|
||||||
tool_name: str
|
|
||||||
args: dict
|
|
||||||
result: dict
|
|
||||||
|
|
||||||
|
|
||||||
# class ClientMessage(BaseModel):
|
|
||||||
# role: str
|
|
||||||
# content: str
|
|
||||||
# experimental_attachments: Optional[List[ClientAttachment]] = None
|
|
||||||
# toolInvocations: Optional[List[ToolInvocation]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class AISDKChatRequest(BaseModel):
|
|
||||||
messages: list[Any]
|
|
||||||
data: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
|
||||||
"""A single message in the chat history."""
|
|
||||||
|
|
||||||
role: str # "user" or "assistant"
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
class NewChatRequest(BaseModel):
|
|
||||||
"""Request schema for the new deep agent chat endpoint."""
|
|
||||||
|
|
||||||
chat_id: int
|
|
||||||
user_query: str
|
|
||||||
search_space_id: int
|
|
||||||
messages: list[ChatMessage] | None = None # Optional chat history from frontend
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCreate(ChatBase):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ChatUpdate(ChatBase):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ChatRead(ChatBase, IDModel, TimestampModel):
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
|
|
||||||
class ChatReadWithoutMessages(ChatBaseWithoutMessages, IDModel, TimestampModel):
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
@ -127,3 +127,24 @@ class ThreadListResponse(BaseModel):
|
||||||
|
|
||||||
threads: list[ThreadListItem]
|
threads: list[ThreadListItem]
|
||||||
archived_threads: list[ThreadListItem]
|
archived_threads: list[ThreadListItem]
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Chat Request Schemas (for deep agent)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
"""A single message in the chat history."""
|
||||||
|
|
||||||
|
role: str # "user" or "assistant"
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class NewChatRequest(BaseModel):
|
||||||
|
"""Request schema for the deep agent chat endpoint."""
|
||||||
|
|
||||||
|
chat_id: int
|
||||||
|
user_query: str
|
||||||
|
search_space_id: int
|
||||||
|
messages: list[ChatMessage] | None = None # Optional chat history from frontend
|
||||||
|
|
|
||||||
|
|
@ -1,33 +0,0 @@
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
|
|
||||||
from .base import IDModel, TimestampModel
|
|
||||||
|
|
||||||
|
|
||||||
class PodcastBase(BaseModel):
|
|
||||||
title: str
|
|
||||||
podcast_transcript: list[Any]
|
|
||||||
file_location: str = ""
|
|
||||||
search_space_id: int
|
|
||||||
chat_state_version: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class PodcastCreate(PodcastBase):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class PodcastUpdate(PodcastBase):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class PodcastRead(PodcastBase, IDModel, TimestampModel):
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
|
|
||||||
class PodcastGenerateRequest(BaseModel):
|
|
||||||
type: Literal["DOCUMENT", "CHAT"]
|
|
||||||
ids: list[int]
|
|
||||||
search_space_id: int
|
|
||||||
podcast_title: str | None = None
|
|
||||||
user_prompt: str | None = None
|
|
||||||
|
|
@ -1,75 +0,0 @@
|
||||||
from collections.abc import AsyncGenerator
|
|
||||||
from typing import Any
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.agents.researcher.graph import graph as researcher_graph
|
|
||||||
from app.agents.researcher.state import State
|
|
||||||
from app.services.streaming_service import StreamingService
|
|
||||||
|
|
||||||
|
|
||||||
async def stream_connector_search_results(
|
|
||||||
user_query: str,
|
|
||||||
user_id: str | UUID,
|
|
||||||
search_space_id: int,
|
|
||||||
session: AsyncSession,
|
|
||||||
research_mode: str,
|
|
||||||
selected_connectors: list[str],
|
|
||||||
langchain_chat_history: list[Any],
|
|
||||||
document_ids_to_add_in_context: list[int],
|
|
||||||
language: str | None = None,
|
|
||||||
top_k: int = 10,
|
|
||||||
) -> AsyncGenerator[str, None]:
|
|
||||||
"""
|
|
||||||
Stream connector search results to the client
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_query: The user's query
|
|
||||||
user_id: The user's ID (can be UUID object or string)
|
|
||||||
search_space_id: The search space ID
|
|
||||||
session: The database session
|
|
||||||
research_mode: The research mode
|
|
||||||
selected_connectors: List of selected connectors
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
str: Formatted response strings
|
|
||||||
"""
|
|
||||||
streaming_service = StreamingService()
|
|
||||||
|
|
||||||
# Convert UUID to string if needed
|
|
||||||
user_id_str = str(user_id) if isinstance(user_id, UUID) else user_id
|
|
||||||
|
|
||||||
# Sample configuration
|
|
||||||
config = {
|
|
||||||
"configurable": {
|
|
||||||
"user_query": user_query,
|
|
||||||
"connectors_to_search": selected_connectors,
|
|
||||||
"user_id": user_id_str,
|
|
||||||
"search_space_id": search_space_id,
|
|
||||||
"document_ids_to_add_in_context": document_ids_to_add_in_context,
|
|
||||||
"language": language, # Add language to the configuration
|
|
||||||
"top_k": top_k, # Add top_k to the configuration
|
|
||||||
}
|
|
||||||
}
|
|
||||||
# print(f"Researcher configuration: {config['configurable']}") # Debug print
|
|
||||||
# Initialize state with database session and streaming service
|
|
||||||
initial_state = State(
|
|
||||||
db_session=session,
|
|
||||||
streaming_service=streaming_service,
|
|
||||||
chat_history=langchain_chat_history,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run the graph directly
|
|
||||||
print("\nRunning the complete researcher workflow...")
|
|
||||||
|
|
||||||
# Use streaming with config parameter
|
|
||||||
async for chunk in researcher_graph.astream(
|
|
||||||
initial_state,
|
|
||||||
config=config,
|
|
||||||
stream_mode="custom",
|
|
||||||
):
|
|
||||||
if isinstance(chunk, dict) and "yield_value" in chunk:
|
|
||||||
yield chunk["yield_value"]
|
|
||||||
|
|
||||||
yield streaming_service.format_completion()
|
|
||||||
|
|
@ -18,7 +18,7 @@ from app.agents.new_chat.llm_config import (
|
||||||
create_chat_litellm_from_config,
|
create_chat_litellm_from_config,
|
||||||
load_llm_config_from_yaml,
|
load_llm_config_from_yaml,
|
||||||
)
|
)
|
||||||
from app.schemas.chats import ChatMessage
|
from app.schemas.new_chat import ChatMessage
|
||||||
from app.services.connector_service import ConnectorService
|
from app.services.connector_service import ConnectorService
|
||||||
from app.services.new_streaming_service import VercelStreamingService
|
from app.services.new_streaming_service import VercelStreamingService
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue