dograh/api/db/knowledge_base_client.py

552 lines
19 KiB
Python
Raw Normal View History

"""Database client for managing knowledge base documents and chunks."""
import hashlib
from pathlib import Path
from typing import List, Optional
from loguru import logger
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from api.db.base_client import BaseDBClient
from api.db.models import KnowledgeBaseChunkModel, KnowledgeBaseDocumentModel
class KnowledgeBaseClient(BaseDBClient):
"""Client for managing knowledge base documents and vector embeddings."""
async def create_document(
self,
organization_id: int,
created_by: int,
filename: str,
file_size_bytes: int,
file_hash: str,
mime_type: str,
source_url: Optional[str] = None,
custom_metadata: Optional[dict] = None,
docling_metadata: Optional[dict] = None,
document_uuid: Optional[str] = None,
retrieval_mode: str = "chunked",
) -> KnowledgeBaseDocumentModel:
"""Create a new knowledge base document record.
Args:
organization_id: ID of the organization
created_by: ID of the user uploading the document
filename: Name of the file
file_size_bytes: Size of the file in bytes
file_hash: SHA-256 hash of the file
mime_type: MIME type of the file
source_url: Optional URL if document was fetched from web
custom_metadata: Optional custom metadata dictionary
docling_metadata: Optional docling processing metadata
document_uuid: Optional UUID to use (if not provided, one will be generated)
Returns:
The created KnowledgeBaseDocumentModel
"""
async with self.async_session() as session:
document = KnowledgeBaseDocumentModel(
organization_id=organization_id,
created_by=created_by,
filename=filename,
file_size_bytes=file_size_bytes,
file_hash=file_hash,
mime_type=mime_type,
source_url=source_url,
custom_metadata=custom_metadata or {},
docling_metadata=docling_metadata or {},
processing_status="pending",
total_chunks=0,
retrieval_mode=retrieval_mode,
)
# Use provided UUID or let the model generate one
if document_uuid:
document.document_uuid = document_uuid
session.add(document)
await session.commit()
await session.refresh(document)
logger.info(
f"Created document '{filename}' ({document.document_uuid}) "
f"for organization {organization_id}"
)
return document
async def get_document_by_id(
self,
document_id: int,
) -> Optional[KnowledgeBaseDocumentModel]:
"""Get a document by its database ID.
Args:
document_id: The database ID of the document
Returns:
KnowledgeBaseDocumentModel if found, None otherwise
"""
async with self.async_session() as session:
query = select(KnowledgeBaseDocumentModel).where(
KnowledgeBaseDocumentModel.id == document_id
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_document_by_uuid(
self,
document_uuid: str,
organization_id: int,
) -> Optional[KnowledgeBaseDocumentModel]:
"""Get a document by its UUID, scoped to organization.
Args:
document_uuid: The unique document UUID
organization_id: ID of the organization
Returns:
KnowledgeBaseDocumentModel if found, None otherwise
"""
async with self.async_session() as session:
query = (
select(KnowledgeBaseDocumentModel)
.where(
KnowledgeBaseDocumentModel.document_uuid == document_uuid,
KnowledgeBaseDocumentModel.organization_id == organization_id,
KnowledgeBaseDocumentModel.is_active == True,
)
.options(selectinload(KnowledgeBaseDocumentModel.created_by_user))
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_document_by_hash(
self,
file_hash: str,
organization_id: int,
) -> Optional[KnowledgeBaseDocumentModel]:
"""Check if a document with the same hash already exists.
Returns the first matching document if multiple exist (can happen with duplicates).
Args:
file_hash: SHA-256 hash of the file
organization_id: ID of the organization
Returns:
KnowledgeBaseDocumentModel if found, None otherwise
"""
async with self.async_session() as session:
query = (
select(KnowledgeBaseDocumentModel)
.where(
KnowledgeBaseDocumentModel.file_hash == file_hash,
KnowledgeBaseDocumentModel.organization_id == organization_id,
KnowledgeBaseDocumentModel.is_active == True,
)
.order_by(KnowledgeBaseDocumentModel.created_at.asc())
.limit(1)
)
result = await session.execute(query)
return result.scalars().first()
async def get_documents_for_organization(
self,
organization_id: int,
processing_status: Optional[str] = None,
limit: int = 100,
offset: int = 0,
) -> List[KnowledgeBaseDocumentModel]:
"""Get all documents for an organization.
Args:
organization_id: ID of the organization
processing_status: Optional filter by status
limit: Maximum number of documents to return
offset: Number of documents to skip
Returns:
List of KnowledgeBaseDocumentModel instances
"""
async with self.async_session() as session:
query = select(KnowledgeBaseDocumentModel).where(
KnowledgeBaseDocumentModel.organization_id == organization_id,
KnowledgeBaseDocumentModel.is_active == True,
)
if processing_status:
query = query.where(
KnowledgeBaseDocumentModel.processing_status == processing_status
)
query = (
query.order_by(KnowledgeBaseDocumentModel.created_at.desc())
.limit(limit)
.offset(offset)
)
result = await session.execute(query)
return list(result.scalars().all())
async def update_document_metadata(
self,
document_id: int,
file_size_bytes: Optional[int] = None,
file_hash: Optional[str] = None,
mime_type: Optional[str] = None,
) -> Optional[KnowledgeBaseDocumentModel]:
"""Update document file metadata.
Args:
document_id: ID of the document
file_size_bytes: Optional file size in bytes
file_hash: Optional SHA-256 hash of the file
mime_type: Optional MIME type
Returns:
Updated KnowledgeBaseDocumentModel
"""
async with self.async_session() as session:
query = select(KnowledgeBaseDocumentModel).where(
KnowledgeBaseDocumentModel.id == document_id
)
result = await session.execute(query)
document = result.scalar_one_or_none()
if not document:
return None
if file_size_bytes is not None:
document.file_size_bytes = file_size_bytes
if file_hash is not None:
document.file_hash = file_hash
if mime_type is not None:
document.mime_type = mime_type
await session.commit()
await session.refresh(document)
logger.info(f"Updated document {document_id} metadata")
return document
async def update_document_status(
self,
document_id: int,
status: str,
error_message: Optional[str] = None,
total_chunks: Optional[int] = None,
docling_metadata: Optional[dict] = None,
) -> Optional[KnowledgeBaseDocumentModel]:
"""Update document processing status.
Args:
document_id: ID of the document
status: New status (pending, processing, completed, failed)
error_message: Optional error message if status is failed
total_chunks: Optional total number of chunks
docling_metadata: Optional docling metadata
Returns:
Updated KnowledgeBaseDocumentModel
"""
async with self.async_session() as session:
query = select(KnowledgeBaseDocumentModel).where(
KnowledgeBaseDocumentModel.id == document_id
)
result = await session.execute(query)
document = result.scalar_one_or_none()
if not document:
return None
document.processing_status = status
if error_message:
document.processing_error = error_message
if total_chunks is not None:
document.total_chunks = total_chunks
if docling_metadata:
document.docling_metadata = docling_metadata
await session.commit()
await session.refresh(document)
logger.info(f"Updated document {document_id} status to {status}")
return document
async def create_chunks_batch(
self,
chunks: List[KnowledgeBaseChunkModel],
) -> List[KnowledgeBaseChunkModel]:
"""Create multiple chunks in a batch.
Args:
chunks: List of KnowledgeBaseChunkModel instances
Returns:
List of created chunks with IDs
"""
async with self.async_session() as session:
session.add_all(chunks)
await session.commit()
for chunk in chunks:
await session.refresh(chunk)
logger.info(f"Created {len(chunks)} chunks")
return chunks
async def get_chunks_for_document(
self,
document_id: int,
organization_id: int,
) -> List[KnowledgeBaseChunkModel]:
"""Get all chunks for a document.
Args:
document_id: ID of the document
organization_id: ID of the organization (for authorization)
Returns:
List of KnowledgeBaseChunkModel instances
"""
async with self.async_session() as session:
query = (
select(KnowledgeBaseChunkModel)
.where(
KnowledgeBaseChunkModel.document_id == document_id,
KnowledgeBaseChunkModel.organization_id == organization_id,
)
.order_by(KnowledgeBaseChunkModel.chunk_index)
)
result = await session.execute(query)
return list(result.scalars().all())
async def search_similar_chunks(
self,
query_embedding: List[float],
organization_id: int,
limit: int = 5,
document_ids: Optional[List[int]] = None,
document_uuids: Optional[List[str]] = None,
embedding_model: Optional[str] = None,
) -> List[dict]:
"""Search for similar chunks using vector similarity.
Returns top-k most similar chunks without any similarity threshold filtering.
Filtering and reranking should be done at the application layer.
Args:
query_embedding: The query embedding vector
organization_id: Organization ID for scoping
limit: Maximum number of results to return
document_ids: Optional list of document IDs to filter by
document_uuids: Optional list of document UUIDs to filter by
embedding_model: Optional embedding model to filter by (for dimension compatibility)
Returns:
List of dictionaries with chunk data and similarity scores, ordered by similarity (highest first)
"""
async with self.async_session() as session:
# Get the raw connection to execute directly with asyncpg
# This avoids parameter binding issues with text() and asyncpg
connection = await session.connection()
raw_connection = await connection.get_raw_connection()
# Build WHERE clause conditions (no similarity threshold)
where_conditions = [
"c.organization_id = $2",
"d.is_active = true",
]
params = [
None,
organization_id,
limit,
] # $1 will be embedding_str, $3 is limit
param_index = 4 # Next available parameter index
# Add document_ids filter if provided
if document_ids:
placeholders = ", ".join(
f"${param_index + i}" for i in range(len(document_ids))
)
where_conditions.append(f"c.document_id IN ({placeholders})")
params.extend(document_ids)
param_index += len(document_ids)
# Add document_uuids filter if provided
if document_uuids:
placeholders = ", ".join(
f"${param_index + i}" for i in range(len(document_uuids))
)
where_conditions.append(f"d.document_uuid IN ({placeholders})")
params.extend(document_uuids)
param_index += len(document_uuids)
# Add embedding_model filter if provided (for dimension compatibility)
if embedding_model:
where_conditions.append(f"c.embedding_model = ${param_index}")
params.append(embedding_model)
param_index += 1
# Build the complete SQL query
where_clause = " AND ".join(where_conditions)
query_sql = f"""
SELECT
c.id,
c.document_id,
c.chunk_text,
c.contextualized_text,
c.chunk_metadata,
c.chunk_index,
d.filename,
d.document_uuid,
1 - (c.embedding <=> $1::vector) as similarity
FROM knowledge_base_chunks c
JOIN knowledge_base_documents d ON c.document_id = d.id
WHERE {where_clause}
ORDER BY c.embedding <=> $1::vector
LIMIT $3
"""
# Convert embedding to string format for PostgreSQL vector type
embedding_str = "[" + ",".join(map(str, query_embedding)) + "]"
params[0] = embedding_str # Set $1
# Execute query directly with asyncpg
rows = await raw_connection.driver_connection.fetch(
query_sql,
*params,
)
# Convert asyncpg records to dictionaries
return [dict(row) for row in rows]
async def update_document_full_text(
self,
document_id: int,
full_text: str,
) -> None:
"""Store full document text for full_document retrieval mode.
Args:
document_id: ID of the document
full_text: The full extracted text content
"""
async with self.async_session() as session:
query = select(KnowledgeBaseDocumentModel).where(
KnowledgeBaseDocumentModel.id == document_id
)
result = await session.execute(query)
document = result.scalar_one_or_none()
if document:
document.full_text = full_text
await session.commit()
logger.info(
f"Stored full text for document {document_id} ({len(full_text)} chars)"
)
async def get_full_text_documents(
self,
organization_id: int,
document_uuids: List[str],
) -> List[KnowledgeBaseDocumentModel]:
"""Get full_document mode documents by their UUIDs.
Args:
organization_id: Organization ID for scoping
document_uuids: List of document UUIDs to fetch
Returns:
List of documents with retrieval_mode='full_document' and full_text set
"""
async with self.async_session() as session:
query = select(KnowledgeBaseDocumentModel).where(
KnowledgeBaseDocumentModel.organization_id == organization_id,
KnowledgeBaseDocumentModel.document_uuid.in_(document_uuids),
KnowledgeBaseDocumentModel.retrieval_mode == "full_document",
KnowledgeBaseDocumentModel.is_active == True,
KnowledgeBaseDocumentModel.processing_status == "completed",
)
result = await session.execute(query)
return list(result.scalars().all())
async def delete_document(
self,
document_uuid: str,
organization_id: int,
) -> bool:
"""Soft delete a document by setting is_active to False.
This will also cascade delete all chunks via the database foreign key.
Args:
document_uuid: The unique document UUID
organization_id: ID of the organization (for authorization)
Returns:
True if document was deleted, False if not found
"""
async with self.async_session() as session:
query = select(KnowledgeBaseDocumentModel).where(
KnowledgeBaseDocumentModel.document_uuid == document_uuid,
KnowledgeBaseDocumentModel.organization_id == organization_id,
)
result = await session.execute(query)
document = result.scalar_one_or_none()
if not document:
return False
document.is_active = False
await session.commit()
logger.info(
f"Deleted document {document_uuid} for organization {organization_id}"
)
return True
@staticmethod
def compute_file_hash(file_path: str) -> str:
"""Compute SHA-256 hash of a file.
Args:
file_path: Path to the file
Returns:
SHA-256 hash as hex string
"""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
@staticmethod
def get_mime_type(file_path: str) -> str:
"""Get MIME type based on file extension.
Args:
file_path: Path to the file
Returns:
MIME type string
"""
extension = Path(file_path).suffix.lower()
mime_types = {
".pdf": "application/pdf",
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
".doc": "application/msword",
".txt": "text/plain",
".json": "application/json",
".html": "text/html",
".md": "text/markdown",
}
return mime_types.get(extension, "application/octet-stream")