mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-26 01:06:23 +02:00
pending issues resolved
This commit is contained in:
parent
aab2f2dfef
commit
d1f21a8dc6
3 changed files with 683 additions and 510 deletions
|
|
@ -1,6 +1,3 @@
|
|||
import re
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from langchain.schema import AIMessage, HumanMessage
|
||||
|
|
@ -19,330 +16,18 @@ from app.schemas import (
|
|||
from app.tasks.stream_connector_search_results import stream_connector_search_results
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
from app.utils.validators import (
|
||||
validate_search_space_id,
|
||||
validate_document_ids,
|
||||
validate_connectors,
|
||||
validate_research_mode,
|
||||
validate_search_mode,
|
||||
validate_messages,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def validate_search_space_id(search_space_id: Any) -> int:
|
||||
"""
|
||||
Validate and convert search_space_id to integer.
|
||||
|
||||
Args:
|
||||
search_space_id: The search space ID to validate
|
||||
|
||||
Returns:
|
||||
int: Validated search space ID
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails
|
||||
"""
|
||||
if search_space_id is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="search_space_id is required"
|
||||
)
|
||||
|
||||
if isinstance(search_space_id, int):
|
||||
if search_space_id <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="search_space_id must be a positive integer"
|
||||
)
|
||||
return search_space_id
|
||||
|
||||
if isinstance(search_space_id, str):
|
||||
# Check if it's a valid integer string
|
||||
if not search_space_id.strip():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="search_space_id cannot be empty"
|
||||
)
|
||||
|
||||
# Check for valid integer format (no leading zeros, no decimal points)
|
||||
if not re.match(r'^[1-9]\d*$', search_space_id.strip()):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="search_space_id must be a valid positive integer"
|
||||
)
|
||||
|
||||
try:
|
||||
value = int(search_space_id.strip())
|
||||
if value <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="search_space_id must be a positive integer"
|
||||
)
|
||||
return value
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="search_space_id must be a valid integer"
|
||||
) from None
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="search_space_id must be an integer or string representation of an integer"
|
||||
)
|
||||
|
||||
|
||||
def validate_document_ids(document_ids: Any) -> list[int]:
|
||||
"""
|
||||
Validate and convert document_ids to list of integers.
|
||||
|
||||
Args:
|
||||
document_ids: The document IDs to validate
|
||||
|
||||
Returns:
|
||||
List[int]: Validated list of document IDs
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails
|
||||
"""
|
||||
if document_ids is None:
|
||||
return []
|
||||
|
||||
if not isinstance(document_ids, list):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="document_ids_to_add_in_context must be a list"
|
||||
)
|
||||
|
||||
validated_ids = []
|
||||
for i, doc_id in enumerate(document_ids):
|
||||
if isinstance(doc_id, int):
|
||||
if doc_id <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"document_ids_to_add_in_context[{i}] must be a positive integer"
|
||||
)
|
||||
validated_ids.append(doc_id)
|
||||
elif isinstance(doc_id, str):
|
||||
if not doc_id.strip():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"document_ids_to_add_in_context[{i}] cannot be empty"
|
||||
)
|
||||
|
||||
if not re.match(r'^[1-9]\d*$', doc_id.strip()):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"document_ids_to_add_in_context[{i}] must be a valid positive integer"
|
||||
)
|
||||
|
||||
try:
|
||||
value = int(doc_id.strip())
|
||||
if value <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"document_ids_to_add_in_context[{i}] must be a positive integer"
|
||||
)
|
||||
validated_ids.append(value)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"document_ids_to_add_in_context[{i}] must be a valid integer"
|
||||
) from None
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"document_ids_to_add_in_context[{i}] must be an integer or string representation of an integer"
|
||||
)
|
||||
|
||||
return validated_ids
|
||||
|
||||
|
||||
def validate_connectors(connectors: Any) -> list[str]:
|
||||
"""
|
||||
Validate selected_connectors list.
|
||||
|
||||
Args:
|
||||
connectors: The connectors to validate
|
||||
|
||||
Returns:
|
||||
List[str]: Validated list of connector names
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails
|
||||
"""
|
||||
if connectors is None:
|
||||
return []
|
||||
|
||||
if not isinstance(connectors, list):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="selected_connectors must be a list"
|
||||
)
|
||||
|
||||
validated_connectors = []
|
||||
for i, connector in enumerate(connectors):
|
||||
if not isinstance(connector, str):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"selected_connectors[{i}] must be a string"
|
||||
)
|
||||
|
||||
if not connector.strip():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"selected_connectors[{i}] cannot be empty"
|
||||
)
|
||||
|
||||
# Basic sanitization - remove any potentially dangerous characters
|
||||
sanitized = re.sub(r'[^\w\-_]', '', connector.strip())
|
||||
if not sanitized:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"selected_connectors[{i}] contains invalid characters"
|
||||
)
|
||||
|
||||
validated_connectors.append(sanitized)
|
||||
|
||||
return validated_connectors
|
||||
|
||||
|
||||
def validate_research_mode(research_mode: Any) -> str:
|
||||
"""
|
||||
Validate research_mode parameter.
|
||||
|
||||
Args:
|
||||
research_mode: The research mode to validate
|
||||
|
||||
Returns:
|
||||
str: Validated research mode
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails
|
||||
"""
|
||||
if research_mode is None:
|
||||
return "GENERAL" # Default value
|
||||
|
||||
if not isinstance(research_mode, str):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="research_mode must be a string"
|
||||
)
|
||||
|
||||
valid_modes = ["GENERAL", "DEEP", "DEEPER"]
|
||||
if research_mode.upper() not in valid_modes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"research_mode must be one of: {', '.join(valid_modes)}"
|
||||
)
|
||||
|
||||
return research_mode.upper()
|
||||
|
||||
|
||||
def validate_search_mode(search_mode: Any) -> str:
|
||||
"""
|
||||
Validate search_mode parameter.
|
||||
|
||||
Args:
|
||||
search_mode: The search mode to validate
|
||||
|
||||
Returns:
|
||||
str: Validated search mode
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails
|
||||
"""
|
||||
if search_mode is None:
|
||||
return "CHUNKS" # Default value
|
||||
|
||||
if not isinstance(search_mode, str):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="search_mode must be a string"
|
||||
)
|
||||
|
||||
valid_modes = ["CHUNKS", "DOCUMENTS"]
|
||||
if search_mode.upper() not in valid_modes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"search_mode must be one of: {', '.join(valid_modes)}"
|
||||
)
|
||||
|
||||
return search_mode.upper()
|
||||
|
||||
|
||||
def validate_messages(messages: Any) -> list[dict]:
|
||||
"""
|
||||
Validate messages structure.
|
||||
|
||||
Args:
|
||||
messages: The messages to validate
|
||||
|
||||
Returns:
|
||||
List[dict]: Validated messages
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails
|
||||
"""
|
||||
if not isinstance(messages, list):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="messages must be a list"
|
||||
)
|
||||
|
||||
if not messages:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="messages cannot be empty"
|
||||
)
|
||||
|
||||
validated_messages = []
|
||||
for i, message in enumerate(messages):
|
||||
if not isinstance(message, dict):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"messages[{i}] must be a dictionary"
|
||||
)
|
||||
|
||||
if "role" not in message:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"messages[{i}] must have a 'role' field"
|
||||
)
|
||||
|
||||
if "content" not in message:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"messages[{i}] must have a 'content' field"
|
||||
)
|
||||
|
||||
role = message["role"]
|
||||
if not isinstance(role, str) or role not in ["user", "assistant", "system"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"messages[{i}].role must be 'user', 'assistant', or 'system'"
|
||||
)
|
||||
|
||||
content = message["content"]
|
||||
if not isinstance(content, str):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"messages[{i}].content must be a string"
|
||||
)
|
||||
|
||||
if not content.strip():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"messages[{i}].content cannot be empty"
|
||||
)
|
||||
|
||||
# Basic content sanitization
|
||||
sanitized_content = content.strip()
|
||||
if len(sanitized_content) > 10000: # Reasonable limit
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"messages[{i}].content is too long (max 10000 characters)"
|
||||
)
|
||||
|
||||
validated_messages.append({
|
||||
"role": role,
|
||||
"content": sanitized_content
|
||||
})
|
||||
|
||||
return validated_messages
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue