feat: added missed migration

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2025-10-12 20:15:27 -07:00
parent 9429c2b06b
commit a3f50ebc4d
11 changed files with 126 additions and 38 deletions

View file

@ -2,7 +2,6 @@
Revision ID: '23' Revision ID: '23'
Revises: '22' Revises: '22'
Create Date: 2025-01-10 12:00:00.000000
""" """

View file

@ -2,7 +2,6 @@
Revision ID: 24 Revision ID: 24
Revises: 23 Revises: 23
Create Date: 2025-01-10 14:00:00.000000
""" """

View file

@ -2,7 +2,6 @@
Revision ID: 25 Revision ID: 25
Revises: 24 Revises: 24
Create Date: 2025-01-10 14:00:00.000000
Changes: Changes:
1. Migrate llm_configs from user association to search_space association 1. Migrate llm_configs from user association to search_space association

View file

@ -0,0 +1,69 @@
"""Add language column to llm_configs
Revision ID: 26
Revises: 25
Changes:
1. Add language column to llm_configs table with default value of 'English'
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "26"
down_revision: str | None = "25"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Add language column to llm_configs table."""
from sqlalchemy import inspect
conn = op.get_bind()
inspector = inspect(conn)
# Get existing columns
llm_config_columns = [col["name"] for col in inspector.get_columns("llm_configs")]
# Add language column if it doesn't exist
if "language" not in llm_config_columns:
op.add_column(
"llm_configs",
sa.Column(
"language",
sa.String(length=50),
nullable=True,
server_default="English",
),
)
# Update existing rows to have 'English' as default
op.execute(
"""
UPDATE llm_configs
SET language = 'English'
WHERE language IS NULL
"""
)
def downgrade() -> None:
"""Remove language column from llm_configs table."""
from sqlalchemy import inspect
conn = op.get_bind()
inspector = inspect(conn)
# Get existing columns
llm_config_columns = [col["name"] for col in inspector.get_columns("llm_configs")]
# Drop language column if it exists
if "language" in llm_config_columns:
op.drop_column("llm_configs", "language")

View file

