mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-22 21:28:12 +02:00
Merge pull request #391 from unitagain/feature/chinese-llm-support
feat: add Chinese LLM providers support with auto-fill API Base URL
This commit is contained in:
commit
c99469bfdf
9 changed files with 565 additions and 5 deletions
|
|
@ -20,6 +20,12 @@ from app.db import Base # Assuming your Base is defined in app.db
|
|||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Override SQLAlchemy URL from environment variables when available
|
||||
# 如果环境变量提供了数据库连接字符串,则优先使用该配置
|
||||
database_url = os.getenv("DATABASE_URL")
|
||||
if database_url:
|
||||
config.set_main_option("sqlalchemy.url", database_url)
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,118 @@
|
|||
"""Add Chinese LLM providers to LiteLLMProvider enum
|
||||
添加国产 LLM 提供商到 LiteLLMProvider 枚举
|
||||
|
||||
Revision ID: 26
|
||||
Revises: 25
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "26"
|
||||
down_revision: str | None = "25"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""
|
||||
Add Chinese LLM providers to LiteLLMProvider enum.
|
||||
添加国产 LLM 提供商到 LiteLLMProvider 枚举。
|
||||
|
||||
Adds support for:
|
||||
- DEEPSEEK: DeepSeek AI models
|
||||
- ALIBABA_QWEN: Alibaba Qwen (通义千问) models
|
||||
- MOONSHOT: Moonshot AI (月之暗面 Kimi) models
|
||||
- ZHIPU: Zhipu AI (智谱 GLM) models
|
||||
"""
|
||||
|
||||
# Add DEEPSEEK to the enum if it doesn't already exist
|
||||
# 如果不存在则添加 DEEPSEEK 到枚举
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_enum
|
||||
WHERE enumtypid = 'litellmprovider'::regtype
|
||||
AND enumlabel = 'DEEPSEEK'
|
||||
) THEN
|
||||
ALTER TYPE litellmprovider ADD VALUE 'DEEPSEEK';
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
|
||||
# Add ALIBABA_QWEN to the enum if it doesn't already exist
|
||||
# 如果不存在则添加 ALIBABA_QWEN 到枚举
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_enum
|
||||
WHERE enumtypid = 'litellmprovider'::regtype
|
||||
AND enumlabel = 'ALIBABA_QWEN'
|
||||
) THEN
|
||||
ALTER TYPE litellmprovider ADD VALUE 'ALIBABA_QWEN';
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
|
||||
# Add MOONSHOT to the enum if it doesn't already exist
|
||||
# 如果不存在则添加 MOONSHOT 到枚举
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_enum
|
||||
WHERE enumtypid = 'litellmprovider'::regtype
|
||||
AND enumlabel = 'MOONSHOT'
|
||||
) THEN
|
||||
ALTER TYPE litellmprovider ADD VALUE 'MOONSHOT';
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
|
||||
# Add ZHIPU to the enum if it doesn't already exist
|
||||
# 如果不存在则添加 ZHIPU 到枚举
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_enum
|
||||
WHERE enumtypid = 'litellmprovider'::regtype
|
||||
AND enumlabel = 'ZHIPU'
|
||||
) THEN
|
||||
ALTER TYPE litellmprovider ADD VALUE 'ZHIPU';
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""
|
||||
Remove Chinese LLM providers from LiteLLMProvider enum.
|
||||
从 LiteLLMProvider 枚举中移除国产 LLM 提供商。
|
||||
|
||||
Note: PostgreSQL doesn't support removing enum values directly.
|
||||
This would require recreating the enum type and updating all dependent objects.
|
||||
For safety, this downgrade is a no-op.
|
||||
|
||||
注意:PostgreSQL 不支持直接删除枚举值。
|
||||
这需要重建枚举类型并更新所有依赖对象。
|
||||
为了安全起见,此降级操作为空操作。
|
||||
"""
|
||||
# PostgreSQL doesn't support removing enum values directly
|
||||
# This would require a complex migration recreating the enum
|
||||
# PostgreSQL 不支持直接删除枚举值
|
||||
# 这需要复杂的迁移来重建枚举
|
||||
pass
|
||||
|
||||
|
|
@ -79,6 +79,10 @@ class ChatType(str, Enum):
|
|||
|
||||
|
||||
class LiteLLMProvider(str, Enum):
|
||||
"""
|
||||
Enum for LLM providers supported by LiteLLM.
|
||||
LiteLLM 支持的 LLM 提供商枚举。
|
||||
"""
|
||||
OPENAI = "OPENAI"
|
||||
ANTHROPIC = "ANTHROPIC"
|
||||
GROQ = "GROQ"
|
||||
|
|
@ -102,6 +106,11 @@ class LiteLLMProvider(str, Enum):
|
|||
ALEPH_ALPHA = "ALEPH_ALPHA"
|
||||
PETALS = "PETALS"
|
||||
COMETAPI = "COMETAPI"
|
||||
# Chinese LLM Providers (OpenAI-compatible) / 国产 LLM 提供商(OpenAI 兼容)
|
||||
DEEPSEEK = "DEEPSEEK" # DeepSeek
|
||||
ALIBABA_QWEN = "ALIBABA_QWEN" # 阿里通义千问
|
||||
MOONSHOT = "MOONSHOT" # 月之暗面 (Kimi)
|
||||
ZHIPU = "ZHIPU" # 智谱 AI (GLM)
|
||||
CUSTOM = "CUSTOM"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1070,6 +1070,7 @@ async def process_file_in_background(
|
|||
},
|
||||
)
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to process file: {filename}",
|
||||
|
|
|
|||
|
|
@ -83,11 +83,11 @@ async def get_user_llm_instance(
|
|||
)
|
||||
return None
|
||||
|
||||
# Build the model string for litellm
|
||||
# Build the model string for litellm / 构建 LiteLLM 的模型字符串
|
||||
if llm_config.custom_provider:
|
||||
model_string = f"{llm_config.custom_provider}/{llm_config.model_name}"
|
||||
else:
|
||||
# Map provider enum to litellm format
|
||||
# Map provider enum to litellm format / 将提供商枚举映射为 LiteLLM 格式
|
||||
provider_map = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
|
|
@ -99,6 +99,11 @@ async def get_user_llm_instance(
|
|||
"AZURE_OPENAI": "azure",
|
||||
"OPENROUTER": "openrouter",
|
||||
"COMETAPI": "cometapi",
|
||||
# Chinese LLM providers (OpenAI-compatible) / 国产 LLM(OpenAI 兼容)
|
||||
"DEEPSEEK": "openai", # DeepSeek uses OpenAI-compatible API
|
||||
"ALIBABA_QWEN": "openai", # Qwen uses OpenAI-compatible API
|
||||
"MOONSHOT": "openai", # Moonshot (Kimi) uses OpenAI-compatible API
|
||||
"ZHIPU": "openai", # Zhipu (GLM) uses OpenAI-compatible API
|
||||
# Add more mappings as needed
|
||||
}
|
||||
provider_prefix = provider_map.get(
|
||||
|
|
|
|||
|
|
@ -73,6 +73,16 @@ class TaskLoggingService:
|
|||
Returns:
|
||||
Log: The updated log entry
|
||||
"""
|
||||
# Ensure session is in a valid state / 确保 session 处于有效状态
|
||||
if not self.session.is_active:
|
||||
await self.session.rollback()
|
||||
|
||||
# Refresh log_entry to avoid expired state / 刷新 log_entry 避免过期状态
|
||||
try:
|
||||
await self.session.refresh(log_entry)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Update the existing log entry
|
||||
log_entry.status = LogStatus.SUCCESS
|
||||
log_entry.message = message
|
||||
|
|
@ -114,6 +124,17 @@ class TaskLoggingService:
|
|||
Returns:
|
||||
Log: The updated log entry
|
||||
"""
|
||||
# Ensure session is in a valid state / 确保 session 处于有效状态
|
||||
if not self.session.is_active:
|
||||
await self.session.rollback()
|
||||
|
||||
# Refresh log_entry to avoid expired state / 刷新 log_entry 避免过期状态
|
||||
try:
|
||||
await self.session.refresh(log_entry)
|
||||
except Exception:
|
||||
# If refresh fails, the object might be detached / 如果刷新失败,对象可能已分离
|
||||
pass
|
||||
|
||||
# Update the existing log entry
|
||||
log_entry.status = LogStatus.FAILED
|
||||
log_entry.level = LogLevel.ERROR
|
||||
|
|
@ -161,6 +182,16 @@ class TaskLoggingService:
|
|||
Returns:
|
||||
Log: The updated log entry
|
||||
"""
|
||||
# Ensure session is in a valid state / 确保 session 处于有效状态
|
||||
if not self.session.is_active:
|
||||
await self.session.rollback()
|
||||
|
||||
# Refresh log_entry to avoid expired state / 刷新 log_entry 避免过期状态
|
||||
try:
|
||||
await self.session.refresh(log_entry)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
log_entry.message = progress_message
|
||||
|
||||
if progress_metadata:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue