mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-16 08:25:18 +02:00
feat: knowledge base functionality for the voice agent (#120)
* feat: upload file and store embedding * feat: add documents in nodes * feat: add openai embedding service
This commit is contained in:
parent
e2fa4bbb98
commit
ef5b9e40a9
52 changed files with 4551 additions and 114 deletions
405
api/routes/knowledge_base.py
Normal file
405
api/routes/knowledge_base.py
Normal file
|
|
@ -0,0 +1,405 @@
|
|||
"""API routes for knowledge base operations."""
|
||||
|
||||
import uuid
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.schemas.knowledge_base import (
|
||||
ChunkSearchRequestSchema,
|
||||
ChunkSearchResponseSchema,
|
||||
DocumentListResponseSchema,
|
||||
DocumentResponseSchema,
|
||||
DocumentUploadRequestSchema,
|
||||
DocumentUploadResponseSchema,
|
||||
ProcessDocumentRequestSchema,
|
||||
)
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.storage import storage_fs
|
||||
from api.tasks.arq import enqueue_job
|
||||
from api.tasks.function_names import FunctionNames
|
||||
|
||||
router = APIRouter(prefix="/knowledge-base", tags=["knowledge-base"])
|
||||
|
||||
|
||||
@router.post(
|
||||
"/upload-url",
|
||||
response_model=DocumentUploadResponseSchema,
|
||||
summary="Get presigned URL for document upload",
|
||||
)
|
||||
async def get_upload_url(
|
||||
request: DocumentUploadRequestSchema,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Generate a presigned PUT URL for uploading a document.
|
||||
|
||||
This endpoint:
|
||||
1. Generates a unique document UUID for organizing the S3 key
|
||||
2. Generates a presigned S3/MinIO URL for uploading the file
|
||||
3. Returns the upload URL and document metadata
|
||||
|
||||
After uploading to the returned URL, call /process-document to create
|
||||
the document record and trigger processing.
|
||||
|
||||
Access Control:
|
||||
* All authenticated users can upload documents scoped to their organization.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Generate unique document UUID for S3 organization
|
||||
document_uuid = str(uuid.uuid4())
|
||||
|
||||
# Generate S3 key: knowledge_base/{org_id}/{document_uuid}/{filename}
|
||||
s3_key = f"knowledge_base/{user.selected_organization_id}/{document_uuid}/{request.filename}"
|
||||
|
||||
# Generate presigned PUT URL (valid for 30 minutes)
|
||||
upload_url = await storage_fs.aget_presigned_put_url(
|
||||
file_path=s3_key,
|
||||
expiration=1800, # 30 minutes
|
||||
content_type=request.mime_type,
|
||||
max_size=100_000_000, # 100MB max
|
||||
)
|
||||
|
||||
if not upload_url:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to generate presigned upload URL"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Generated upload URL for document {document_uuid}, "
|
||||
f"user {user.id}, org {user.selected_organization_id}"
|
||||
)
|
||||
|
||||
return DocumentUploadResponseSchema(
|
||||
upload_url=upload_url,
|
||||
document_uuid=document_uuid,
|
||||
s3_key=s3_key,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error generating upload URL: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to generate upload URL"
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/process-document",
|
||||
response_model=DocumentResponseSchema,
|
||||
summary="Trigger document processing",
|
||||
)
|
||||
async def process_document(
|
||||
request: ProcessDocumentRequestSchema,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Trigger asynchronous processing of an uploaded document.
|
||||
|
||||
This endpoint should be called after successfully uploading a file to the presigned URL.
|
||||
It will:
|
||||
1. Create a document record in the database with the specified UUID
|
||||
2. Enqueue a background task to process the document (chunking and embedding)
|
||||
|
||||
The document status will be updated from 'pending' -> 'processing' -> 'completed' or 'failed'.
|
||||
|
||||
Embedding Services:
|
||||
* openai (default): High-quality 1536-dimensional embeddings (requires OPENAI_API_KEY)
|
||||
* sentence_transformer: Free, offline-capable, 384-dimensional embeddings
|
||||
|
||||
Access Control:
|
||||
* Users can only process documents in their organization.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Extract filename from s3_key
|
||||
filename = request.s3_key.split("/")[-1]
|
||||
|
||||
# Create document record with the specific UUID from upload
|
||||
document = await db_client.create_document(
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.id,
|
||||
filename=filename,
|
||||
file_size_bytes=0, # Will be updated by background task
|
||||
file_hash="", # Will be computed by background task
|
||||
mime_type="application/octet-stream", # Will be detected by background task
|
||||
custom_metadata={"s3_key": request.s3_key},
|
||||
document_uuid=request.document_uuid, # Use UUID from upload
|
||||
)
|
||||
|
||||
# Enqueue background task for processing
|
||||
await enqueue_job(
|
||||
FunctionNames.PROCESS_KNOWLEDGE_BASE_DOCUMENT,
|
||||
document.id,
|
||||
request.s3_key,
|
||||
user.selected_organization_id,
|
||||
128, # max_tokens (default)
|
||||
request.embedding_service,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created document {request.document_uuid} (id={document.id}) and enqueued processing "
|
||||
f"with {request.embedding_service} embeddings, org {user.selected_organization_id}"
|
||||
)
|
||||
|
||||
return DocumentResponseSchema(
|
||||
id=document.id,
|
||||
document_uuid=request.document_uuid,
|
||||
filename=filename,
|
||||
file_size_bytes=0,
|
||||
file_hash="",
|
||||
mime_type="application/octet-stream",
|
||||
processing_status="pending",
|
||||
processing_error=None,
|
||||
total_chunks=0,
|
||||
custom_metadata={"s3_key": request.s3_key},
|
||||
docling_metadata={},
|
||||
source_url=None,
|
||||
created_at=document.created_at,
|
||||
updated_at=document.updated_at,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.id,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Error processing document: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to process document"
|
||||
) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/documents",
|
||||
response_model=DocumentListResponseSchema,
|
||||
summary="List documents",
|
||||
)
|
||||
async def list_documents(
|
||||
status: Annotated[
|
||||
Optional[str],
|
||||
Query(description="Filter by processing status"),
|
||||
] = None,
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 100,
|
||||
offset: Annotated[int, Query(ge=0)] = 0,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""List all documents for the user's organization.
|
||||
|
||||
Access Control:
|
||||
* Users can only see documents from their organization.
|
||||
"""
|
||||
|
||||
try:
|
||||
documents = await db_client.get_documents_for_organization(
|
||||
organization_id=user.selected_organization_id,
|
||||
processing_status=status,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
# Convert to response schema
|
||||
document_list = [
|
||||
DocumentResponseSchema(
|
||||
id=doc.id,
|
||||
document_uuid=doc.document_uuid,
|
||||
filename=doc.filename,
|
||||
file_size_bytes=doc.file_size_bytes,
|
||||
file_hash=doc.file_hash,
|
||||
mime_type=doc.mime_type,
|
||||
processing_status=doc.processing_status,
|
||||
processing_error=doc.processing_error,
|
||||
total_chunks=doc.total_chunks,
|
||||
custom_metadata=doc.custom_metadata,
|
||||
docling_metadata=doc.docling_metadata,
|
||||
source_url=doc.source_url,
|
||||
created_at=doc.created_at,
|
||||
updated_at=doc.updated_at,
|
||||
organization_id=doc.organization_id,
|
||||
created_by=doc.created_by,
|
||||
is_active=doc.is_active,
|
||||
)
|
||||
for doc in documents
|
||||
]
|
||||
|
||||
return DocumentListResponseSchema(
|
||||
documents=document_list,
|
||||
total=len(document_list),
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error listing documents: {exc}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list documents") from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/documents/{document_uuid}",
|
||||
response_model=DocumentResponseSchema,
|
||||
summary="Get document details",
|
||||
)
|
||||
async def get_document(
|
||||
document_uuid: str,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Get details of a specific document.
|
||||
|
||||
Access Control:
|
||||
* Users can only access documents from their organization.
|
||||
"""
|
||||
|
||||
try:
|
||||
document = await db_client.get_document_by_uuid(
|
||||
document_uuid=document_uuid,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
|
||||
if not document:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
return DocumentResponseSchema(
|
||||
id=document.id,
|
||||
document_uuid=document.document_uuid,
|
||||
filename=document.filename,
|
||||
file_size_bytes=document.file_size_bytes,
|
||||
file_hash=document.file_hash,
|
||||
mime_type=document.mime_type,
|
||||
processing_status=document.processing_status,
|
||||
processing_error=document.processing_error,
|
||||
total_chunks=document.total_chunks,
|
||||
custom_metadata=document.custom_metadata,
|
||||
docling_metadata=document.docling_metadata,
|
||||
source_url=document.source_url,
|
||||
created_at=document.created_at,
|
||||
updated_at=document.updated_at,
|
||||
organization_id=document.organization_id,
|
||||
created_by=document.created_by,
|
||||
is_active=document.is_active,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Error getting document: {exc}")
|
||||
raise HTTPException(status_code=500, detail="Failed to get document") from exc
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/documents/{document_uuid}",
|
||||
summary="Delete document",
|
||||
)
|
||||
async def delete_document(
|
||||
document_uuid: str,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Soft delete a document and its chunks.
|
||||
|
||||
Access Control:
|
||||
* Users can only delete documents from their organization.
|
||||
"""
|
||||
|
||||
try:
|
||||
success = await db_client.delete_document(
|
||||
document_uuid=document_uuid,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
logger.info(
|
||||
f"Deleted document {document_uuid}, "
|
||||
f"user {user.id}, org {user.selected_organization_id}"
|
||||
)
|
||||
|
||||
return {"success": True, "message": "Document deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Error deleting document: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to delete document"
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/search",
|
||||
response_model=ChunkSearchResponseSchema,
|
||||
summary="Search for similar chunks",
|
||||
)
|
||||
async def search_chunks(
|
||||
request: ChunkSearchRequestSchema,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Search for document chunks similar to the query.
|
||||
|
||||
This endpoint uses vector similarity search to find relevant chunks.
|
||||
Results are returned without threshold filtering - apply similarity
|
||||
thresholds at the application layer after optional reranking.
|
||||
|
||||
Access Control:
|
||||
* Users can only search documents from their organization.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Import here to avoid circular dependency
|
||||
from api.services.gen_ai import OpenAIEmbeddingService
|
||||
|
||||
# Try to get user's embeddings configuration
|
||||
user_config = await db_client.get_user_configurations(user.id)
|
||||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
|
||||
if user_config.embeddings:
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
|
||||
# Initialize embedding service with user config or fallback to env
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
)
|
||||
|
||||
# Perform search
|
||||
results = await embedding_service.search_similar_chunks(
|
||||
query=request.query,
|
||||
organization_id=user.selected_organization_id,
|
||||
limit=request.limit,
|
||||
document_uuids=request.document_uuids,
|
||||
)
|
||||
|
||||
# Apply similarity threshold if provided
|
||||
if request.min_similarity is not None:
|
||||
results = [r for r in results if r["similarity"] >= request.min_similarity]
|
||||
|
||||
# Convert to response schema
|
||||
from api.schemas.knowledge_base import ChunkResponseSchema
|
||||
|
||||
chunks = [
|
||||
ChunkResponseSchema(
|
||||
id=r["id"],
|
||||
document_id=r["document_id"],
|
||||
chunk_text=r["chunk_text"],
|
||||
contextualized_text=r.get("contextualized_text"),
|
||||
chunk_index=r["chunk_index"],
|
||||
chunk_metadata=r["chunk_metadata"],
|
||||
filename=r["filename"],
|
||||
document_uuid=r["document_uuid"],
|
||||
similarity=r["similarity"],
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
|
||||
return ChunkSearchResponseSchema(
|
||||
chunks=chunks,
|
||||
query=request.query,
|
||||
total_results=len(chunks),
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error searching chunks: {exc}")
|
||||
raise HTTPException(status_code=500, detail="Failed to search chunks") from exc
|
||||
Loading…
Add table
Add a link
Reference in a new issue