chat-routes vulnerability fixed

This commit is contained in:
Aditya Vaish 2025-10-06 16:47:02 +05:30
parent 6b7ce53c58
commit aab2f2dfef

View file

@ -1,3 +1,6 @@
import re
from typing import Any
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from langchain.schema import AIMessage, HumanMessage
@ -20,36 +23,351 @@ from app.utils.check_ownership import check_ownership
router = APIRouter()
def validate_search_space_id(search_space_id: Any) -> int:
"""
Validate and convert search_space_id to integer.
Args:
search_space_id: The search space ID to validate
Returns:
int: Validated search space ID
Raises:
HTTPException: If validation fails
"""
if search_space_id is None:
raise HTTPException(
status_code=400,
detail="search_space_id is required"
)
if isinstance(search_space_id, int):
if search_space_id <= 0:
raise HTTPException(
status_code=400,
detail="search_space_id must be a positive integer"
)
return search_space_id
if isinstance(search_space_id, str):
# Check if it's a valid integer string
if not search_space_id.strip():
raise HTTPException(
status_code=400,
detail="search_space_id cannot be empty"
)
# Check for valid integer format (no leading zeros, no decimal points)
if not re.match(r'^[1-9]\d*$', search_space_id.strip()):
raise HTTPException(
status_code=400,
detail="search_space_id must be a valid positive integer"
)
try:
value = int(search_space_id.strip())
if value <= 0:
raise HTTPException(
status_code=400,
detail="search_space_id must be a positive integer"
)
return value
except ValueError:
raise HTTPException(
status_code=400,
detail="search_space_id must be a valid integer"
) from None
raise HTTPException(
status_code=400,
detail="search_space_id must be an integer or string representation of an integer"
)
def validate_document_ids(document_ids: Any) -> list[int]:
"""
Validate and convert document_ids to list of integers.
Args:
document_ids: The document IDs to validate
Returns:
List[int]: Validated list of document IDs
Raises:
HTTPException: If validation fails
"""
if document_ids is None:
return []
if not isinstance(document_ids, list):
raise HTTPException(
status_code=400,
detail="document_ids_to_add_in_context must be a list"
)
validated_ids = []
for i, doc_id in enumerate(document_ids):
if isinstance(doc_id, int):
if doc_id <= 0:
raise HTTPException(
status_code=400,
detail=f"document_ids_to_add_in_context[{i}] must be a positive integer"
)
validated_ids.append(doc_id)
elif isinstance(doc_id, str):
if not doc_id.strip():
raise HTTPException(
status_code=400,
detail=f"document_ids_to_add_in_context[{i}] cannot be empty"
)
if not re.match(r'^[1-9]\d*$', doc_id.strip()):
raise HTTPException(
status_code=400,
detail=f"document_ids_to_add_in_context[{i}] must be a valid positive integer"
)
try:
value = int(doc_id.strip())
if value <= 0:
raise HTTPException(
status_code=400,
detail=f"document_ids_to_add_in_context[{i}] must be a positive integer"
)
validated_ids.append(value)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"document_ids_to_add_in_context[{i}] must be a valid integer"
) from None
else:
raise HTTPException(
status_code=400,
detail=f"document_ids_to_add_in_context[{i}] must be an integer or string representation of an integer"
)
return validated_ids
def validate_connectors(connectors: Any) -> list[str]:
"""
Validate selected_connectors list.
Args:
connectors: The connectors to validate
Returns:
List[str]: Validated list of connector names
Raises:
HTTPException: If validation fails
"""
if connectors is None:
return []
if not isinstance(connectors, list):
raise HTTPException(
status_code=400,
detail="selected_connectors must be a list"
)
validated_connectors = []
for i, connector in enumerate(connectors):
if not isinstance(connector, str):
raise HTTPException(
status_code=400,
detail=f"selected_connectors[{i}] must be a string"
)
if not connector.strip():
raise HTTPException(
status_code=400,
detail=f"selected_connectors[{i}] cannot be empty"
)
# Basic sanitization - remove any potentially dangerous characters
sanitized = re.sub(r'[^\w\-_]', '', connector.strip())
if not sanitized:
raise HTTPException(
status_code=400,
detail=f"selected_connectors[{i}] contains invalid characters"
)
validated_connectors.append(sanitized)
return validated_connectors
def validate_research_mode(research_mode: Any) -> str:
"""
Validate research_mode parameter.
Args:
research_mode: The research mode to validate
Returns:
str: Validated research mode
Raises:
HTTPException: If validation fails
"""
if research_mode is None:
return "GENERAL" # Default value
if not isinstance(research_mode, str):
raise HTTPException(
status_code=400,
detail="research_mode must be a string"
)
valid_modes = ["GENERAL", "DEEP", "DEEPER"]
if research_mode.upper() not in valid_modes:
raise HTTPException(
status_code=400,
detail=f"research_mode must be one of: {', '.join(valid_modes)}"
)
return research_mode.upper()
def validate_search_mode(search_mode: Any) -> str:
"""
Validate search_mode parameter.
Args:
search_mode: The search mode to validate
Returns:
str: Validated search mode
Raises:
HTTPException: If validation fails
"""
if search_mode is None:
return "CHUNKS" # Default value
if not isinstance(search_mode, str):
raise HTTPException(
status_code=400,
detail="search_mode must be a string"
)
valid_modes = ["CHUNKS", "DOCUMENTS"]
if search_mode.upper() not in valid_modes:
raise HTTPException(
status_code=400,
detail=f"search_mode must be one of: {', '.join(valid_modes)}"
)
return search_mode.upper()
def validate_messages(messages: Any) -> list[dict]:
"""
Validate messages structure.
Args:
messages: The messages to validate
Returns:
List[dict]: Validated messages
Raises:
HTTPException: If validation fails
"""
if not isinstance(messages, list):
raise HTTPException(
status_code=400,
detail="messages must be a list"
)
if not messages:
raise HTTPException(
status_code=400,
detail="messages cannot be empty"
)
validated_messages = []
for i, message in enumerate(messages):
if not isinstance(message, dict):
raise HTTPException(
status_code=400,
detail=f"messages[{i}] must be a dictionary"
)
if "role" not in message:
raise HTTPException(
status_code=400,
detail=f"messages[{i}] must have a 'role' field"
)
if "content" not in message:
raise HTTPException(
status_code=400,
detail=f"messages[{i}] must have a 'content' field"
)
role = message["role"]
if not isinstance(role, str) or role not in ["user", "assistant", "system"]:
raise HTTPException(
status_code=400,
detail=f"messages[{i}].role must be 'user', 'assistant', or 'system'"
)
content = message["content"]
if not isinstance(content, str):
raise HTTPException(
status_code=400,
detail=f"messages[{i}].content must be a string"
)
if not content.strip():
raise HTTPException(
status_code=400,
detail=f"messages[{i}].content cannot be empty"
)
# Basic content sanitization
sanitized_content = content.strip()
if len(sanitized_content) > 10000: # Reasonable limit
raise HTTPException(
status_code=400,
detail=f"messages[{i}].content is too long (max 10000 characters)"
)
validated_messages.append({
"role": role,
"content": sanitized_content
})
return validated_messages
@router.post("/chat")
async def handle_chat_data(
request: AISDKChatRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
messages = request.messages
# Validate and sanitize all input data
messages = validate_messages(request.messages)
if messages[-1]["role"] != "user":
raise HTTPException(
status_code=400, detail="Last message must be a user message"
)
user_query = messages[-1]["content"]
search_space_id = request.data.get("search_space_id")
research_mode: str = request.data.get("research_mode")
selected_connectors: list[str] = request.data.get("selected_connectors")
document_ids_to_add_in_context: list[int] = request.data.get(
"document_ids_to_add_in_context"
)
search_mode_str = request.data.get("search_mode", "CHUNKS")
# Convert search_space_id to integer if it's a string
if search_space_id and isinstance(search_space_id, str):
try:
search_space_id = int(search_space_id)
except ValueError:
raise HTTPException(
status_code=400, detail="Invalid search_space_id format"
) from None
# Extract and validate data from request
request_data = request.data or {}
search_space_id = validate_search_space_id(request_data.get("search_space_id"))
research_mode = validate_research_mode(request_data.get("research_mode"))
selected_connectors = validate_connectors(request_data.get("selected_connectors"))
document_ids_to_add_in_context = validate_document_ids(request_data.get("document_ids_to_add_in_context"))
search_mode_str = validate_search_mode(request_data.get("search_mode"))
# Check if the search space belongs to the current user
try:
@ -126,6 +444,25 @@ async def read_chats(
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
# Validate pagination parameters
if skip < 0:
raise HTTPException(
status_code=400,
detail="skip must be a non-negative integer"
)
if limit <= 0 or limit > 1000: # Reasonable upper limit
raise HTTPException(
status_code=400,
detail="limit must be between 1 and 1000"
)
# Validate search_space_id if provided
if search_space_id is not None and search_space_id <= 0:
raise HTTPException(
status_code=400,
detail="search_space_id must be a positive integer"
)
try:
# Select specific fields excluding messages
query = (