@ -37,8 +37,7 @@ class Configuration:
search_mode: SearchMode search_mode: SearchMode
research_mode: ResearchMode research_mode: ResearchMode
document_ids_to_add_in_context: list[int] document_ids_to_add_in_context: list[int]
language: str | None = None language: str | None = None
@classmethod @classmethod
def from_runnable_config( def from_runnable_config(

View file

@ -1,9 +1,12 @@
import datetime import datetime
def _build_language_instruction(language: str | None = None): def _build_language_instruction(language: str | None = None):
if language: if language:
return f"\n\nIMPORTANT: Please respond in {language} language. All your responses, explanations, and analysis should be written in {language}." return f"\n\nIMPORTANT: Please respond in {language} language. All your responses, explanations, and analysis should be written in {language}."
return "" return ""
def get_answer_outline_system_prompt(language: str | None = None) -> str: def get_answer_outline_system_prompt(language: str | None = None) -> str:
language_instruction = _build_language_instruction(language) language_instruction = _build_language_instruction(language)

View file

@ -102,7 +102,7 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any
user_query = configuration.user_query user_query = configuration.user_query
user_id = configuration.user_id user_id = configuration.user_id
search_space_id = configuration.search_space_id search_space_id = configuration.search_space_id
language = configuration.language language = configuration.language
# Get user's fast LLM # Get user's fast LLM
llm = await get_user_fast_llm(state.db_session, user_id, search_space_id) llm = await get_user_fast_llm(state.db_session, user_id, search_space_id)
if not llm: if not llm:
@ -127,7 +127,9 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any
""" """
# Use initial system prompt for token calculation # Use initial system prompt for token calculation
initial_system_prompt = get_qna_citation_system_prompt(chat_history_str, language) initial_system_prompt = get_qna_citation_system_prompt(
chat_history_str, language
)
base_messages = [ base_messages = [
SystemMessage(content=initial_system_prompt), SystemMessage(content=initial_system_prompt),
HumanMessage(content=base_human_message_template), HumanMessage(content=base_human_message_template),

View file

@ -1,7 +1,11 @@
import datetime import datetime
from ..prompts import _build_language_instruction from ..prompts import _build_language_instruction
def get_qna_citation_system_prompt(chat_history: str | None = None, language: str | None = None):
def get_qna_citation_system_prompt(
chat_history: str | None = None, language: str | None = None
):
chat_history_section = ( chat_history_section = (
f""" f"""
<chat_history> <chat_history>
@ -15,7 +19,7 @@ NO CHAT HISTORY PROVIDED
</chat_history> </chat_history>
""" """
) )
# Add language instruction if specified # Add language instruction if specified
language_instruction = _build_language_instruction(language) language_instruction = _build_language_instruction(language)
return f""" return f"""
@ -151,7 +155,9 @@ Make sure your response:
""" """
def get_qna_no_documents_system_prompt(chat_history: str | None = None, language: str | None = None): def get_qna_no_documents_system_prompt(
chat_history: str | None = None, language: str | None = None
):
chat_history_section = ( chat_history_section = (
f""" f"""
<chat_history> <chat_history>
@ -165,7 +171,7 @@ NO CHAT HISTORY PROVIDED
</chat_history> </chat_history>
""" """
) )
# Add language instruction if specified # Add language instruction if specified
language_instruction = _build_language_instruction(language) language_instruction = _build_language_instruction(language)

View file

@ -1,6 +1,11 @@
import datetime import datetime
from ..prompts import _build_language_instruction from ..prompts import _build_language_instruction
def get_citation_system_prompt(chat_history: str | None = None, language: str | None = None):
def get_citation_system_prompt(
chat_history: str | None = None, language: str | None = None
):
chat_history_section = ( chat_history_section = (
f""" f"""
<chat_history> <chat_history>
@ -14,7 +19,7 @@ NO CHAT HISTORY PROVIDED
</chat_history> </chat_history>
""" """
) )
# Add language instruction if specified # Add language instruction if specified
language_instruction = _build_language_instruction(language) language_instruction = _build_language_instruction(language)
@ -158,7 +163,9 @@ Make sure your response:
""" """
def get_no_documents_system_prompt(chat_history: str | None = None, language: str | None = None): def get_no_documents_system_prompt(
chat_history: str | None = None, language: str | None = None
):
chat_history_section = ( chat_history_section = (
f""" f"""
<chat_history> <chat_history>
@ -172,7 +179,7 @@ NO CHAT HISTORY PROVIDED
</chat_history> </chat_history>
""" """
) )
# Add language instruction if specified # Add language instruction if specified
language_instruction = _build_language_instruction(language) language_instruction = _build_language_instruction(language)

View file

@ -6,7 +6,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from app.db import Chat, SearchSpace, User, UserSearchSpacePreference, get_async_session from app.db import Chat, SearchSpace, User, UserSearchSpacePreference, get_async_session
from app.schemas import ( from app.schemas import (
AISDKChatRequest, AISDKChatRequest,
@ -64,47 +63,53 @@ async def handle_chat_data(
language_result = await session.execute( language_result = await session.execute(
select(UserSearchSpacePreference) select(UserSearchSpacePreference)
.options( .options(
selectinload(UserSearchSpacePreference.search_space).selectinload(SearchSpace.llm_configs), selectinload(UserSearchSpacePreference.search_space).selectinload(
SearchSpace.llm_configs
),
selectinload(UserSearchSpacePreference.long_context_llm), selectinload(UserSearchSpacePreference.long_context_llm),
selectinload(UserSearchSpacePreference.fast_llm), selectinload(UserSearchSpacePreference.fast_llm),
selectinload(UserSearchSpacePreference.strategic_llm) selectinload(UserSearchSpacePreference.strategic_llm),
) )
.filter( .filter(
UserSearchSpacePreference.search_space_id == search_space_id, UserSearchSpacePreference.search_space_id == search_space_id,
UserSearchSpacePreference.user_id == user.id UserSearchSpacePreference.user_id == user.id,
) )
) )
user_preference = language_result.scalars().first() user_preference = language_result.scalars().first()
# print("UserSearchSpacePreference:", user_preference) # print("UserSearchSpacePreference:", user_preference)
language = None language = None
if user_preference and user_preference.search_space and user_preference.search_space.llm_configs: if (
user_preference
and user_preference.search_space
and user_preference.search_space.llm_configs
):
llm_configs = user_preference.search_space.llm_configs llm_configs = user_preference.search_space.llm_configs
for preferred_llm in [
for preferred_llm in [user_preference.fast_llm, user_preference.long_context_llm, user_preference.strategic_llm]: user_preference.fast_llm,
if preferred_llm and getattr(preferred_llm, 'language', None): user_preference.long_context_llm,
user_preference.strategic_llm,
]:
if preferred_llm and getattr(preferred_llm, "language", None):
language = preferred_llm.language language = preferred_llm.language
break break
if not language: if not language:
first_llm_config = llm_configs[0] first_llm_config = llm_configs[0]
language = getattr(first_llm_config, 'language', None) language = getattr(first_llm_config, "language", None)
except HTTPException: except HTTPException:
raise HTTPException( raise HTTPException(
status_code=403, detail="You don't have access to this search space" status_code=403, detail="You don't have access to this search space"
) from None ) from None
langchain_chat_history = [] langchain_chat_history = []
for message in messages[:-1]: for message in messages[:-1]:
if message["role"] == "user": if message["role"] == "user":
langchain_chat_history.append(HumanMessage(content=message["content"])) langchain_chat_history.append(HumanMessage(content=message["content"]))
elif message["role"] == "assistant": elif message["role"] == "assistant":
langchain_chat_history.append(AIMessage(content=message["content"])) langchain_chat_history.append(AIMessage(content=message["content"]))
response = StreamingResponse( response = StreamingResponse(
stream_connector_search_results( stream_connector_search_results(
@ -117,7 +122,7 @@ async def handle_chat_data(
langchain_chat_history, langchain_chat_history,
search_mode_str, search_mode_str,
document_ids_to_add_in_context, document_ids_to_add_in_context,
language, language,
) )
) )

View file

@ -299,10 +299,10 @@ async def update_user_llm_preferences(
# Validate that all provided LLM config IDs belong to the search space # Validate that all provided LLM config IDs belong to the search space
update_data = preferences.model_dump(exclude_unset=True) update_data = preferences.model_dump(exclude_unset=True)
# Store language from configs to validate consistency # Store language from configs to validate consistency
languages = set() languages = set()
for _key, llm_config_id in update_data.items(): for _key, llm_config_id in update_data.items():
if llm_config_id is not None: if llm_config_id is not None:
# Verify the LLM config belongs to the search space # Verify the LLM config belongs to the search space
@ -318,10 +318,10 @@ async def update_user_llm_preferences(
status_code=404, status_code=404,
detail=f"LLM configuration {llm_config_id} not found in this search space", detail=f"LLM configuration {llm_config_id} not found in this search space",
) )
# Collect language for consistency check # Collect language for consistency check
languages.add(llm_config.language) languages.add(llm_config.language)
# Check if all selected LLM configs have the same language # Check if all selected LLM configs have the same language
if len(languages) > 1: if len(languages) > 1:
raise HTTPException( raise HTTPException(