2025-10-06 16:47:02 +05:30
|
|
|
import re
|
|
|
|
|
from typing import Any
|
|
|
|
|
|
2025-04-20 19:19:35 -07:00
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
|
|
|
from fastapi.responses import StreamingResponse
|
2025-07-24 14:43:48 -07:00
|
|
|
from langchain.schema import AIMessage, HumanMessage
|
2025-04-20 19:19:35 -07:00
|
|
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
from sqlalchemy.future import select
|
2025-05-11 23:05:56 -07:00
|
|
|
|
2025-07-24 14:43:48 -07:00
|
|
|
from app.db import Chat, SearchSpace, User, get_async_session
|
2025-08-02 21:20:36 -07:00
|
|
|
from app.schemas import (
|
|
|
|
|
AISDKChatRequest,
|
|
|
|
|
ChatCreate,
|
|
|
|
|
ChatRead,
|
|
|
|
|
ChatReadWithoutMessages,
|
|
|
|
|
ChatUpdate,
|
|
|
|
|
)
|
2025-07-24 14:43:48 -07:00
|
|
|
from app.tasks.stream_connector_search_results import stream_connector_search_results
|
|
|
|
|
from app.users import current_active_user
|
|
|
|
|
from app.utils.check_ownership import check_ownership
|
2025-05-11 23:05:56 -07:00
|
|
|
|
2025-03-14 18:53:14 -07:00
|
|
|
router = APIRouter()
|
|
|
|
|
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-10-06 16:47:02 +05:30
|
|
|
def validate_search_space_id(search_space_id: Any) -> int:
|
|
|
|
|
"""
|
|
|
|
|
Validate and convert search_space_id to integer.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
search_space_id: The search space ID to validate
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
int: Validated search space ID
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
HTTPException: If validation fails
|
|
|
|
|
"""
|
|
|
|
|
if search_space_id is None:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail="search_space_id is required"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if isinstance(search_space_id, int):
|
|
|
|
|
if search_space_id <= 0:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail="search_space_id must be a positive integer"
|
|
|
|
|
)
|
|
|
|
|
return search_space_id
|
|
|
|
|
|
|
|
|
|
if isinstance(search_space_id, str):
|
|
|
|
|
# Check if it's a valid integer string
|
|
|
|
|
if not search_space_id.strip():
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail="search_space_id cannot be empty"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Check for valid integer format (no leading zeros, no decimal points)
|
|
|
|
|
if not re.match(r'^[1-9]\d*$', search_space_id.strip()):
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail="search_space_id must be a valid positive integer"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
value = int(search_space_id.strip())
|
|
|
|
|
if value <= 0:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail="search_space_id must be a positive integer"
|
|
|
|
|
)
|
|
|
|
|
return value
|
|
|
|
|
except ValueError:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail="search_space_id must be a valid integer"
|
|
|
|
|
) from None
|
|
|
|
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail="search_space_id must be an integer or string representation of an integer"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate_document_ids(document_ids: Any) -> list[int]:
|
|
|
|
|
"""
|
|
|
|
|
Validate and convert document_ids to list of integers.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
document_ids: The document IDs to validate
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
List[int]: Validated list of document IDs
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
HTTPException: If validation fails
|
|
|
|
|
"""
|
|
|
|
|
if document_ids is None:
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
if not isinstance(document_ids, list):
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail="document_ids_to_add_in_context must be a list"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
validated_ids = []
|
|
|
|
|
for i, doc_id in enumerate(document_ids):
|
|
|
|
|
if isinstance(doc_id, int):
|
|
|
|
|
if doc_id <= 0:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail=f"document_ids_to_add_in_context[{i}] must be a positive integer"
|
|
|
|
|
)
|
|
|
|
|
validated_ids.append(doc_id)
|
|
|
|
|
elif isinstance(doc_id, str):
|
|
|
|
|
if not doc_id.strip():
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail=f"document_ids_to_add_in_context[{i}] cannot be empty"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not re.match(r'^[1-9]\d*$', doc_id.strip()):
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail=f"document_ids_to_add_in_context[{i}] must be a valid positive integer"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
value = int(doc_id.strip())
|
|
|
|
|
if value <= 0:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail=f"document_ids_to_add_in_context[{i}] must be a positive integer"
|
|
|
|
|
)
|
|
|
|
|
validated_ids.append(value)
|
|
|
|
|
except ValueError:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail=f"document_ids_to_add_in_context[{i}] must be a valid integer"
|
|
|
|
|
) from None
|
|
|
|
|
else:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail=f"document_ids_to_add_in_context[{i}] must be an integer or string representation of an integer"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return validated_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate_connectors(connectors: Any) -> list[str]:
|
|
|
|
|
"""
|
|
|
|
|
Validate selected_connectors list.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
connectors: The connectors to validate
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
List[str]: Validated list of connector names
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
HTTPException: If validation fails
|
|
|
|
|
"""
|
|
|
|
|
if connectors is None:
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
if not isinstance(connectors, list):
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail="selected_connectors must be a list"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
validated_connectors = []
|
|
|
|
|
for i, connector in enumerate(connectors):
|
|
|
|
|
if not isinstance(connector, str):
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail=f"selected_connectors[{i}] must be a string"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not connector.strip():
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail=f"selected_connectors[{i}] cannot be empty"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Basic sanitization - remove any potentially dangerous characters
|
|
|
|
|
sanitized = re.sub(r'[^\w\-_]', '', connector.strip())
|
|
|
|
|
if not sanitized:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail=f"selected_connectors[{i}] contains invalid characters"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
validated_connectors.append(sanitized)
|
|
|
|
|
|
|
|
|
|
return validated_connectors
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate_research_mode(research_mode: Any) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Validate research_mode parameter.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
research_mode: The research mode to validate
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
str: Validated research mode
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
HTTPException: If validation fails
|
|
|
|
|
"""
|
|
|
|
|
if research_mode is None:
|
|
|
|
|
return "GENERAL" # Default value
|
|
|
|
|
|
|
|
|
|
if not isinstance(research_mode, str):
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail="research_mode must be a string"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
valid_modes = ["GENERAL", "DEEP", "DEEPER"]
|
|
|
|
|
if research_mode.upper() not in valid_modes:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail=f"research_mode must be one of: {', '.join(valid_modes)}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return research_mode.upper()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate_search_mode(search_mode: Any) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Validate search_mode parameter.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
search_mode: The search mode to validate
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
str: Validated search mode
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
HTTPException: If validation fails
|
|
|
|
|
"""
|
|
|
|
|
if search_mode is None:
|
|
|
|
|
return "CHUNKS" # Default value
|
|
|
|
|
|
|
|
|
|
if not isinstance(search_mode, str):
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail="search_mode must be a string"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
valid_modes = ["CHUNKS", "DOCUMENTS"]
|
|
|
|
|
if search_mode.upper() not in valid_modes:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail=f"search_mode must be one of: {', '.join(valid_modes)}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return search_mode.upper()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate_messages(messages: Any) -> list[dict]:
|
|
|
|
|
"""
|
|
|
|
|
Validate messages structure.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
messages: The messages to validate
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
List[dict]: Validated messages
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
HTTPException: If validation fails
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(messages, list):
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail="messages must be a list"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not messages:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail="messages cannot be empty"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
validated_messages = []
|
|
|
|
|
for i, message in enumerate(messages):
|
|
|
|
|
if not isinstance(message, dict):
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail=f"messages[{i}] must be a dictionary"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if "role" not in message:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail=f"messages[{i}] must have a 'role' field"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if "content" not in message:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail=f"messages[{i}] must have a 'content' field"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
role = message["role"]
|
|
|
|
|
if not isinstance(role, str) or role not in ["user", "assistant", "system"]:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail=f"messages[{i}].role must be 'user', 'assistant', or 'system'"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
content = message["content"]
|
|
|
|
|
if not isinstance(content, str):
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail=f"messages[{i}].content must be a string"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not content.strip():
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail=f"messages[{i}].content cannot be empty"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Basic content sanitization
|
|
|
|
|
sanitized_content = content.strip()
|
|
|
|
|
if len(sanitized_content) > 10000: # Reasonable limit
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail=f"messages[{i}].content is too long (max 10000 characters)"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
validated_messages.append({
|
|
|
|
|
"role": role,
|
|
|
|
|
"content": sanitized_content
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
return validated_messages
|
|
|
|
|
|
|
|
|
|
|
2025-03-14 18:53:14 -07:00
|
|
|
@router.post("/chat")
|
|
|
|
|
async def handle_chat_data(
|
|
|
|
|
request: AISDKChatRequest,
|
|
|
|
|
session: AsyncSession = Depends(get_async_session),
|
2025-07-24 14:43:48 -07:00
|
|
|
user: User = Depends(current_active_user),
|
2025-03-14 18:53:14 -07:00
|
|
|
):
|
2025-10-06 16:47:02 +05:30
|
|
|
# Validate and sanitize all input data
|
|
|
|
|
messages = validate_messages(request.messages)
|
|
|
|
|
|
2025-07-24 14:43:48 -07:00
|
|
|
if messages[-1]["role"] != "user":
|
2025-03-14 18:53:14 -07:00
|
|
|
raise HTTPException(
|
2025-07-24 14:43:48 -07:00
|
|
|
status_code=400, detail="Last message must be a user message"
|
|
|
|
|
)
|
2025-03-14 18:53:14 -07:00
|
|
|
|
2025-07-24 14:43:48 -07:00
|
|
|
user_query = messages[-1]["content"]
|
2025-10-06 16:47:02 +05:30
|
|
|
|
|
|
|
|
# Extract and validate data from request
|
|
|
|
|
request_data = request.data or {}
|
|
|
|
|
search_space_id = validate_search_space_id(request_data.get("search_space_id"))
|
|
|
|
|
research_mode = validate_research_mode(request_data.get("research_mode"))
|
|
|
|
|
selected_connectors = validate_connectors(request_data.get("selected_connectors"))
|
|
|
|
|
document_ids_to_add_in_context = validate_document_ids(request_data.get("document_ids_to_add_in_context"))
|
|
|
|
|
search_mode_str = validate_search_mode(request_data.get("search_mode"))
|
2025-03-14 18:53:14 -07:00
|
|
|
|
|
|
|
|
# Check if the search space belongs to the current user
|
|
|
|
|
try:
|
|
|
|
|
await check_ownership(session, SearchSpace, search_space_id, user)
|
|
|
|
|
except HTTPException:
|
|
|
|
|
raise HTTPException(
|
2025-07-24 14:43:48 -07:00
|
|
|
status_code=403, detail="You don't have access to this search space"
|
|
|
|
|
) from None
|
|
|
|
|
|
2025-05-10 20:06:19 -07:00
|
|
|
langchain_chat_history = []
|
|
|
|
|
for message in messages[:-1]:
|
2025-07-24 14:43:48 -07:00
|
|
|
if message["role"] == "user":
|
|
|
|
|
langchain_chat_history.append(HumanMessage(content=message["content"]))
|
|
|
|
|
elif message["role"] == "assistant":
|
|
|
|
|
langchain_chat_history.append(AIMessage(content=message["content"]))
|
2025-03-14 18:53:14 -07:00
|
|
|
|
2025-07-18 17:43:07 -07:00
|
|
|
response = StreamingResponse(
|
|
|
|
|
stream_connector_search_results(
|
|
|
|
|
user_query,
|
|
|
|
|
user.id,
|
|
|
|
|
search_space_id,
|
|
|
|
|
session,
|
|
|
|
|
research_mode,
|
|
|
|
|
selected_connectors,
|
|
|
|
|
langchain_chat_history,
|
|
|
|
|
search_mode_str,
|
|
|
|
|
document_ids_to_add_in_context,
|
|
|
|
|
)
|
|
|
|
|
)
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-07-18 17:43:07 -07:00
|
|
|
response.headers["x-vercel-ai-data-stream"] = "v1"
|
2025-03-14 18:53:14 -07:00
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/chats/", response_model=ChatRead)
|
|
|
|
|
async def create_chat(
|
|
|
|
|
chat: ChatCreate,
|
|
|
|
|
session: AsyncSession = Depends(get_async_session),
|
2025-07-24 14:43:48 -07:00
|
|
|
user: User = Depends(current_active_user),
|
2025-03-14 18:53:14 -07:00
|
|
|
):
|
|
|
|
|
try:
|
|
|
|
|
await check_ownership(session, SearchSpace, chat.search_space_id, user)
|
|
|
|
|
db_chat = Chat(**chat.model_dump())
|
|
|
|
|
session.add(db_chat)
|
|
|
|
|
await session.commit()
|
|
|
|
|
await session.refresh(db_chat)
|
|
|
|
|
return db_chat
|
|
|
|
|
except HTTPException:
|
|
|
|
|
raise
|
2025-07-24 14:43:48 -07:00
|
|
|
except IntegrityError:
|
2025-03-14 18:53:14 -07:00
|
|
|
await session.rollback()
|
|
|
|
|
raise HTTPException(
|
2025-07-24 14:43:48 -07:00
|
|
|
status_code=400,
|
|
|
|
|
detail="Database constraint violation. Please check your input data.",
|
|
|
|
|
) from None
|
|
|
|
|
except OperationalError:
|
2025-03-14 18:53:14 -07:00
|
|
|
await session.rollback()
|
|
|
|
|
raise HTTPException(
|
2025-07-24 14:43:48 -07:00
|
|
|
status_code=503, detail="Database operation failed. Please try again later."
|
|
|
|
|
) from None
|
|
|
|
|
except Exception:
|
2025-03-14 18:53:14 -07:00
|
|
|
await session.rollback()
|
|
|
|
|
raise HTTPException(
|
2025-07-24 14:43:48 -07:00
|
|
|
status_code=500,
|
|
|
|
|
detail="An unexpected error occurred while creating the chat.",
|
|
|
|
|
) from None
|
2025-03-14 18:53:14 -07:00
|
|
|
|
|
|
|
|
|
2025-08-02 21:20:36 -07:00
|
|
|
@router.get("/chats/", response_model=list[ChatReadWithoutMessages])
|
2025-03-14 18:53:14 -07:00
|
|
|
async def read_chats(
|
|
|
|
|
skip: int = 0,
|
|
|
|
|
limit: int = 100,
|
2025-07-24 14:43:48 -07:00
|
|
|
search_space_id: int | None = None,
|
2025-03-14 18:53:14 -07:00
|
|
|
session: AsyncSession = Depends(get_async_session),
|
2025-07-24 14:43:48 -07:00
|
|
|
user: User = Depends(current_active_user),
|
2025-03-14 18:53:14 -07:00
|
|
|
):
|
2025-10-06 16:47:02 +05:30
|
|
|
# Validate pagination parameters
|
|
|
|
|
if skip < 0:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail="skip must be a non-negative integer"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if limit <= 0 or limit > 1000: # Reasonable upper limit
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail="limit must be between 1 and 1000"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Validate search_space_id if provided
|
|
|
|
|
if search_space_id is not None and search_space_id <= 0:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail="search_space_id must be a positive integer"
|
|
|
|
|
)
|
2025-03-14 18:53:14 -07:00
|
|
|
try:
|
2025-08-02 21:20:36 -07:00
|
|
|
# Select specific fields excluding messages
|
|
|
|
|
query = (
|
|
|
|
|
select(
|
|
|
|
|
Chat.id,
|
|
|
|
|
Chat.type,
|
|
|
|
|
Chat.title,
|
|
|
|
|
Chat.initial_connectors,
|
|
|
|
|
Chat.search_space_id,
|
|
|
|
|
Chat.created_at,
|
|
|
|
|
)
|
|
|
|
|
.join(SearchSpace)
|
|
|
|
|
.filter(SearchSpace.user_id == user.id)
|
|
|
|
|
)
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-04-17 23:19:56 -07:00
|
|
|
# Filter by search_space_id if provided
|
|
|
|
|
if search_space_id is not None:
|
|
|
|
|
query = query.filter(Chat.search_space_id == search_space_id)
|
2025-07-24 14:43:48 -07:00
|
|
|
|
|
|
|
|
result = await session.execute(query.offset(skip).limit(limit))
|
2025-08-02 21:20:36 -07:00
|
|
|
return result.all()
|
2025-03-14 18:53:14 -07:00
|
|
|
except OperationalError:
|
|
|
|
|
raise HTTPException(
|
2025-07-24 14:43:48 -07:00
|
|
|
status_code=503, detail="Database operation failed. Please try again later."
|
|
|
|
|
) from None
|
2025-03-14 18:53:14 -07:00
|
|
|
except Exception:
|
|
|
|
|
raise HTTPException(
|
2025-07-24 14:43:48 -07:00
|
|
|
status_code=500, detail="An unexpected error occurred while fetching chats."
|
|
|
|
|
) from None
|
2025-03-14 18:53:14 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/chats/{chat_id}", response_model=ChatRead)
|
|
|
|
|
async def read_chat(
|
|
|
|
|
chat_id: int,
|
|
|
|
|
session: AsyncSession = Depends(get_async_session),
|
2025-07-24 14:43:48 -07:00
|
|
|
user: User = Depends(current_active_user),
|
2025-03-14 18:53:14 -07:00
|
|
|
):
|
|
|
|
|
try:
|
|
|
|
|
result = await session.execute(
|
|
|
|
|
select(Chat)
|
|
|
|
|
.join(SearchSpace)
|
|
|
|
|
.filter(Chat.id == chat_id, SearchSpace.user_id == user.id)
|
|
|
|
|
)
|
|
|
|
|
chat = result.scalars().first()
|
|
|
|
|
if not chat:
|
|
|
|
|
raise HTTPException(
|
2025-07-24 14:43:48 -07:00
|
|
|
status_code=404,
|
|
|
|
|
detail="Chat not found or you don't have permission to access it",
|
|
|
|
|
)
|
2025-03-14 18:53:14 -07:00
|
|
|
return chat
|
|
|
|
|
except OperationalError:
|
|
|
|
|
raise HTTPException(
|
2025-07-24 14:43:48 -07:00
|
|
|
status_code=503, detail="Database operation failed. Please try again later."
|
|
|
|
|
) from None
|
2025-03-14 18:53:14 -07:00
|
|
|
except Exception:
|
|
|
|
|
raise HTTPException(
|
2025-07-24 14:43:48 -07:00
|
|
|
status_code=500,
|
|
|
|
|
detail="An unexpected error occurred while fetching the chat.",
|
|
|
|
|
) from None
|
2025-03-14 18:53:14 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.put("/chats/{chat_id}", response_model=ChatRead)
|
|
|
|
|
async def update_chat(
|
|
|
|
|
chat_id: int,
|
|
|
|
|
chat_update: ChatUpdate,
|
|
|
|
|
session: AsyncSession = Depends(get_async_session),
|
2025-07-24 14:43:48 -07:00
|
|
|
user: User = Depends(current_active_user),
|
2025-03-14 18:53:14 -07:00
|
|
|
):
|
|
|
|
|
try:
|
|
|
|
|
db_chat = await read_chat(chat_id, session, user)
|
|
|
|
|
update_data = chat_update.model_dump(exclude_unset=True)
|
|
|
|
|
for key, value in update_data.items():
|
|
|
|
|
setattr(db_chat, key, value)
|
|
|
|
|
await session.commit()
|
|
|
|
|
await session.refresh(db_chat)
|
|
|
|
|
return db_chat
|
|
|
|
|
except HTTPException:
|
|
|
|
|
raise
|
|
|
|
|
except IntegrityError:
|
|
|
|
|
await session.rollback()
|
|
|
|
|
raise HTTPException(
|
2025-07-24 14:43:48 -07:00
|
|
|
status_code=400,
|
|
|
|
|
detail="Database constraint violation. Please check your input data.",
|
|
|
|
|
) from None
|
2025-03-14 18:53:14 -07:00
|
|
|
except OperationalError:
|
|
|
|
|
await session.rollback()
|
|
|
|
|
raise HTTPException(
|
2025-07-24 14:43:48 -07:00
|
|
|
status_code=503, detail="Database operation failed. Please try again later."
|
|
|
|
|
) from None
|
2025-03-14 18:53:14 -07:00
|
|
|
except Exception:
|
|
|
|
|
await session.rollback()
|
|
|
|
|
raise HTTPException(
|
2025-07-24 14:43:48 -07:00
|
|
|
status_code=500,
|
|
|
|
|
detail="An unexpected error occurred while updating the chat.",
|
|
|
|
|
) from None
|
2025-03-14 18:53:14 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.delete("/chats/{chat_id}", response_model=dict)
|
|
|
|
|
async def delete_chat(
|
|
|
|
|
chat_id: int,
|
|
|
|
|
session: AsyncSession = Depends(get_async_session),
|
2025-07-24 14:43:48 -07:00
|
|
|
user: User = Depends(current_active_user),
|
2025-03-14 18:53:14 -07:00
|
|
|
):
|
|
|
|
|
try:
|
|
|
|
|
db_chat = await read_chat(chat_id, session, user)
|
|
|
|
|
await session.delete(db_chat)
|
|
|
|
|
await session.commit()
|
|
|
|
|
return {"message": "Chat deleted successfully"}
|
|
|
|
|
except HTTPException:
|
|
|
|
|
raise
|
|
|
|
|
except IntegrityError:
|
|
|
|
|
await session.rollback()
|
|
|
|
|
raise HTTPException(
|
2025-07-24 14:43:48 -07:00
|
|
|
status_code=400, detail="Cannot delete chat due to existing dependencies."
|
|
|
|
|
) from None
|
2025-03-14 18:53:14 -07:00
|
|
|
except OperationalError:
|
|
|
|
|
await session.rollback()
|
|
|
|
|
raise HTTPException(
|
2025-07-24 14:43:48 -07:00
|
|
|
status_code=503, detail="Database operation failed. Please try again later."
|
|
|
|
|
) from None
|
2025-03-14 18:53:14 -07:00
|
|
|
except Exception:
|
|
|
|
|
await session.rollback()
|
|
|
|
|
raise HTTPException(
|
2025-07-24 14:43:48 -07:00
|
|
|
status_code=500,
|
|
|
|
|
detail="An unexpected error occurred while deleting the chat.",
|
|
|
|
|
) from None
|