mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
* feat: upload file and store embedding * feat: add documents in nodes * feat: add openai embedding service
405 lines
13 KiB
Python
405 lines
13 KiB
Python
"""API routes for knowledge base operations."""
|
|
|
|
import uuid
|
|
from typing import Annotated, Optional
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
from loguru import logger
|
|
|
|
from api.db import db_client
|
|
from api.schemas.knowledge_base import (
|
|
ChunkSearchRequestSchema,
|
|
ChunkSearchResponseSchema,
|
|
DocumentListResponseSchema,
|
|
DocumentResponseSchema,
|
|
DocumentUploadRequestSchema,
|
|
DocumentUploadResponseSchema,
|
|
ProcessDocumentRequestSchema,
|
|
)
|
|
from api.services.auth.depends import get_user
|
|
from api.services.storage import storage_fs
|
|
from api.tasks.arq import enqueue_job
|
|
from api.tasks.function_names import FunctionNames
|
|
|
|
router = APIRouter(prefix="/knowledge-base", tags=["knowledge-base"])
|
|
|
|
|
|
@router.post(
|
|
"/upload-url",
|
|
response_model=DocumentUploadResponseSchema,
|
|
summary="Get presigned URL for document upload",
|
|
)
|
|
async def get_upload_url(
|
|
request: DocumentUploadRequestSchema,
|
|
user=Depends(get_user),
|
|
):
|
|
"""Generate a presigned PUT URL for uploading a document.
|
|
|
|
This endpoint:
|
|
1. Generates a unique document UUID for organizing the S3 key
|
|
2. Generates a presigned S3/MinIO URL for uploading the file
|
|
3. Returns the upload URL and document metadata
|
|
|
|
After uploading to the returned URL, call /process-document to create
|
|
the document record and trigger processing.
|
|
|
|
Access Control:
|
|
* All authenticated users can upload documents scoped to their organization.
|
|
"""
|
|
|
|
try:
|
|
# Generate unique document UUID for S3 organization
|
|
document_uuid = str(uuid.uuid4())
|
|
|
|
# Generate S3 key: knowledge_base/{org_id}/{document_uuid}/{filename}
|
|
s3_key = f"knowledge_base/{user.selected_organization_id}/{document_uuid}/{request.filename}"
|
|
|
|
# Generate presigned PUT URL (valid for 30 minutes)
|
|
upload_url = await storage_fs.aget_presigned_put_url(
|
|
file_path=s3_key,
|
|
expiration=1800, # 30 minutes
|
|
content_type=request.mime_type,
|
|
max_size=100_000_000, # 100MB max
|
|
)
|
|
|
|
if not upload_url:
|
|
raise HTTPException(
|
|
status_code=500, detail="Failed to generate presigned upload URL"
|
|
)
|
|
|
|
logger.info(
|
|
f"Generated upload URL for document {document_uuid}, "
|
|
f"user {user.id}, org {user.selected_organization_id}"
|
|
)
|
|
|
|
return DocumentUploadResponseSchema(
|
|
upload_url=upload_url,
|
|
document_uuid=document_uuid,
|
|
s3_key=s3_key,
|
|
)
|
|
|
|
except Exception as exc:
|
|
logger.error(f"Error generating upload URL: {exc}")
|
|
raise HTTPException(
|
|
status_code=500, detail="Failed to generate upload URL"
|
|
) from exc
|
|
|
|
|
|
@router.post(
|
|
"/process-document",
|
|
response_model=DocumentResponseSchema,
|
|
summary="Trigger document processing",
|
|
)
|
|
async def process_document(
|
|
request: ProcessDocumentRequestSchema,
|
|
user=Depends(get_user),
|
|
):
|
|
"""Trigger asynchronous processing of an uploaded document.
|
|
|
|
This endpoint should be called after successfully uploading a file to the presigned URL.
|
|
It will:
|
|
1. Create a document record in the database with the specified UUID
|
|
2. Enqueue a background task to process the document (chunking and embedding)
|
|
|
|
The document status will be updated from 'pending' -> 'processing' -> 'completed' or 'failed'.
|
|
|
|
Embedding Services:
|
|
* openai (default): High-quality 1536-dimensional embeddings (requires OPENAI_API_KEY)
|
|
* sentence_transformer: Free, offline-capable, 384-dimensional embeddings
|
|
|
|
Access Control:
|
|
* Users can only process documents in their organization.
|
|
"""
|
|
|
|
try:
|
|
# Extract filename from s3_key
|
|
filename = request.s3_key.split("/")[-1]
|
|
|
|
# Create document record with the specific UUID from upload
|
|
document = await db_client.create_document(
|
|
organization_id=user.selected_organization_id,
|
|
created_by=user.id,
|
|
filename=filename,
|
|
file_size_bytes=0, # Will be updated by background task
|
|
file_hash="", # Will be computed by background task
|
|
mime_type="application/octet-stream", # Will be detected by background task
|
|
custom_metadata={"s3_key": request.s3_key},
|
|
document_uuid=request.document_uuid, # Use UUID from upload
|
|
)
|
|
|
|
# Enqueue background task for processing
|
|
await enqueue_job(
|
|
FunctionNames.PROCESS_KNOWLEDGE_BASE_DOCUMENT,
|
|
document.id,
|
|
request.s3_key,
|
|
user.selected_organization_id,
|
|
128, # max_tokens (default)
|
|
request.embedding_service,
|
|
)
|
|
|
|
logger.info(
|
|
f"Created document {request.document_uuid} (id={document.id}) and enqueued processing "
|
|
f"with {request.embedding_service} embeddings, org {user.selected_organization_id}"
|
|
)
|
|
|
|
return DocumentResponseSchema(
|
|
id=document.id,
|
|
document_uuid=request.document_uuid,
|
|
filename=filename,
|
|
file_size_bytes=0,
|
|
file_hash="",
|
|
mime_type="application/octet-stream",
|
|
processing_status="pending",
|
|
processing_error=None,
|
|
total_chunks=0,
|
|
custom_metadata={"s3_key": request.s3_key},
|
|
docling_metadata={},
|
|
source_url=None,
|
|
created_at=document.created_at,
|
|
updated_at=document.updated_at,
|
|
organization_id=user.selected_organization_id,
|
|
created_by=user.id,
|
|
is_active=True,
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as exc:
|
|
logger.error(f"Error processing document: {exc}")
|
|
raise HTTPException(
|
|
status_code=500, detail="Failed to process document"
|
|
) from exc
|
|
|
|
|
|
@router.get(
|
|
"/documents",
|
|
response_model=DocumentListResponseSchema,
|
|
summary="List documents",
|
|
)
|
|
async def list_documents(
|
|
status: Annotated[
|
|
Optional[str],
|
|
Query(description="Filter by processing status"),
|
|
] = None,
|
|
limit: Annotated[int, Query(ge=1, le=100)] = 100,
|
|
offset: Annotated[int, Query(ge=0)] = 0,
|
|
user=Depends(get_user),
|
|
):
|
|
"""List all documents for the user's organization.
|
|
|
|
Access Control:
|
|
* Users can only see documents from their organization.
|
|
"""
|
|
|
|
try:
|
|
documents = await db_client.get_documents_for_organization(
|
|
organization_id=user.selected_organization_id,
|
|
processing_status=status,
|
|
limit=limit,
|
|
offset=offset,
|
|
)
|
|
|
|
# Convert to response schema
|
|
document_list = [
|
|
DocumentResponseSchema(
|
|
id=doc.id,
|
|
document_uuid=doc.document_uuid,
|
|
filename=doc.filename,
|
|
file_size_bytes=doc.file_size_bytes,
|
|
file_hash=doc.file_hash,
|
|
mime_type=doc.mime_type,
|
|
processing_status=doc.processing_status,
|
|
processing_error=doc.processing_error,
|
|
total_chunks=doc.total_chunks,
|
|
custom_metadata=doc.custom_metadata,
|
|
docling_metadata=doc.docling_metadata,
|
|
source_url=doc.source_url,
|
|
created_at=doc.created_at,
|
|
updated_at=doc.updated_at,
|
|
organization_id=doc.organization_id,
|
|
created_by=doc.created_by,
|
|
is_active=doc.is_active,
|
|
)
|
|
for doc in documents
|
|
]
|
|
|
|
return DocumentListResponseSchema(
|
|
documents=document_list,
|
|
total=len(document_list),
|
|
limit=limit,
|
|
offset=offset,
|
|
)
|
|
|
|
except Exception as exc:
|
|
logger.error(f"Error listing documents: {exc}")
|
|
raise HTTPException(status_code=500, detail="Failed to list documents") from exc
|
|
|
|
|
|
@router.get(
|
|
"/documents/{document_uuid}",
|
|
response_model=DocumentResponseSchema,
|
|
summary="Get document details",
|
|
)
|
|
async def get_document(
|
|
document_uuid: str,
|
|
user=Depends(get_user),
|
|
):
|
|
"""Get details of a specific document.
|
|
|
|
Access Control:
|
|
* Users can only access documents from their organization.
|
|
"""
|
|
|
|
try:
|
|
document = await db_client.get_document_by_uuid(
|
|
document_uuid=document_uuid,
|
|
organization_id=user.selected_organization_id,
|
|
)
|
|
|
|
if not document:
|
|
raise HTTPException(status_code=404, detail="Document not found")
|
|
|
|
return DocumentResponseSchema(
|
|
id=document.id,
|
|
document_uuid=document.document_uuid,
|
|
filename=document.filename,
|
|
file_size_bytes=document.file_size_bytes,
|
|
file_hash=document.file_hash,
|
|
mime_type=document.mime_type,
|
|
processing_status=document.processing_status,
|
|
processing_error=document.processing_error,
|
|
total_chunks=document.total_chunks,
|
|
custom_metadata=document.custom_metadata,
|
|
docling_metadata=document.docling_metadata,
|
|
source_url=document.source_url,
|
|
created_at=document.created_at,
|
|
updated_at=document.updated_at,
|
|
organization_id=document.organization_id,
|
|
created_by=document.created_by,
|
|
is_active=document.is_active,
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as exc:
|
|
logger.error(f"Error getting document: {exc}")
|
|
raise HTTPException(status_code=500, detail="Failed to get document") from exc
|
|
|
|
|
|
@router.delete(
|
|
"/documents/{document_uuid}",
|
|
summary="Delete document",
|
|
)
|
|
async def delete_document(
|
|
document_uuid: str,
|
|
user=Depends(get_user),
|
|
):
|
|
"""Soft delete a document and its chunks.
|
|
|
|
Access Control:
|
|
* Users can only delete documents from their organization.
|
|
"""
|
|
|
|
try:
|
|
success = await db_client.delete_document(
|
|
document_uuid=document_uuid,
|
|
organization_id=user.selected_organization_id,
|
|
)
|
|
|
|
if not success:
|
|
raise HTTPException(status_code=404, detail="Document not found")
|
|
|
|
logger.info(
|
|
f"Deleted document {document_uuid}, "
|
|
f"user {user.id}, org {user.selected_organization_id}"
|
|
)
|
|
|
|
return {"success": True, "message": "Document deleted successfully"}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as exc:
|
|
logger.error(f"Error deleting document: {exc}")
|
|
raise HTTPException(
|
|
status_code=500, detail="Failed to delete document"
|
|
) from exc
|
|
|
|
|
|
@router.post(
|
|
"/search",
|
|
response_model=ChunkSearchResponseSchema,
|
|
summary="Search for similar chunks",
|
|
)
|
|
async def search_chunks(
|
|
request: ChunkSearchRequestSchema,
|
|
user=Depends(get_user),
|
|
):
|
|
"""Search for document chunks similar to the query.
|
|
|
|
This endpoint uses vector similarity search to find relevant chunks.
|
|
Results are returned without threshold filtering - apply similarity
|
|
thresholds at the application layer after optional reranking.
|
|
|
|
Access Control:
|
|
* Users can only search documents from their organization.
|
|
"""
|
|
|
|
try:
|
|
# Import here to avoid circular dependency
|
|
from api.services.gen_ai import OpenAIEmbeddingService
|
|
|
|
# Try to get user's embeddings configuration
|
|
user_config = await db_client.get_user_configurations(user.id)
|
|
embeddings_api_key = None
|
|
embeddings_model = None
|
|
|
|
if user_config.embeddings:
|
|
embeddings_api_key = user_config.embeddings.api_key
|
|
embeddings_model = user_config.embeddings.model
|
|
|
|
# Initialize embedding service with user config or fallback to env
|
|
embedding_service = OpenAIEmbeddingService(
|
|
db_client=db_client,
|
|
api_key=embeddings_api_key,
|
|
model_id=embeddings_model or "text-embedding-3-small",
|
|
)
|
|
|
|
# Perform search
|
|
results = await embedding_service.search_similar_chunks(
|
|
query=request.query,
|
|
organization_id=user.selected_organization_id,
|
|
limit=request.limit,
|
|
document_uuids=request.document_uuids,
|
|
)
|
|
|
|
# Apply similarity threshold if provided
|
|
if request.min_similarity is not None:
|
|
results = [r for r in results if r["similarity"] >= request.min_similarity]
|
|
|
|
# Convert to response schema
|
|
from api.schemas.knowledge_base import ChunkResponseSchema
|
|
|
|
chunks = [
|
|
ChunkResponseSchema(
|
|
id=r["id"],
|
|
document_id=r["document_id"],
|
|
chunk_text=r["chunk_text"],
|
|
contextualized_text=r.get("contextualized_text"),
|
|
chunk_index=r["chunk_index"],
|
|
chunk_metadata=r["chunk_metadata"],
|
|
filename=r["filename"],
|
|
document_uuid=r["document_uuid"],
|
|
similarity=r["similarity"],
|
|
)
|
|
for r in results
|
|
]
|
|
|
|
return ChunkSearchResponseSchema(
|
|
chunks=chunks,
|
|
query=request.query,
|
|
total_results=len(chunks),
|
|
)
|
|
|
|
except Exception as exc:
|
|
logger.error(f"Error searching chunks: {exc}")
|
|
raise HTTPException(status_code=500, detail="Failed to search chunks") from exc
|