SurfSense/surfsense_backend/app/db.py

492 lines
16 KiB
Python
Raw Normal View History

2025-03-14 18:53:14 -07:00
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
2025-03-14 18:53:14 -07:00
from enum import Enum
from fastapi import Depends
from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase
2025-03-14 18:53:14 -07:00
from pgvector.sqlalchemy import Vector
from sqlalchemy import (
ARRAY,
JSON,
TIMESTAMP,
2025-10-23 22:29:31 +02:00
BigInteger,
2025-03-14 18:53:14 -07:00
Boolean,
Column,
Enum as SQLAlchemyEnum,
ForeignKey,
Integer,
String,
Text,
UniqueConstraint,
2025-03-14 18:53:14 -07:00
text,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, declared_attr, relationship
from app.config import config
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever
2025-03-14 18:53:14 -07:00
if config.AUTH_TYPE == "GOOGLE":
from fastapi_users.db import SQLAlchemyBaseOAuthAccountTableUUID
2025-03-14 18:53:14 -07:00
DATABASE_URL = config.DATABASE_URL
class DocumentType(str, Enum):
EXTENSION = "EXTENSION"
CRAWLED_URL = "CRAWLED_URL"
FILE = "FILE"
SLACK_CONNECTOR = "SLACK_CONNECTOR"
NOTION_CONNECTOR = "NOTION_CONNECTOR"
2025-04-09 18:46:10 -07:00
YOUTUBE_VIDEO = "YOUTUBE_VIDEO"
GITHUB_CONNECTOR = "GITHUB_CONNECTOR"
2025-04-15 23:10:35 -07:00
LINEAR_CONNECTOR = "LINEAR_CONNECTOR"
2025-06-02 18:30:38 +07:00
DISCORD_CONNECTOR = "DISCORD_CONNECTOR"
JIRA_CONNECTOR = "JIRA_CONNECTOR"
CONFLUENCE_CONNECTOR = "CONFLUENCE_CONNECTOR"
2025-07-30 21:35:27 +02:00
CLICKUP_CONNECTOR = "CLICKUP_CONNECTOR"
GOOGLE_CALENDAR_CONNECTOR = "GOOGLE_CALENDAR_CONNECTOR"
GOOGLE_GMAIL_CONNECTOR = "GOOGLE_GMAIL_CONNECTOR"
AIRTABLE_CONNECTOR = "AIRTABLE_CONNECTOR"
2025-09-28 14:59:10 -07:00
LUMA_CONNECTOR = "LUMA_CONNECTOR"
2025-10-12 09:39:04 +05:30
ELASTICSEARCH_CONNECTOR = "ELASTICSEARCH_CONNECTOR"
2025-03-14 18:53:14 -07:00
2025-03-14 18:53:14 -07:00
class SearchSourceConnectorType(str, Enum):
SERPER_API = "SERPER_API" # NOT IMPLEMENTED YET : DON'T REMEMBER WHY : MOST PROBABLY BECAUSE WE NEED TO CRAWL THE RESULTS RETURNED BY IT
2025-03-14 18:53:14 -07:00
TAVILY_API = "TAVILY_API"
2025-10-12 20:43:45 +05:30
SEARXNG_API = "SEARXNG_API"
LINKUP_API = "LINKUP_API"
BAIDU_SEARCH_API = "BAIDU_SEARCH_API" # Baidu AI Search API for Chinese web search
2025-03-14 18:53:14 -07:00
SLACK_CONNECTOR = "SLACK_CONNECTOR"
NOTION_CONNECTOR = "NOTION_CONNECTOR"
GITHUB_CONNECTOR = "GITHUB_CONNECTOR"
2025-04-15 23:10:35 -07:00
LINEAR_CONNECTOR = "LINEAR_CONNECTOR"
2025-06-02 18:30:38 +07:00
DISCORD_CONNECTOR = "DISCORD_CONNECTOR"
JIRA_CONNECTOR = "JIRA_CONNECTOR"
CONFLUENCE_CONNECTOR = "CONFLUENCE_CONNECTOR"
2025-07-30 21:35:27 +02:00
CLICKUP_CONNECTOR = "CLICKUP_CONNECTOR"
GOOGLE_CALENDAR_CONNECTOR = "GOOGLE_CALENDAR_CONNECTOR"
GOOGLE_GMAIL_CONNECTOR = "GOOGLE_GMAIL_CONNECTOR"
AIRTABLE_CONNECTOR = "AIRTABLE_CONNECTOR"
2025-09-28 14:59:10 -07:00
LUMA_CONNECTOR = "LUMA_CONNECTOR"
2025-10-12 09:39:04 +05:30
ELASTICSEARCH_CONNECTOR = "ELASTICSEARCH_CONNECTOR"
2025-03-14 18:53:14 -07:00
class ChatType(str, Enum):
2025-06-03 00:10:35 -07:00
QNA = "QNA"
2025-06-09 15:50:15 -07:00
2025-06-09 15:50:15 -07:00
class LiteLLMProvider(str, Enum):
"""
Enum for LLM providers supported by LiteLLM.
"""
2025-06-09 15:50:15 -07:00
OPENAI = "OPENAI"
ANTHROPIC = "ANTHROPIC"
GROQ = "GROQ"
COHERE = "COHERE"
HUGGINGFACE = "HUGGINGFACE"
AZURE_OPENAI = "AZURE_OPENAI"
GOOGLE = "GOOGLE"
AWS_BEDROCK = "AWS_BEDROCK"
OLLAMA = "OLLAMA"
MISTRAL = "MISTRAL"
TOGETHER_AI = "TOGETHER_AI"
2025-09-16 18:16:33 -07:00
OPENROUTER = "OPENROUTER"
2025-06-09 15:50:15 -07:00
REPLICATE = "REPLICATE"
PALM = "PALM"
VERTEX_AI = "VERTEX_AI"
ANYSCALE = "ANYSCALE"
PERPLEXITY = "PERPLEXITY"
DEEPINFRA = "DEEPINFRA"
AI21 = "AI21"
NLPCLOUD = "NLPCLOUD"
ALEPH_ALPHA = "ALEPH_ALPHA"
PETALS = "PETALS"
COMETAPI = "COMETAPI"
# Chinese LLM Providers (OpenAI-compatible)
DEEPSEEK = "DEEPSEEK"
ALIBABA_QWEN = "ALIBABA_QWEN"
MOONSHOT = "MOONSHOT"
ZHIPU = "ZHIPU"
2025-06-09 15:50:15 -07:00
CUSTOM = "CUSTOM"
class LogLevel(str, Enum):
DEBUG = "DEBUG"
INFO = "INFO"
WARNING = "WARNING"
ERROR = "ERROR"
CRITICAL = "CRITICAL"
class LogStatus(str, Enum):
IN_PROGRESS = "IN_PROGRESS"
SUCCESS = "SUCCESS"
FAILED = "FAILED"
2025-03-14 18:53:14 -07:00
class Base(DeclarativeBase):
pass
2025-03-14 18:53:14 -07:00
class TimestampMixin:
@declared_attr
def created_at(cls): # noqa: N805
return Column(
TIMESTAMP(timezone=True),
nullable=False,
default=lambda: datetime.now(UTC),
index=True,
)
2025-03-14 18:53:14 -07:00
class BaseModel(Base):
__abstract__ = True
__allow_unmapped__ = True
id = Column(Integer, primary_key=True, index=True)
2025-03-14 18:53:14 -07:00
class Chat(BaseModel, TimestampMixin):
__tablename__ = "chats"
type = Column(SQLAlchemyEnum(ChatType), nullable=False)
2025-04-30 00:10:50 -07:00
title = Column(String, nullable=False, index=True)
2025-03-14 18:53:14 -07:00
initial_connectors = Column(ARRAY(String), nullable=True)
messages = Column(JSON, nullable=False)
2025-10-23 22:29:31 +02:00
state_version = Column(BigInteger, nullable=False, default=1)
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
)
search_space = relationship("SearchSpace", back_populates="chats")
2025-03-14 18:53:14 -07:00
class Document(BaseModel, TimestampMixin):
__tablename__ = "documents"
2025-04-30 00:10:50 -07:00
title = Column(String, nullable=False, index=True)
2025-03-14 18:53:14 -07:00
document_type = Column(SQLAlchemyEnum(DocumentType), nullable=False)
document_metadata = Column(JSON, nullable=True)
2025-03-14 18:53:14 -07:00
content = Column(Text, nullable=False)
content_hash = Column(String, nullable=False, index=True, unique=True)
unique_identifier_hash = Column(String, nullable=True, index=True, unique=True)
2025-03-14 18:53:14 -07:00
embedding = Column(Vector(config.embedding_model_instance.dimension))
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
)
2025-03-14 18:53:14 -07:00
search_space = relationship("SearchSpace", back_populates="documents")
chunks = relationship(
"Chunk", back_populates="document", cascade="all, delete-orphan"
)
2025-03-14 18:53:14 -07:00
class Chunk(BaseModel, TimestampMixin):
__tablename__ = "chunks"
2025-03-14 18:53:14 -07:00
content = Column(Text, nullable=False)
embedding = Column(Vector(config.embedding_model_instance.dimension))
document_id = Column(
Integer, ForeignKey("documents.id", ondelete="CASCADE"), nullable=False
)
2025-03-14 18:53:14 -07:00
document = relationship("Document", back_populates="chunks")
2025-03-14 18:53:14 -07:00
class Podcast(BaseModel, TimestampMixin):
__tablename__ = "podcasts"
2025-04-30 00:10:50 -07:00
title = Column(String, nullable=False, index=True)
podcast_transcript = Column(JSON, nullable=False, default={})
2025-03-14 18:53:14 -07:00
file_location = Column(String(500), nullable=False, default="")
2025-10-23 23:49:49 +02:00
chat_id = Column(
Integer, ForeignKey("chats.id", ondelete="CASCADE"), nullable=True
) # If generated from a chat, this will be the chat id, else null ( can be from a document or a chat )
2025-10-23 22:29:31 +02:00
chat_state_version = Column(BigInteger, nullable=True)
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
)
2025-03-14 18:53:14 -07:00
search_space = relationship("SearchSpace", back_populates="podcasts")
2025-03-14 18:53:14 -07:00
class SearchSpace(BaseModel, TimestampMixin):
__tablename__ = "searchspaces"
2025-03-14 18:53:14 -07:00
name = Column(String(100), nullable=False, index=True)
description = Column(String(500), nullable=True)
user_id = Column(
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
)
2025-03-14 18:53:14 -07:00
user = relationship("User", back_populates="search_spaces")
documents = relationship(
"Document",
back_populates="search_space",
order_by="Document.id",
cascade="all, delete-orphan",
)
podcasts = relationship(
"Podcast",
back_populates="search_space",
order_by="Podcast.id",
cascade="all, delete-orphan",
)
chats = relationship(
"Chat",
back_populates="search_space",
order_by="Chat.id",
cascade="all, delete-orphan",
)
logs = relationship(
"Log",
back_populates="search_space",
order_by="Log.id",
cascade="all, delete-orphan",
)
search_source_connectors = relationship(
"SearchSourceConnector",
back_populates="search_space",
order_by="SearchSourceConnector.id",
cascade="all, delete-orphan",
)
llm_configs = relationship(
"LLMConfig",
back_populates="search_space",
order_by="LLMConfig.id",
cascade="all, delete-orphan",
)
user_preferences = relationship(
"UserSearchSpacePreference",
back_populates="search_space",
cascade="all, delete-orphan",
)
2025-03-14 18:53:14 -07:00
class SearchSourceConnector(BaseModel, TimestampMixin):
__tablename__ = "search_source_connectors"
__table_args__ = (
UniqueConstraint(
"search_space_id",
"user_id",
"connector_type",
name="uq_searchspace_user_connector_type",
),
)
2025-03-14 18:53:14 -07:00
name = Column(String(100), nullable=False, index=True)
connector_type = Column(SQLAlchemyEnum(SearchSourceConnectorType), nullable=False)
2025-03-14 18:53:14 -07:00
is_indexable = Column(Boolean, nullable=False, default=False)
last_indexed_at = Column(TIMESTAMP(timezone=True), nullable=True)
config = Column(JSON, nullable=False)
# Periodic indexing fields
periodic_indexing_enabled = Column(Boolean, nullable=False, default=False)
indexing_frequency_minutes = Column(Integer, nullable=True)
next_scheduled_at = Column(TIMESTAMP(timezone=True), nullable=True)
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
)
search_space = relationship(
"SearchSpace", back_populates="search_source_connectors"
)
user_id = Column(
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
)
2025-03-14 18:53:14 -07:00
2025-06-09 15:50:15 -07:00
class LLMConfig(BaseModel, TimestampMixin):
__tablename__ = "llm_configs"
2025-06-09 15:50:15 -07:00
name = Column(String(100), nullable=False, index=True)
# Provider from the enum
provider = Column(SQLAlchemyEnum(LiteLLMProvider), nullable=False)
# Custom provider name when provider is CUSTOM
custom_provider = Column(String(100), nullable=True)
# Just the model name without provider prefix
model_name = Column(String(100), nullable=False)
# API Key should be encrypted before storing
api_key = Column(String, nullable=False)
api_base = Column(String(500), nullable=True)
language = Column(String(50), nullable=True, default="English")
2025-06-09 15:50:15 -07:00
# For any other parameters that litellm supports
litellm_params = Column(JSON, nullable=True, default={})
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
)
search_space = relationship("SearchSpace", back_populates="llm_configs")
class UserSearchSpacePreference(BaseModel, TimestampMixin):
__tablename__ = "user_search_space_preferences"
__table_args__ = (
UniqueConstraint(
"user_id",
"search_space_id",
name="uq_user_searchspace",
),
)
user_id = Column(
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
)
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
)
# User-specific LLM preferences for this search space
long_context_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
fast_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
strategic_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
# Future RBAC fields can be added here
# role = Column(String(50), nullable=True) # e.g., 'owner', 'editor', 'viewer'
# permissions = Column(JSON, nullable=True)
user = relationship("User", back_populates="search_space_preferences")
search_space = relationship("SearchSpace", back_populates="user_preferences")
long_context_llm = relationship(
"LLMConfig", foreign_keys=[long_context_llm_id], post_update=True
)
fast_llm = relationship("LLMConfig", foreign_keys=[fast_llm_id], post_update=True)
strategic_llm = relationship(
"LLMConfig", foreign_keys=[strategic_llm_id], post_update=True
)
2025-06-09 15:50:15 -07:00
class Log(BaseModel, TimestampMixin):
__tablename__ = "logs"
level = Column(SQLAlchemyEnum(LogLevel), nullable=False, index=True)
status = Column(SQLAlchemyEnum(LogStatus), nullable=False, index=True)
message = Column(Text, nullable=False)
source = Column(
String(200), nullable=True, index=True
) # Service/component that generated the log
log_metadata = Column(JSON, nullable=True, default={}) # Additional context data
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
)
search_space = relationship("SearchSpace", back_populates="logs")
if config.AUTH_TYPE == "GOOGLE":
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
pass
2025-03-14 18:53:14 -07:00
class User(SQLAlchemyBaseUserTableUUID, Base):
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
"OAuthAccount", lazy="joined"
)
search_spaces = relationship("SearchSpace", back_populates="user")
search_space_preferences = relationship(
"UserSearchSpacePreference",
back_populates="user",
cascade="all, delete-orphan",
)
2025-06-09 15:50:15 -07:00
# Page usage tracking for ETL services
pages_limit = Column(Integer, nullable=False, default=500, server_default="500")
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
else:
2025-03-14 18:53:14 -07:00
class User(SQLAlchemyBaseUserTableUUID, Base):
search_spaces = relationship("SearchSpace", back_populates="user")
search_space_preferences = relationship(
"UserSearchSpacePreference",
back_populates="user",
cascade="all, delete-orphan",
)
2025-06-09 15:50:15 -07:00
# Page usage tracking for ETL services
pages_limit = Column(Integer, nullable=False, default=500, server_default="500")
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
2025-03-14 18:53:14 -07:00
engine = create_async_engine(DATABASE_URL)
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
2025-03-14 18:53:14 -07:00
async def setup_indexes():
async with engine.begin() as conn:
# Create indexes
2025-03-14 18:53:14 -07:00
# Document Summary Indexes
await conn.execute(
text(
"CREATE INDEX IF NOT EXISTS document_vector_index ON documents USING hnsw (embedding public.vector_cosine_ops)"
)
)
await conn.execute(
text(
"CREATE INDEX IF NOT EXISTS document_search_index ON documents USING gin (to_tsvector('english', content))"
)
)
2025-03-14 18:53:14 -07:00
# Document Chuck Indexes
await conn.execute(
text(
"CREATE INDEX IF NOT EXISTS chucks_vector_index ON chunks USING hnsw (embedding public.vector_cosine_ops)"
)
)
await conn.execute(
text(
"CREATE INDEX IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector('english', content))"
)
)
2025-03-14 18:53:14 -07:00
async def create_db_and_tables():
async with engine.begin() as conn:
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
2025-03-14 18:53:14 -07:00
await conn.run_sync(Base.metadata.create_all)
await setup_indexes()
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
async with async_session_maker() as session:
yield session
if config.AUTH_TYPE == "GOOGLE":
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
yield SQLAlchemyUserDatabase(session, User, OAuthAccount)
else:
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
yield SQLAlchemyUserDatabase(session, User)
async def get_chucks_hybrid_search_retriever(
session: AsyncSession = Depends(get_async_session),
):
2025-03-14 18:53:14 -07:00
return ChucksHybridSearchRetriever(session)
async def get_documents_hybrid_search_retriever(
session: AsyncSession = Depends(get_async_session),
):
return DocumentHybridSearchRetriever(session)