mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-09 07:42:39 +02:00
feat: added missed migration
This commit is contained in:
parent
9429c2b06b
commit
a3f50ebc4d
11 changed files with 126 additions and 38 deletions
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
Revision ID: '23'
|
Revision ID: '23'
|
||||||
Revises: '22'
|
Revises: '22'
|
||||||
Create Date: 2025-01-10 12:00:00.000000
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
Revision ID: 24
|
Revision ID: 24
|
||||||
Revises: 23
|
Revises: 23
|
||||||
Create Date: 2025-01-10 14:00:00.000000
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
Revision ID: 25
|
Revision ID: 25
|
||||||
Revises: 24
|
Revises: 24
|
||||||
Create Date: 2025-01-10 14:00:00.000000
|
|
||||||
|
|
||||||
Changes:
|
Changes:
|
||||||
1. Migrate llm_configs from user association to search_space association
|
1. Migrate llm_configs from user association to search_space association
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
"""Add language column to llm_configs
|
||||||
|
|
||||||
|
Revision ID: 26
|
||||||
|
Revises: 25
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
1. Add language column to llm_configs table with default value of 'English'
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "26"
|
||||||
|
down_revision: str | None = "25"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Add language column to llm_configs table."""
|
||||||
|
|
||||||
|
from sqlalchemy import inspect
|
||||||
|
|
||||||
|
conn = op.get_bind()
|
||||||
|
inspector = inspect(conn)
|
||||||
|
|
||||||
|
# Get existing columns
|
||||||
|
llm_config_columns = [col["name"] for col in inspector.get_columns("llm_configs")]
|
||||||
|
|
||||||
|
# Add language column if it doesn't exist
|
||||||
|
if "language" not in llm_config_columns:
|
||||||
|
op.add_column(
|
||||||
|
"llm_configs",
|
||||||
|
sa.Column(
|
||||||
|
"language",
|
||||||
|
sa.String(length=50),
|
||||||
|
nullable=True,
|
||||||
|
server_default="English",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update existing rows to have 'English' as default
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
UPDATE llm_configs
|
||||||
|
SET language = 'English'
|
||||||
|
WHERE language IS NULL
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Remove language column from llm_configs table."""
|
||||||
|
|
||||||
|
from sqlalchemy import inspect
|
||||||
|
|
||||||
|
conn = op.get_bind()
|
||||||
|
inspector = inspect(conn)
|
||||||
|
|
||||||
|
# Get existing columns
|
||||||
|
llm_config_columns = [col["name"] for col in inspector.get_columns("llm_configs")]
|
||||||
|
|
||||||
|
# Drop language column if it exists
|
||||||
|
if "language" in llm_config_columns:
|
||||||
|
op.drop_column("llm_configs", "language")
|
||||||
|
|
@ -37,8 +37,7 @@ class Configuration:
|
||||||
search_mode: SearchMode
|
search_mode: SearchMode
|
||||||
research_mode: ResearchMode
|
research_mode: ResearchMode
|
||||||
document_ids_to_add_in_context: list[int]
|
document_ids_to_add_in_context: list[int]
|
||||||
language: str | None = None
|
language: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_runnable_config(
|
def from_runnable_config(
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
def _build_language_instruction(language: str | None = None):
|
def _build_language_instruction(language: str | None = None):
|
||||||
if language:
|
if language:
|
||||||
return f"\n\nIMPORTANT: Please respond in {language} language. All your responses, explanations, and analysis should be written in {language}."
|
return f"\n\nIMPORTANT: Please respond in {language} language. All your responses, explanations, and analysis should be written in {language}."
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def get_answer_outline_system_prompt(language: str | None = None) -> str:
|
def get_answer_outline_system_prompt(language: str | None = None) -> str:
|
||||||
language_instruction = _build_language_instruction(language)
|
language_instruction = _build_language_instruction(language)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -102,7 +102,7 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any
|
||||||
user_query = configuration.user_query
|
user_query = configuration.user_query
|
||||||
user_id = configuration.user_id
|
user_id = configuration.user_id
|
||||||
search_space_id = configuration.search_space_id
|
search_space_id = configuration.search_space_id
|
||||||
language = configuration.language
|
language = configuration.language
|
||||||
# Get user's fast LLM
|
# Get user's fast LLM
|
||||||
llm = await get_user_fast_llm(state.db_session, user_id, search_space_id)
|
llm = await get_user_fast_llm(state.db_session, user_id, search_space_id)
|
||||||
if not llm:
|
if not llm:
|
||||||
|
|
@ -127,7 +127,9 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Use initial system prompt for token calculation
|
# Use initial system prompt for token calculation
|
||||||
initial_system_prompt = get_qna_citation_system_prompt(chat_history_str, language)
|
initial_system_prompt = get_qna_citation_system_prompt(
|
||||||
|
chat_history_str, language
|
||||||
|
)
|
||||||
base_messages = [
|
base_messages = [
|
||||||
SystemMessage(content=initial_system_prompt),
|
SystemMessage(content=initial_system_prompt),
|
||||||
HumanMessage(content=base_human_message_template),
|
HumanMessage(content=base_human_message_template),
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,11 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from ..prompts import _build_language_instruction
|
from ..prompts import _build_language_instruction
|
||||||
|
|
||||||
def get_qna_citation_system_prompt(chat_history: str | None = None, language: str | None = None):
|
|
||||||
|
def get_qna_citation_system_prompt(
|
||||||
|
chat_history: str | None = None, language: str | None = None
|
||||||
|
):
|
||||||
chat_history_section = (
|
chat_history_section = (
|
||||||
f"""
|
f"""
|
||||||
<chat_history>
|
<chat_history>
|
||||||
|
|
@ -15,7 +19,7 @@ NO CHAT HISTORY PROVIDED
|
||||||
</chat_history>
|
</chat_history>
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add language instruction if specified
|
# Add language instruction if specified
|
||||||
language_instruction = _build_language_instruction(language)
|
language_instruction = _build_language_instruction(language)
|
||||||
return f"""
|
return f"""
|
||||||
|
|
@ -151,7 +155,9 @@ Make sure your response:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_qna_no_documents_system_prompt(chat_history: str | None = None, language: str | None = None):
|
def get_qna_no_documents_system_prompt(
|
||||||
|
chat_history: str | None = None, language: str | None = None
|
||||||
|
):
|
||||||
chat_history_section = (
|
chat_history_section = (
|
||||||
f"""
|
f"""
|
||||||
<chat_history>
|
<chat_history>
|
||||||
|
|
@ -165,7 +171,7 @@ NO CHAT HISTORY PROVIDED
|
||||||
</chat_history>
|
</chat_history>
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add language instruction if specified
|
# Add language instruction if specified
|
||||||
language_instruction = _build_language_instruction(language)
|
language_instruction = _build_language_instruction(language)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,11 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from ..prompts import _build_language_instruction
|
from ..prompts import _build_language_instruction
|
||||||
def get_citation_system_prompt(chat_history: str | None = None, language: str | None = None):
|
|
||||||
|
|
||||||
|
def get_citation_system_prompt(
|
||||||
|
chat_history: str | None = None, language: str | None = None
|
||||||
|
):
|
||||||
chat_history_section = (
|
chat_history_section = (
|
||||||
f"""
|
f"""
|
||||||
<chat_history>
|
<chat_history>
|
||||||
|
|
@ -14,7 +19,7 @@ NO CHAT HISTORY PROVIDED
|
||||||
</chat_history>
|
</chat_history>
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add language instruction if specified
|
# Add language instruction if specified
|
||||||
language_instruction = _build_language_instruction(language)
|
language_instruction = _build_language_instruction(language)
|
||||||
|
|
||||||
|
|
@ -158,7 +163,9 @@ Make sure your response:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_no_documents_system_prompt(chat_history: str | None = None, language: str | None = None):
|
def get_no_documents_system_prompt(
|
||||||
|
chat_history: str | None = None, language: str | None = None
|
||||||
|
):
|
||||||
chat_history_section = (
|
chat_history_section = (
|
||||||
f"""
|
f"""
|
||||||
<chat_history>
|
<chat_history>
|
||||||
|
|
@ -172,7 +179,7 @@ NO CHAT HISTORY PROVIDED
|
||||||
</chat_history>
|
</chat_history>
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add language instruction if specified
|
# Add language instruction if specified
|
||||||
language_instruction = _build_language_instruction(language)
|
language_instruction = _build_language_instruction(language)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
|
||||||
from app.db import Chat, SearchSpace, User, UserSearchSpacePreference, get_async_session
|
from app.db import Chat, SearchSpace, User, UserSearchSpacePreference, get_async_session
|
||||||
from app.schemas import (
|
from app.schemas import (
|
||||||
AISDKChatRequest,
|
AISDKChatRequest,
|
||||||
|
|
@ -64,47 +63,53 @@ async def handle_chat_data(
|
||||||
language_result = await session.execute(
|
language_result = await session.execute(
|
||||||
select(UserSearchSpacePreference)
|
select(UserSearchSpacePreference)
|
||||||
.options(
|
.options(
|
||||||
selectinload(UserSearchSpacePreference.search_space).selectinload(SearchSpace.llm_configs),
|
selectinload(UserSearchSpacePreference.search_space).selectinload(
|
||||||
|
SearchSpace.llm_configs
|
||||||
|
),
|
||||||
selectinload(UserSearchSpacePreference.long_context_llm),
|
selectinload(UserSearchSpacePreference.long_context_llm),
|
||||||
selectinload(UserSearchSpacePreference.fast_llm),
|
selectinload(UserSearchSpacePreference.fast_llm),
|
||||||
selectinload(UserSearchSpacePreference.strategic_llm)
|
selectinload(UserSearchSpacePreference.strategic_llm),
|
||||||
)
|
)
|
||||||
.filter(
|
.filter(
|
||||||
UserSearchSpacePreference.search_space_id == search_space_id,
|
UserSearchSpacePreference.search_space_id == search_space_id,
|
||||||
UserSearchSpacePreference.user_id == user.id
|
UserSearchSpacePreference.user_id == user.id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
user_preference = language_result.scalars().first()
|
user_preference = language_result.scalars().first()
|
||||||
# print("UserSearchSpacePreference:", user_preference)
|
# print("UserSearchSpacePreference:", user_preference)
|
||||||
|
|
||||||
language = None
|
language = None
|
||||||
if user_preference and user_preference.search_space and user_preference.search_space.llm_configs:
|
if (
|
||||||
|
user_preference
|
||||||
|
and user_preference.search_space
|
||||||
|
and user_preference.search_space.llm_configs
|
||||||
|
):
|
||||||
llm_configs = user_preference.search_space.llm_configs
|
llm_configs = user_preference.search_space.llm_configs
|
||||||
|
|
||||||
|
for preferred_llm in [
|
||||||
for preferred_llm in [user_preference.fast_llm, user_preference.long_context_llm, user_preference.strategic_llm]:
|
user_preference.fast_llm,
|
||||||
if preferred_llm and getattr(preferred_llm, 'language', None):
|
user_preference.long_context_llm,
|
||||||
|
user_preference.strategic_llm,
|
||||||
|
]:
|
||||||
|
if preferred_llm and getattr(preferred_llm, "language", None):
|
||||||
language = preferred_llm.language
|
language = preferred_llm.language
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
if not language:
|
if not language:
|
||||||
first_llm_config = llm_configs[0]
|
first_llm_config = llm_configs[0]
|
||||||
language = getattr(first_llm_config, 'language', None)
|
language = getattr(first_llm_config, "language", None)
|
||||||
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403, detail="You don't have access to this search space"
|
status_code=403, detail="You don't have access to this search space"
|
||||||
) from None
|
) from None
|
||||||
|
|
||||||
langchain_chat_history = []
|
langchain_chat_history = []
|
||||||
for message in messages[:-1]:
|
for message in messages[:-1]:
|
||||||
if message["role"] == "user":
|
if message["role"] == "user":
|
||||||
langchain_chat_history.append(HumanMessage(content=message["content"]))
|
langchain_chat_history.append(HumanMessage(content=message["content"]))
|
||||||
elif message["role"] == "assistant":
|
elif message["role"] == "assistant":
|
||||||
langchain_chat_history.append(AIMessage(content=message["content"]))
|
langchain_chat_history.append(AIMessage(content=message["content"]))
|
||||||
|
|
||||||
|
|
||||||
response = StreamingResponse(
|
response = StreamingResponse(
|
||||||
stream_connector_search_results(
|
stream_connector_search_results(
|
||||||
|
|
@ -117,7 +122,7 @@ async def handle_chat_data(
|
||||||
langchain_chat_history,
|
langchain_chat_history,
|
||||||
search_mode_str,
|
search_mode_str,
|
||||||
document_ids_to_add_in_context,
|
document_ids_to_add_in_context,
|
||||||
language,
|
language,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -299,10 +299,10 @@ async def update_user_llm_preferences(
|
||||||
|
|
||||||
# Validate that all provided LLM config IDs belong to the search space
|
# Validate that all provided LLM config IDs belong to the search space
|
||||||
update_data = preferences.model_dump(exclude_unset=True)
|
update_data = preferences.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
# Store language from configs to validate consistency
|
# Store language from configs to validate consistency
|
||||||
languages = set()
|
languages = set()
|
||||||
|
|
||||||
for _key, llm_config_id in update_data.items():
|
for _key, llm_config_id in update_data.items():
|
||||||
if llm_config_id is not None:
|
if llm_config_id is not None:
|
||||||
# Verify the LLM config belongs to the search space
|
# Verify the LLM config belongs to the search space
|
||||||
|
|
@ -318,10 +318,10 @@ async def update_user_llm_preferences(
|
||||||
status_code=404,
|
status_code=404,
|
||||||
detail=f"LLM configuration {llm_config_id} not found in this search space",
|
detail=f"LLM configuration {llm_config_id} not found in this search space",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Collect language for consistency check
|
# Collect language for consistency check
|
||||||
languages.add(llm_config.language)
|
languages.add(llm_config.language)
|
||||||
|
|
||||||
# Check if all selected LLM configs have the same language
|
# Check if all selected LLM configs have the same language
|
||||||
if len(languages) > 1:
|
if len(languages) > 1:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue