From aab2f2dfefd3a4c7ed8ec9bba3014406c0d60490 Mon Sep 17 00:00:00 2001 From: Aditya Vaish Date: Mon, 6 Oct 2025 16:47:02 +0530 Subject: [PATCH] chat-routes vulnerability fixed --- surfsense_backend/app/routes/chats_routes.py | 373 ++++++++++++++++++- 1 file changed, 355 insertions(+), 18 deletions(-) diff --git a/surfsense_backend/app/routes/chats_routes.py b/surfsense_backend/app/routes/chats_routes.py index 21af85c0f..9783464f5 100644 --- a/surfsense_backend/app/routes/chats_routes.py +++ b/surfsense_backend/app/routes/chats_routes.py @@ -1,3 +1,6 @@ +import re +from typing import Any + from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse from langchain.schema import AIMessage, HumanMessage @@ -20,36 +23,351 @@ from app.utils.check_ownership import check_ownership router = APIRouter() +def validate_search_space_id(search_space_id: Any) -> int: + """ + Validate and convert search_space_id to integer. + + Args: + search_space_id: The search space ID to validate + + Returns: + int: Validated search space ID + + Raises: + HTTPException: If validation fails + """ + if search_space_id is None: + raise HTTPException( + status_code=400, + detail="search_space_id is required" + ) + + if isinstance(search_space_id, int): + if search_space_id <= 0: + raise HTTPException( + status_code=400, + detail="search_space_id must be a positive integer" + ) + return search_space_id + + if isinstance(search_space_id, str): + # Check if it's a valid integer string + if not search_space_id.strip(): + raise HTTPException( + status_code=400, + detail="search_space_id cannot be empty" + ) + + # Check for valid integer format (no leading zeros, no decimal points) + if not re.match(r'^[1-9]\d*$', search_space_id.strip()): + raise HTTPException( + status_code=400, + detail="search_space_id must be a valid positive integer" + ) + + try: + value = int(search_space_id.strip()) + if value <= 0: + raise HTTPException( + status_code=400, + detail="search_space_id must be a positive integer" + ) + return value + except ValueError: + raise HTTPException( + status_code=400, + detail="search_space_id must be a valid integer" + ) from None + + raise HTTPException( + status_code=400, + detail="search_space_id must be an integer or string representation of an integer" + ) + + +def validate_document_ids(document_ids: Any) -> list[int]: + """ + Validate and convert document_ids to list of integers. + + Args: + document_ids: The document IDs to validate + + Returns: + List[int]: Validated list of document IDs + + Raises: + HTTPException: If validation fails + """ + if document_ids is None: + return [] + + if not isinstance(document_ids, list): + raise HTTPException( + status_code=400, + detail="document_ids_to_add_in_context must be a list" + ) + + validated_ids = [] + for i, doc_id in enumerate(document_ids): + if isinstance(doc_id, int): + if doc_id <= 0: + raise HTTPException( + status_code=400, + detail=f"document_ids_to_add_in_context[{i}] must be a positive integer" + ) + validated_ids.append(doc_id) + elif isinstance(doc_id, str): + if not doc_id.strip(): + raise HTTPException( + status_code=400, + detail=f"document_ids_to_add_in_context[{i}] cannot be empty" + ) + + if not re.match(r'^[1-9]\d*$', doc_id.strip()): + raise HTTPException( + status_code=400, + detail=f"document_ids_to_add_in_context[{i}] must be a valid positive integer" + ) + + try: + value = int(doc_id.strip()) + if value <= 0: + raise HTTPException( + status_code=400, + detail=f"document_ids_to_add_in_context[{i}] must be a positive integer" + ) + validated_ids.append(value) + except ValueError: + raise HTTPException( + status_code=400, + detail=f"document_ids_to_add_in_context[{i}] must be a valid integer" + ) from None + else: + raise HTTPException( + status_code=400, + detail=f"document_ids_to_add_in_context[{i}] must be an integer or string representation of an integer" + ) + + return validated_ids + + +def validate_connectors(connectors: Any) -> list[str]: + """ + Validate selected_connectors list. + + Args: + connectors: The connectors to validate + + Returns: + List[str]: Validated list of connector names + + Raises: + HTTPException: If validation fails + """ + if connectors is None: + return [] + + if not isinstance(connectors, list): + raise HTTPException( + status_code=400, + detail="selected_connectors must be a list" + ) + + validated_connectors = [] + for i, connector in enumerate(connectors): + if not isinstance(connector, str): + raise HTTPException( + status_code=400, + detail=f"selected_connectors[{i}] must be a string" + ) + + if not connector.strip(): + raise HTTPException( + status_code=400, + detail=f"selected_connectors[{i}] cannot be empty" + ) + + # Basic sanitization - remove any potentially dangerous characters + sanitized = re.sub(r'[^\w\-_]', '', connector.strip()) + if not sanitized: + raise HTTPException( + status_code=400, + detail=f"selected_connectors[{i}] contains invalid characters" + ) + + validated_connectors.append(sanitized) + + return validated_connectors + + +def validate_research_mode(research_mode: Any) -> str: + """ + Validate research_mode parameter. + + Args: + research_mode: The research mode to validate + + Returns: + str: Validated research mode + + Raises: + HTTPException: If validation fails + """ + if research_mode is None: + return "GENERAL" # Default value + + if not isinstance(research_mode, str): + raise HTTPException( + status_code=400, + detail="research_mode must be a string" + ) + + valid_modes = ["GENERAL", "DEEP", "DEEPER"] + if research_mode.upper() not in valid_modes: + raise HTTPException( + status_code=400, + detail=f"research_mode must be one of: {', '.join(valid_modes)}" + ) + + return research_mode.upper() + + +def validate_search_mode(search_mode: Any) -> str: + """ + Validate search_mode parameter. + + Args: + search_mode: The search mode to validate + + Returns: + str: Validated search mode + + Raises: + HTTPException: If validation fails + """ + if search_mode is None: + return "CHUNKS" # Default value + + if not isinstance(search_mode, str): + raise HTTPException( + status_code=400, + detail="search_mode must be a string" + ) + + valid_modes = ["CHUNKS", "DOCUMENTS"] + if search_mode.upper() not in valid_modes: + raise HTTPException( + status_code=400, + detail=f"search_mode must be one of: {', '.join(valid_modes)}" + ) + + return search_mode.upper() + + +def validate_messages(messages: Any) -> list[dict]: + """ + Validate messages structure. + + Args: + messages: The messages to validate + + Returns: + List[dict]: Validated messages + + Raises: + HTTPException: If validation fails + """ + if not isinstance(messages, list): + raise HTTPException( + status_code=400, + detail="messages must be a list" + ) + + if not messages: + raise HTTPException( + status_code=400, + detail="messages cannot be empty" + ) + + validated_messages = [] + for i, message in enumerate(messages): + if not isinstance(message, dict): + raise HTTPException( + status_code=400, + detail=f"messages[{i}] must be a dictionary" + ) + + if "role" not in message: + raise HTTPException( + status_code=400, + detail=f"messages[{i}] must have a 'role' field" + ) + + if "content" not in message: + raise HTTPException( + status_code=400, + detail=f"messages[{i}] must have a 'content' field" + ) + + role = message["role"] + if not isinstance(role, str) or role not in ["user", "assistant", "system"]: + raise HTTPException( + status_code=400, + detail=f"messages[{i}].role must be 'user', 'assistant', or 'system'" + ) + + content = message["content"] + if not isinstance(content, str): + raise HTTPException( + status_code=400, + detail=f"messages[{i}].content must be a string" + ) + + if not content.strip(): + raise HTTPException( + status_code=400, + detail=f"messages[{i}].content cannot be empty" + ) + + # Basic content sanitization + sanitized_content = content.strip() + if len(sanitized_content) > 10000: # Reasonable limit + raise HTTPException( + status_code=400, + detail=f"messages[{i}].content is too long (max 10000 characters)" + ) + + validated_messages.append({ + "role": role, + "content": sanitized_content + }) + + return validated_messages + + @router.post("/chat") async def handle_chat_data( request: AISDKChatRequest, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - messages = request.messages + # Validate and sanitize all input data + messages = validate_messages(request.messages) + if messages[-1]["role"] != "user": raise HTTPException( status_code=400, detail="Last message must be a user message" ) user_query = messages[-1]["content"] - search_space_id = request.data.get("search_space_id") - research_mode: str = request.data.get("research_mode") - selected_connectors: list[str] = request.data.get("selected_connectors") - document_ids_to_add_in_context: list[int] = request.data.get( - "document_ids_to_add_in_context" - ) - - search_mode_str = request.data.get("search_mode", "CHUNKS") - - # Convert search_space_id to integer if it's a string - if search_space_id and isinstance(search_space_id, str): - try: - search_space_id = int(search_space_id) - except ValueError: - raise HTTPException( - status_code=400, detail="Invalid search_space_id format" - ) from None + + # Extract and validate data from request + request_data = request.data or {} + search_space_id = validate_search_space_id(request_data.get("search_space_id")) + research_mode = validate_research_mode(request_data.get("research_mode")) + selected_connectors = validate_connectors(request_data.get("selected_connectors")) + document_ids_to_add_in_context = validate_document_ids(request_data.get("document_ids_to_add_in_context")) + search_mode_str = validate_search_mode(request_data.get("search_mode")) # Check if the search space belongs to the current user try: @@ -126,6 +444,25 @@ async def read_chats( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): + # Validate pagination parameters + if skip < 0: + raise HTTPException( + status_code=400, + detail="skip must be a non-negative integer" + ) + + if limit <= 0 or limit > 1000: # Reasonable upper limit + raise HTTPException( + status_code=400, + detail="limit must be between 1 and 1000" + ) + + # Validate search_space_id if provided + if search_space_id is not None and search_space_id <= 0: + raise HTTPException( + status_code=400, + detail="search_space_id must be a positive integer" + ) try: # Select specific fields excluding messages query = (