dograh/api/routes/knowledge_base.py
2026-01-17 13:36:26 +05:30

405 lines
13 KiB
Python

"""API routes for knowledge base operations."""
import uuid
from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from loguru import logger
from api.db import db_client
from api.schemas.knowledge_base import (
ChunkSearchRequestSchema,
ChunkSearchResponseSchema,
DocumentListResponseSchema,
DocumentResponseSchema,
DocumentUploadRequestSchema,
DocumentUploadResponseSchema,
ProcessDocumentRequestSchema,
)
from api.services.auth.depends import get_user
from api.services.storage import storage_fs
from api.tasks.arq import enqueue_job
from api.tasks.function_names import FunctionNames
router = APIRouter(prefix="/knowledge-base", tags=["knowledge-base"])
@router.post(
"/upload-url",
response_model=DocumentUploadResponseSchema,
summary="Get presigned URL for document upload",
)
async def get_upload_url(
request: DocumentUploadRequestSchema,
user=Depends(get_user),
):
"""Generate a presigned PUT URL for uploading a document.
This endpoint:
1. Generates a unique document UUID for organizing the S3 key
2. Generates a presigned S3/MinIO URL for uploading the file
3. Returns the upload URL and document metadata
After uploading to the returned URL, call /process-document to create
the document record and trigger processing.
Access Control:
* All authenticated users can upload documents scoped to their organization.
"""
try:
# Generate unique document UUID for S3 organization
document_uuid = str(uuid.uuid4())
# Generate S3 key: knowledge_base/{org_id}/{document_uuid}/{filename}
s3_key = f"knowledge_base/{user.selected_organization_id}/{document_uuid}/{request.filename}"
# Generate presigned PUT URL (valid for 30 minutes)
upload_url = await storage_fs.aget_presigned_put_url(
file_path=s3_key,
expiration=1800, # 30 minutes
content_type=request.mime_type,
max_size=100_000_000, # 100MB max
)
if not upload_url:
raise HTTPException(
status_code=500, detail="Failed to generate presigned upload URL"
)
logger.info(
f"Generated upload URL for document {document_uuid}, "
f"user {user.id}, org {user.selected_organization_id}"
)
return DocumentUploadResponseSchema(
upload_url=upload_url,
document_uuid=document_uuid,
s3_key=s3_key,
)
except Exception as exc:
logger.error(f"Error generating upload URL: {exc}")
raise HTTPException(
status_code=500, detail="Failed to generate upload URL"
) from exc
@router.post(
"/process-document",
response_model=DocumentResponseSchema,
summary="Trigger document processing",
)
async def process_document(
request: ProcessDocumentRequestSchema,
user=Depends(get_user),
):
"""Trigger asynchronous processing of an uploaded document.
This endpoint should be called after successfully uploading a file to the presigned URL.
It will:
1. Create a document record in the database with the specified UUID
2. Enqueue a background task to process the document (chunking and embedding)
The document status will be updated from 'pending' -> 'processing' -> 'completed' or 'failed'.
Embedding Services:
* openai (default): High-quality 1536-dimensional embeddings (requires OPENAI_API_KEY)
* sentence_transformer: Free, offline-capable, 384-dimensional embeddings
Access Control:
* Users can only process documents in their organization.
"""
try:
# Extract filename from s3_key
filename = request.s3_key.split("/")[-1]
# Create document record with the specific UUID from upload
document = await db_client.create_document(
organization_id=user.selected_organization_id,
created_by=user.id,
filename=filename,
file_size_bytes=0, # Will be updated by background task
file_hash="", # Will be computed by background task
mime_type="application/octet-stream", # Will be detected by background task
custom_metadata={"s3_key": request.s3_key},
document_uuid=request.document_uuid, # Use UUID from upload
)
# Enqueue background task for processing
await enqueue_job(
FunctionNames.PROCESS_KNOWLEDGE_BASE_DOCUMENT,
document.id,
request.s3_key,
user.selected_organization_id,
128, # max_tokens (default)
request.embedding_service,
)
logger.info(
f"Created document {request.document_uuid} (id={document.id}) and enqueued processing "
f"with {request.embedding_service} embeddings, org {user.selected_organization_id}"
)
return DocumentResponseSchema(
id=document.id,
document_uuid=request.document_uuid,
filename=filename,
file_size_bytes=0,
file_hash="",
mime_type="application/octet-stream",
processing_status="pending",
processing_error=None,
total_chunks=0,
custom_metadata={"s3_key": request.s3_key},
docling_metadata={},
source_url=None,
created_at=document.created_at,
updated_at=document.updated_at,
organization_id=user.selected_organization_id,
created_by=user.id,
is_active=True,
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"Error processing document: {exc}")
raise HTTPException(
status_code=500, detail="Failed to process document"
) from exc
@router.get(
"/documents",
response_model=DocumentListResponseSchema,
summary="List documents",
)
async def list_documents(
status: Annotated[
Optional[str],
Query(description="Filter by processing status"),
] = None,
limit: Annotated[int, Query(ge=1, le=100)] = 100,
offset: Annotated[int, Query(ge=0)] = 0,
user=Depends(get_user),
):
"""List all documents for the user's organization.
Access Control:
* Users can only see documents from their organization.
"""
try:
documents = await db_client.get_documents_for_organization(
organization_id=user.selected_organization_id,
processing_status=status,
limit=limit,
offset=offset,
)
# Convert to response schema
document_list = [
DocumentResponseSchema(
id=doc.id,
document_uuid=doc.document_uuid,
filename=doc.filename,
file_size_bytes=doc.file_size_bytes,
file_hash=doc.file_hash,
mime_type=doc.mime_type,
processing_status=doc.processing_status,
processing_error=doc.processing_error,
total_chunks=doc.total_chunks,
custom_metadata=doc.custom_metadata,
docling_metadata=doc.docling_metadata,
source_url=doc.source_url,
created_at=doc.created_at,
updated_at=doc.updated_at,
organization_id=doc.organization_id,
created_by=doc.created_by,
is_active=doc.is_active,
)
for doc in documents
]
return DocumentListResponseSchema(
documents=document_list,
total=len(document_list),
limit=limit,
offset=offset,
)
except Exception as exc:
logger.error(f"Error listing documents: {exc}")
raise HTTPException(status_code=500, detail="Failed to list documents") from exc
@router.get(
"/documents/{document_uuid}",
response_model=DocumentResponseSchema,
summary="Get document details",
)
async def get_document(
document_uuid: str,
user=Depends(get_user),
):
"""Get details of a specific document.
Access Control:
* Users can only access documents from their organization.
"""
try:
document = await db_client.get_document_by_uuid(
document_uuid=document_uuid,
organization_id=user.selected_organization_id,
)
if not document:
raise HTTPException(status_code=404, detail="Document not found")
return DocumentResponseSchema(
id=document.id,
document_uuid=document.document_uuid,
filename=document.filename,
file_size_bytes=document.file_size_bytes,
file_hash=document.file_hash,
mime_type=document.mime_type,
processing_status=document.processing_status,
processing_error=document.processing_error,
total_chunks=document.total_chunks,
custom_metadata=document.custom_metadata,
docling_metadata=document.docling_metadata,
source_url=document.source_url,
created_at=document.created_at,
updated_at=document.updated_at,
organization_id=document.organization_id,
created_by=document.created_by,
is_active=document.is_active,
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"Error getting document: {exc}")
raise HTTPException(status_code=500, detail="Failed to get document") from exc
@router.delete(
"/documents/{document_uuid}",
summary="Delete document",
)
async def delete_document(
document_uuid: str,
user=Depends(get_user),
):
"""Soft delete a document and its chunks.
Access Control:
* Users can only delete documents from their organization.
"""
try:
success = await db_client.delete_document(
document_uuid=document_uuid,
organization_id=user.selected_organization_id,
)
if not success:
raise HTTPException(status_code=404, detail="Document not found")
logger.info(
f"Deleted document {document_uuid}, "
f"user {user.id}, org {user.selected_organization_id}"
)
return {"success": True, "message": "Document deleted successfully"}
except HTTPException:
raise
except Exception as exc:
logger.error(f"Error deleting document: {exc}")
raise HTTPException(
status_code=500, detail="Failed to delete document"
) from exc
@router.post(
"/search",
response_model=ChunkSearchResponseSchema,
summary="Search for similar chunks",
)
async def search_chunks(
request: ChunkSearchRequestSchema,
user=Depends(get_user),
):
"""Search for document chunks similar to the query.
This endpoint uses vector similarity search to find relevant chunks.
Results are returned without threshold filtering - apply similarity
thresholds at the application layer after optional reranking.
Access Control:
* Users can only search documents from their organization.
"""
try:
# Import here to avoid circular dependency
from api.services.gen_ai import OpenAIEmbeddingService
# Try to get user's embeddings configuration
user_config = await db_client.get_user_configurations(user.id)
embeddings_api_key = None
embeddings_model = None
if user_config.embeddings:
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
# Initialize embedding service with user config or fallback to env
embedding_service = OpenAIEmbeddingService(
db_client=db_client,
api_key=embeddings_api_key,
model_id=embeddings_model or "text-embedding-3-small",
)
# Perform search
results = await embedding_service.search_similar_chunks(
query=request.query,
organization_id=user.selected_organization_id,
limit=request.limit,
document_uuids=request.document_uuids,
)
# Apply similarity threshold if provided
if request.min_similarity is not None:
results = [r for r in results if r["similarity"] >= request.min_similarity]
# Convert to response schema
from api.schemas.knowledge_base import ChunkResponseSchema
chunks = [
ChunkResponseSchema(
id=r["id"],
document_id=r["document_id"],
chunk_text=r["chunk_text"],
contextualized_text=r.get("contextualized_text"),
chunk_index=r["chunk_index"],
chunk_metadata=r["chunk_metadata"],
filename=r["filename"],
document_uuid=r["document_uuid"],
similarity=r["similarity"],
)
for r in results
]
return ChunkSearchResponseSchema(
chunks=chunks,
query=request.query,
total_results=len(chunks),
)
except Exception as exc:
logger.error(f"Error searching chunks: {exc}")
raise HTTPException(status_code=500, detail="Failed to search chunks") from exc