mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
- Add support for DeepSeek, Qwen (Alibaba), Kimi (Moonshot), and GLM (Zhipu) - Implement auto-fill API Base URL when selecting Chinese LLM providers - Add smart validation and warnings for missing API endpoints - Fix session state management in task logging service - Add comprehensive Chinese setup documentation - Add database migration for new LLM provider enums Closes #383
468 lines
15 KiB
Python
468 lines
15 KiB
Python
from collections.abc import AsyncGenerator
|
||
from datetime import UTC, datetime
|
||
from enum import Enum
|
||
|
||
from fastapi import Depends
|
||
from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase
|
||
from pgvector.sqlalchemy import Vector
|
||
from sqlalchemy import (
|
||
ARRAY,
|
||
JSON,
|
||
TIMESTAMP,
|
||
Boolean,
|
||
Column,
|
||
Enum as SQLAlchemyEnum,
|
||
ForeignKey,
|
||
Integer,
|
||
String,
|
||
Text,
|
||
UniqueConstraint,
|
||
text,
|
||
)
|
||
from sqlalchemy.dialects.postgresql import UUID
|
||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||
from sqlalchemy.orm import DeclarativeBase, Mapped, declared_attr, relationship
|
||
|
||
from app.config import config
|
||
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||
from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever
|
||
|
||
if config.AUTH_TYPE == "GOOGLE":
|
||
from fastapi_users.db import SQLAlchemyBaseOAuthAccountTableUUID
|
||
|
||
DATABASE_URL = config.DATABASE_URL
|
||
|
||
|
||
class DocumentType(str, Enum):
|
||
EXTENSION = "EXTENSION"
|
||
CRAWLED_URL = "CRAWLED_URL"
|
||
FILE = "FILE"
|
||
SLACK_CONNECTOR = "SLACK_CONNECTOR"
|
||
NOTION_CONNECTOR = "NOTION_CONNECTOR"
|
||
YOUTUBE_VIDEO = "YOUTUBE_VIDEO"
|
||
GITHUB_CONNECTOR = "GITHUB_CONNECTOR"
|
||
LINEAR_CONNECTOR = "LINEAR_CONNECTOR"
|
||
DISCORD_CONNECTOR = "DISCORD_CONNECTOR"
|
||
JIRA_CONNECTOR = "JIRA_CONNECTOR"
|
||
CONFLUENCE_CONNECTOR = "CONFLUENCE_CONNECTOR"
|
||
CLICKUP_CONNECTOR = "CLICKUP_CONNECTOR"
|
||
GOOGLE_CALENDAR_CONNECTOR = "GOOGLE_CALENDAR_CONNECTOR"
|
||
GOOGLE_GMAIL_CONNECTOR = "GOOGLE_GMAIL_CONNECTOR"
|
||
AIRTABLE_CONNECTOR = "AIRTABLE_CONNECTOR"
|
||
LUMA_CONNECTOR = "LUMA_CONNECTOR"
|
||
|
||
|
||
class SearchSourceConnectorType(str, Enum):
|
||
SERPER_API = "SERPER_API" # NOT IMPLEMENTED YET : DON'T REMEMBER WHY : MOST PROBABLY BECAUSE WE NEED TO CRAWL THE RESULTS RETURNED BY IT
|
||
TAVILY_API = "TAVILY_API"
|
||
LINKUP_API = "LINKUP_API"
|
||
SLACK_CONNECTOR = "SLACK_CONNECTOR"
|
||
NOTION_CONNECTOR = "NOTION_CONNECTOR"
|
||
GITHUB_CONNECTOR = "GITHUB_CONNECTOR"
|
||
LINEAR_CONNECTOR = "LINEAR_CONNECTOR"
|
||
DISCORD_CONNECTOR = "DISCORD_CONNECTOR"
|
||
JIRA_CONNECTOR = "JIRA_CONNECTOR"
|
||
CONFLUENCE_CONNECTOR = "CONFLUENCE_CONNECTOR"
|
||
CLICKUP_CONNECTOR = "CLICKUP_CONNECTOR"
|
||
GOOGLE_CALENDAR_CONNECTOR = "GOOGLE_CALENDAR_CONNECTOR"
|
||
GOOGLE_GMAIL_CONNECTOR = "GOOGLE_GMAIL_CONNECTOR"
|
||
AIRTABLE_CONNECTOR = "AIRTABLE_CONNECTOR"
|
||
LUMA_CONNECTOR = "LUMA_CONNECTOR"
|
||
|
||
|
||
class ChatType(str, Enum):
|
||
QNA = "QNA"
|
||
REPORT_GENERAL = "REPORT_GENERAL"
|
||
REPORT_DEEP = "REPORT_DEEP"
|
||
REPORT_DEEPER = "REPORT_DEEPER"
|
||
|
||
|
||
class LiteLLMProvider(str, Enum):
|
||
"""
|
||
Enum for LLM providers supported by LiteLLM.
|
||
LiteLLM 支持的 LLM 提供商枚举。
|
||
"""
|
||
OPENAI = "OPENAI"
|
||
ANTHROPIC = "ANTHROPIC"
|
||
GROQ = "GROQ"
|
||
COHERE = "COHERE"
|
||
HUGGINGFACE = "HUGGINGFACE"
|
||
AZURE_OPENAI = "AZURE_OPENAI"
|
||
GOOGLE = "GOOGLE"
|
||
AWS_BEDROCK = "AWS_BEDROCK"
|
||
OLLAMA = "OLLAMA"
|
||
MISTRAL = "MISTRAL"
|
||
TOGETHER_AI = "TOGETHER_AI"
|
||
OPENROUTER = "OPENROUTER"
|
||
REPLICATE = "REPLICATE"
|
||
PALM = "PALM"
|
||
VERTEX_AI = "VERTEX_AI"
|
||
ANYSCALE = "ANYSCALE"
|
||
PERPLEXITY = "PERPLEXITY"
|
||
DEEPINFRA = "DEEPINFRA"
|
||
AI21 = "AI21"
|
||
NLPCLOUD = "NLPCLOUD"
|
||
ALEPH_ALPHA = "ALEPH_ALPHA"
|
||
PETALS = "PETALS"
|
||
COMETAPI = "COMETAPI"
|
||
# Chinese LLM Providers (OpenAI-compatible) / 国产 LLM 提供商(OpenAI 兼容)
|
||
DEEPSEEK = "DEEPSEEK" # DeepSeek
|
||
ALIBABA_QWEN = "ALIBABA_QWEN" # 阿里通义千问
|
||
MOONSHOT = "MOONSHOT" # 月之暗面 (Kimi)
|
||
ZHIPU = "ZHIPU" # 智谱 AI (GLM)
|
||
CUSTOM = "CUSTOM"
|
||
|
||
|
||
class LogLevel(str, Enum):
|
||
DEBUG = "DEBUG"
|
||
INFO = "INFO"
|
||
WARNING = "WARNING"
|
||
ERROR = "ERROR"
|
||
CRITICAL = "CRITICAL"
|
||
|
||
|
||
class LogStatus(str, Enum):
|
||
IN_PROGRESS = "IN_PROGRESS"
|
||
SUCCESS = "SUCCESS"
|
||
FAILED = "FAILED"
|
||
|
||
|
||
class Base(DeclarativeBase):
|
||
pass
|
||
|
||
|
||
class TimestampMixin:
|
||
@declared_attr
|
||
def created_at(cls): # noqa: N805
|
||
return Column(
|
||
TIMESTAMP(timezone=True),
|
||
nullable=False,
|
||
default=lambda: datetime.now(UTC),
|
||
index=True,
|
||
)
|
||
|
||
|
||
class BaseModel(Base):
|
||
__abstract__ = True
|
||
__allow_unmapped__ = True
|
||
|
||
id = Column(Integer, primary_key=True, index=True)
|
||
|
||
|
||
class Chat(BaseModel, TimestampMixin):
|
||
__tablename__ = "chats"
|
||
|
||
type = Column(SQLAlchemyEnum(ChatType), nullable=False)
|
||
title = Column(String, nullable=False, index=True)
|
||
initial_connectors = Column(ARRAY(String), nullable=True)
|
||
messages = Column(JSON, nullable=False)
|
||
|
||
search_space_id = Column(
|
||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||
)
|
||
search_space = relationship("SearchSpace", back_populates="chats")
|
||
|
||
|
||
class Document(BaseModel, TimestampMixin):
|
||
__tablename__ = "documents"
|
||
|
||
title = Column(String, nullable=False, index=True)
|
||
document_type = Column(SQLAlchemyEnum(DocumentType), nullable=False)
|
||
document_metadata = Column(JSON, nullable=True)
|
||
|
||
content = Column(Text, nullable=False)
|
||
content_hash = Column(String, nullable=False, index=True, unique=True)
|
||
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
||
|
||
search_space_id = Column(
|
||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||
)
|
||
search_space = relationship("SearchSpace", back_populates="documents")
|
||
chunks = relationship(
|
||
"Chunk", back_populates="document", cascade="all, delete-orphan"
|
||
)
|
||
|
||
|
||
class Chunk(BaseModel, TimestampMixin):
|
||
__tablename__ = "chunks"
|
||
|
||
content = Column(Text, nullable=False)
|
||
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
||
|
||
document_id = Column(
|
||
Integer, ForeignKey("documents.id", ondelete="CASCADE"), nullable=False
|
||
)
|
||
document = relationship("Document", back_populates="chunks")
|
||
|
||
|
||
class Podcast(BaseModel, TimestampMixin):
|
||
__tablename__ = "podcasts"
|
||
|
||
title = Column(String, nullable=False, index=True)
|
||
podcast_transcript = Column(JSON, nullable=False, default={})
|
||
file_location = Column(String(500), nullable=False, default="")
|
||
|
||
search_space_id = Column(
|
||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||
)
|
||
search_space = relationship("SearchSpace", back_populates="podcasts")
|
||
|
||
|
||
class SearchSpace(BaseModel, TimestampMixin):
|
||
__tablename__ = "searchspaces"
|
||
|
||
name = Column(String(100), nullable=False, index=True)
|
||
description = Column(String(500), nullable=True)
|
||
|
||
user_id = Column(
|
||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||
)
|
||
user = relationship("User", back_populates="search_spaces")
|
||
|
||
documents = relationship(
|
||
"Document",
|
||
back_populates="search_space",
|
||
order_by="Document.id",
|
||
cascade="all, delete-orphan",
|
||
)
|
||
podcasts = relationship(
|
||
"Podcast",
|
||
back_populates="search_space",
|
||
order_by="Podcast.id",
|
||
cascade="all, delete-orphan",
|
||
)
|
||
chats = relationship(
|
||
"Chat",
|
||
back_populates="search_space",
|
||
order_by="Chat.id",
|
||
cascade="all, delete-orphan",
|
||
)
|
||
logs = relationship(
|
||
"Log",
|
||
back_populates="search_space",
|
||
order_by="Log.id",
|
||
cascade="all, delete-orphan",
|
||
)
|
||
search_source_connectors = relationship(
|
||
"SearchSourceConnector",
|
||
back_populates="search_space",
|
||
order_by="SearchSourceConnector.id",
|
||
cascade="all, delete-orphan",
|
||
)
|
||
llm_configs = relationship(
|
||
"LLMConfig",
|
||
back_populates="search_space",
|
||
order_by="LLMConfig.id",
|
||
cascade="all, delete-orphan",
|
||
)
|
||
user_preferences = relationship(
|
||
"UserSearchSpacePreference",
|
||
back_populates="search_space",
|
||
cascade="all, delete-orphan",
|
||
)
|
||
|
||
|
||
class SearchSourceConnector(BaseModel, TimestampMixin):
|
||
__tablename__ = "search_source_connectors"
|
||
__table_args__ = (
|
||
UniqueConstraint(
|
||
"search_space_id",
|
||
"user_id",
|
||
"connector_type",
|
||
name="uq_searchspace_user_connector_type",
|
||
),
|
||
)
|
||
|
||
name = Column(String(100), nullable=False, index=True)
|
||
connector_type = Column(SQLAlchemyEnum(SearchSourceConnectorType), nullable=False)
|
||
is_indexable = Column(Boolean, nullable=False, default=False)
|
||
last_indexed_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
||
config = Column(JSON, nullable=False)
|
||
|
||
search_space_id = Column(
|
||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||
)
|
||
search_space = relationship(
|
||
"SearchSpace", back_populates="search_source_connectors"
|
||
)
|
||
|
||
user_id = Column(
|
||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||
)
|
||
|
||
|
||
class LLMConfig(BaseModel, TimestampMixin):
|
||
__tablename__ = "llm_configs"
|
||
|
||
name = Column(String(100), nullable=False, index=True)
|
||
# Provider from the enum
|
||
provider = Column(SQLAlchemyEnum(LiteLLMProvider), nullable=False)
|
||
# Custom provider name when provider is CUSTOM
|
||
custom_provider = Column(String(100), nullable=True)
|
||
# Just the model name without provider prefix
|
||
model_name = Column(String(100), nullable=False)
|
||
# API Key should be encrypted before storing
|
||
api_key = Column(String, nullable=False)
|
||
api_base = Column(String(500), nullable=True)
|
||
|
||
# For any other parameters that litellm supports
|
||
litellm_params = Column(JSON, nullable=True, default={})
|
||
|
||
search_space_id = Column(
|
||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||
)
|
||
search_space = relationship("SearchSpace", back_populates="llm_configs")
|
||
|
||
|
||
class UserSearchSpacePreference(BaseModel, TimestampMixin):
|
||
__tablename__ = "user_search_space_preferences"
|
||
__table_args__ = (
|
||
UniqueConstraint(
|
||
"user_id",
|
||
"search_space_id",
|
||
name="uq_user_searchspace",
|
||
),
|
||
)
|
||
|
||
user_id = Column(
|
||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||
)
|
||
search_space_id = Column(
|
||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||
)
|
||
|
||
# User-specific LLM preferences for this search space
|
||
long_context_llm_id = Column(
|
||
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||
)
|
||
fast_llm_id = Column(
|
||
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||
)
|
||
strategic_llm_id = Column(
|
||
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||
)
|
||
|
||
# Future RBAC fields can be added here
|
||
# role = Column(String(50), nullable=True) # e.g., 'owner', 'editor', 'viewer'
|
||
# permissions = Column(JSON, nullable=True)
|
||
|
||
user = relationship("User", back_populates="search_space_preferences")
|
||
search_space = relationship("SearchSpace", back_populates="user_preferences")
|
||
|
||
long_context_llm = relationship(
|
||
"LLMConfig", foreign_keys=[long_context_llm_id], post_update=True
|
||
)
|
||
fast_llm = relationship("LLMConfig", foreign_keys=[fast_llm_id], post_update=True)
|
||
strategic_llm = relationship(
|
||
"LLMConfig", foreign_keys=[strategic_llm_id], post_update=True
|
||
)
|
||
|
||
|
||
class Log(BaseModel, TimestampMixin):
|
||
__tablename__ = "logs"
|
||
|
||
level = Column(SQLAlchemyEnum(LogLevel), nullable=False, index=True)
|
||
status = Column(SQLAlchemyEnum(LogStatus), nullable=False, index=True)
|
||
message = Column(Text, nullable=False)
|
||
source = Column(
|
||
String(200), nullable=True, index=True
|
||
) # Service/component that generated the log
|
||
log_metadata = Column(JSON, nullable=True, default={}) # Additional context data
|
||
|
||
search_space_id = Column(
|
||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||
)
|
||
search_space = relationship("SearchSpace", back_populates="logs")
|
||
|
||
|
||
if config.AUTH_TYPE == "GOOGLE":
|
||
|
||
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
|
||
pass
|
||
|
||
class User(SQLAlchemyBaseUserTableUUID, Base):
|
||
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
|
||
"OAuthAccount", lazy="joined"
|
||
)
|
||
search_spaces = relationship("SearchSpace", back_populates="user")
|
||
search_space_preferences = relationship(
|
||
"UserSearchSpacePreference",
|
||
back_populates="user",
|
||
cascade="all, delete-orphan",
|
||
)
|
||
|
||
else:
|
||
|
||
class User(SQLAlchemyBaseUserTableUUID, Base):
|
||
search_spaces = relationship("SearchSpace", back_populates="user")
|
||
search_space_preferences = relationship(
|
||
"UserSearchSpacePreference",
|
||
back_populates="user",
|
||
cascade="all, delete-orphan",
|
||
)
|
||
|
||
|
||
engine = create_async_engine(DATABASE_URL)
|
||
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
|
||
|
||
|
||
async def setup_indexes():
|
||
async with engine.begin() as conn:
|
||
# Create indexes
|
||
# Document Summary Indexes
|
||
await conn.execute(
|
||
text(
|
||
"CREATE INDEX IF NOT EXISTS document_vector_index ON documents USING hnsw (embedding public.vector_cosine_ops)"
|
||
)
|
||
)
|
||
await conn.execute(
|
||
text(
|
||
"CREATE INDEX IF NOT EXISTS document_search_index ON documents USING gin (to_tsvector('english', content))"
|
||
)
|
||
)
|
||
# Document Chuck Indexes
|
||
await conn.execute(
|
||
text(
|
||
"CREATE INDEX IF NOT EXISTS chucks_vector_index ON chunks USING hnsw (embedding public.vector_cosine_ops)"
|
||
)
|
||
)
|
||
await conn.execute(
|
||
text(
|
||
"CREATE INDEX IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector('english', content))"
|
||
)
|
||
)
|
||
|
||
|
||
async def create_db_and_tables():
|
||
async with engine.begin() as conn:
|
||
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||
await conn.run_sync(Base.metadata.create_all)
|
||
await setup_indexes()
|
||
|
||
|
||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||
async with async_session_maker() as session:
|
||
yield session
|
||
|
||
|
||
if config.AUTH_TYPE == "GOOGLE":
|
||
|
||
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
|
||
yield SQLAlchemyUserDatabase(session, User, OAuthAccount)
|
||
|
||
else:
|
||
|
||
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
|
||
yield SQLAlchemyUserDatabase(session, User)
|
||
|
||
|
||
async def get_chucks_hybrid_search_retriever(
|
||
session: AsyncSession = Depends(get_async_session),
|
||
):
|
||
return ChucksHybridSearchRetriever(session)
|
||
|
||
|
||
async def get_documents_hybrid_search_retriever(
|
||
session: AsyncSession = Depends(get_async_session),
|
||
):
|
||
return DocumentHybridSearchRetriever(session)
|