mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
Add a worker sync event so that runtime updates on one worker can propagate across other workers using pubsub for multi worker deployments
500 lines
17 KiB
Python
500 lines
17 KiB
Python
"""Database client for managing knowledge base documents and chunks."""
|
|
|
|
import hashlib
|
|
from pathlib import Path
|
|
from typing import List, Optional
|
|
|
|
from loguru import logger
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from api.db.base_client import BaseDBClient
|
|
from api.db.models import KnowledgeBaseChunkModel, KnowledgeBaseDocumentModel
|
|
|
|
|
|
class KnowledgeBaseClient(BaseDBClient):
|
|
"""Client for managing knowledge base documents and vector embeddings."""
|
|
|
|
async def create_document(
|
|
self,
|
|
organization_id: int,
|
|
created_by: int,
|
|
filename: str,
|
|
file_size_bytes: int,
|
|
file_hash: str,
|
|
mime_type: str,
|
|
source_url: Optional[str] = None,
|
|
custom_metadata: Optional[dict] = None,
|
|
docling_metadata: Optional[dict] = None,
|
|
document_uuid: Optional[str] = None,
|
|
) -> KnowledgeBaseDocumentModel:
|
|
"""Create a new knowledge base document record.
|
|
|
|
Args:
|
|
organization_id: ID of the organization
|
|
created_by: ID of the user uploading the document
|
|
filename: Name of the file
|
|
file_size_bytes: Size of the file in bytes
|
|
file_hash: SHA-256 hash of the file
|
|
mime_type: MIME type of the file
|
|
source_url: Optional URL if document was fetched from web
|
|
custom_metadata: Optional custom metadata dictionary
|
|
docling_metadata: Optional docling processing metadata
|
|
document_uuid: Optional UUID to use (if not provided, one will be generated)
|
|
|
|
Returns:
|
|
The created KnowledgeBaseDocumentModel
|
|
"""
|
|
async with self.async_session() as session:
|
|
document = KnowledgeBaseDocumentModel(
|
|
organization_id=organization_id,
|
|
created_by=created_by,
|
|
filename=filename,
|
|
file_size_bytes=file_size_bytes,
|
|
file_hash=file_hash,
|
|
mime_type=mime_type,
|
|
source_url=source_url,
|
|
custom_metadata=custom_metadata or {},
|
|
docling_metadata=docling_metadata or {},
|
|
processing_status="pending",
|
|
total_chunks=0,
|
|
)
|
|
|
|
# Use provided UUID or let the model generate one
|
|
if document_uuid:
|
|
document.document_uuid = document_uuid
|
|
|
|
session.add(document)
|
|
await session.commit()
|
|
await session.refresh(document)
|
|
|
|
logger.info(
|
|
f"Created document '{filename}' ({document.document_uuid}) "
|
|
f"for organization {organization_id}"
|
|
)
|
|
return document
|
|
|
|
async def get_document_by_id(
|
|
self,
|
|
document_id: int,
|
|
) -> Optional[KnowledgeBaseDocumentModel]:
|
|
"""Get a document by its database ID.
|
|
|
|
Args:
|
|
document_id: The database ID of the document
|
|
|
|
Returns:
|
|
KnowledgeBaseDocumentModel if found, None otherwise
|
|
"""
|
|
async with self.async_session() as session:
|
|
query = select(KnowledgeBaseDocumentModel).where(
|
|
KnowledgeBaseDocumentModel.id == document_id
|
|
)
|
|
|
|
result = await session.execute(query)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_document_by_uuid(
|
|
self,
|
|
document_uuid: str,
|
|
organization_id: int,
|
|
) -> Optional[KnowledgeBaseDocumentModel]:
|
|
"""Get a document by its UUID, scoped to organization.
|
|
|
|
Args:
|
|
document_uuid: The unique document UUID
|
|
organization_id: ID of the organization
|
|
|
|
Returns:
|
|
KnowledgeBaseDocumentModel if found, None otherwise
|
|
"""
|
|
async with self.async_session() as session:
|
|
query = (
|
|
select(KnowledgeBaseDocumentModel)
|
|
.where(
|
|
KnowledgeBaseDocumentModel.document_uuid == document_uuid,
|
|
KnowledgeBaseDocumentModel.organization_id == organization_id,
|
|
KnowledgeBaseDocumentModel.is_active == True,
|
|
)
|
|
.options(selectinload(KnowledgeBaseDocumentModel.created_by_user))
|
|
)
|
|
|
|
result = await session.execute(query)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_document_by_hash(
|
|
self,
|
|
file_hash: str,
|
|
organization_id: int,
|
|
) -> Optional[KnowledgeBaseDocumentModel]:
|
|
"""Check if a document with the same hash already exists.
|
|
|
|
Returns the first matching document if multiple exist (can happen with duplicates).
|
|
|
|
Args:
|
|
file_hash: SHA-256 hash of the file
|
|
organization_id: ID of the organization
|
|
|
|
Returns:
|
|
KnowledgeBaseDocumentModel if found, None otherwise
|
|
"""
|
|
async with self.async_session() as session:
|
|
query = (
|
|
select(KnowledgeBaseDocumentModel)
|
|
.where(
|
|
KnowledgeBaseDocumentModel.file_hash == file_hash,
|
|
KnowledgeBaseDocumentModel.organization_id == organization_id,
|
|
KnowledgeBaseDocumentModel.is_active == True,
|
|
)
|
|
.order_by(KnowledgeBaseDocumentModel.created_at.asc())
|
|
.limit(1)
|
|
)
|
|
|
|
result = await session.execute(query)
|
|
return result.scalars().first()
|
|
|
|
async def get_documents_for_organization(
|
|
self,
|
|
organization_id: int,
|
|
processing_status: Optional[str] = None,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
) -> List[KnowledgeBaseDocumentModel]:
|
|
"""Get all documents for an organization.
|
|
|
|
Args:
|
|
organization_id: ID of the organization
|
|
processing_status: Optional filter by status
|
|
limit: Maximum number of documents to return
|
|
offset: Number of documents to skip
|
|
|
|
Returns:
|
|
List of KnowledgeBaseDocumentModel instances
|
|
"""
|
|
async with self.async_session() as session:
|
|
query = select(KnowledgeBaseDocumentModel).where(
|
|
KnowledgeBaseDocumentModel.organization_id == organization_id,
|
|
KnowledgeBaseDocumentModel.is_active == True,
|
|
)
|
|
|
|
if processing_status:
|
|
query = query.where(
|
|
KnowledgeBaseDocumentModel.processing_status == processing_status
|
|
)
|
|
|
|
query = (
|
|
query.order_by(KnowledgeBaseDocumentModel.created_at.desc())
|
|
.limit(limit)
|
|
.offset(offset)
|
|
)
|
|
|
|
result = await session.execute(query)
|
|
return list(result.scalars().all())
|
|
|
|
async def update_document_metadata(
|
|
self,
|
|
document_id: int,
|
|
file_size_bytes: Optional[int] = None,
|
|
file_hash: Optional[str] = None,
|
|
mime_type: Optional[str] = None,
|
|
) -> Optional[KnowledgeBaseDocumentModel]:
|
|
"""Update document file metadata.
|
|
|
|
Args:
|
|
document_id: ID of the document
|
|
file_size_bytes: Optional file size in bytes
|
|
file_hash: Optional SHA-256 hash of the file
|
|
mime_type: Optional MIME type
|
|
|
|
Returns:
|
|
Updated KnowledgeBaseDocumentModel
|
|
"""
|
|
async with self.async_session() as session:
|
|
query = select(KnowledgeBaseDocumentModel).where(
|
|
KnowledgeBaseDocumentModel.id == document_id
|
|
)
|
|
result = await session.execute(query)
|
|
document = result.scalar_one_or_none()
|
|
|
|
if not document:
|
|
return None
|
|
|
|
if file_size_bytes is not None:
|
|
document.file_size_bytes = file_size_bytes
|
|
if file_hash is not None:
|
|
document.file_hash = file_hash
|
|
if mime_type is not None:
|
|
document.mime_type = mime_type
|
|
|
|
await session.commit()
|
|
await session.refresh(document)
|
|
|
|
logger.info(f"Updated document {document_id} metadata")
|
|
return document
|
|
|
|
async def update_document_status(
|
|
self,
|
|
document_id: int,
|
|
status: str,
|
|
error_message: Optional[str] = None,
|
|
total_chunks: Optional[int] = None,
|
|
docling_metadata: Optional[dict] = None,
|
|
) -> Optional[KnowledgeBaseDocumentModel]:
|
|
"""Update document processing status.
|
|
|
|
Args:
|
|
document_id: ID of the document
|
|
status: New status (pending, processing, completed, failed)
|
|
error_message: Optional error message if status is failed
|
|
total_chunks: Optional total number of chunks
|
|
docling_metadata: Optional docling metadata
|
|
|
|
Returns:
|
|
Updated KnowledgeBaseDocumentModel
|
|
"""
|
|
async with self.async_session() as session:
|
|
query = select(KnowledgeBaseDocumentModel).where(
|
|
KnowledgeBaseDocumentModel.id == document_id
|
|
)
|
|
result = await session.execute(query)
|
|
document = result.scalar_one_or_none()
|
|
|
|
if not document:
|
|
return None
|
|
|
|
document.processing_status = status
|
|
if error_message:
|
|
document.processing_error = error_message
|
|
if total_chunks is not None:
|
|
document.total_chunks = total_chunks
|
|
if docling_metadata:
|
|
document.docling_metadata = docling_metadata
|
|
|
|
await session.commit()
|
|
await session.refresh(document)
|
|
|
|
logger.info(f"Updated document {document_id} status to {status}")
|
|
return document
|
|
|
|
async def create_chunks_batch(
|
|
self,
|
|
chunks: List[KnowledgeBaseChunkModel],
|
|
) -> List[KnowledgeBaseChunkModel]:
|
|
"""Create multiple chunks in a batch.
|
|
|
|
Args:
|
|
chunks: List of KnowledgeBaseChunkModel instances
|
|
|
|
Returns:
|
|
List of created chunks with IDs
|
|
"""
|
|
async with self.async_session() as session:
|
|
session.add_all(chunks)
|
|
await session.commit()
|
|
|
|
for chunk in chunks:
|
|
await session.refresh(chunk)
|
|
|
|
logger.info(f"Created {len(chunks)} chunks")
|
|
return chunks
|
|
|
|
async def get_chunks_for_document(
|
|
self,
|
|
document_id: int,
|
|
organization_id: int,
|
|
) -> List[KnowledgeBaseChunkModel]:
|
|
"""Get all chunks for a document.
|
|
|
|
Args:
|
|
document_id: ID of the document
|
|
organization_id: ID of the organization (for authorization)
|
|
|
|
Returns:
|
|
List of KnowledgeBaseChunkModel instances
|
|
"""
|
|
async with self.async_session() as session:
|
|
query = (
|
|
select(KnowledgeBaseChunkModel)
|
|
.where(
|
|
KnowledgeBaseChunkModel.document_id == document_id,
|
|
KnowledgeBaseChunkModel.organization_id == organization_id,
|
|
)
|
|
.order_by(KnowledgeBaseChunkModel.chunk_index)
|
|
)
|
|
|
|
result = await session.execute(query)
|
|
return list(result.scalars().all())
|
|
|
|
async def search_similar_chunks(
|
|
self,
|
|
query_embedding: List[float],
|
|
organization_id: int,
|
|
limit: int = 5,
|
|
document_ids: Optional[List[int]] = None,
|
|
document_uuids: Optional[List[str]] = None,
|
|
embedding_model: Optional[str] = None,
|
|
) -> List[dict]:
|
|
"""Search for similar chunks using vector similarity.
|
|
|
|
Returns top-k most similar chunks without any similarity threshold filtering.
|
|
Filtering and reranking should be done at the application layer.
|
|
|
|
Args:
|
|
query_embedding: The query embedding vector
|
|
organization_id: Organization ID for scoping
|
|
limit: Maximum number of results to return
|
|
document_ids: Optional list of document IDs to filter by
|
|
document_uuids: Optional list of document UUIDs to filter by
|
|
embedding_model: Optional embedding model to filter by (for dimension compatibility)
|
|
|
|
Returns:
|
|
List of dictionaries with chunk data and similarity scores, ordered by similarity (highest first)
|
|
"""
|
|
async with self.async_session() as session:
|
|
# Get the raw connection to execute directly with asyncpg
|
|
# This avoids parameter binding issues with text() and asyncpg
|
|
connection = await session.connection()
|
|
raw_connection = await connection.get_raw_connection()
|
|
|
|
# Build WHERE clause conditions (no similarity threshold)
|
|
where_conditions = [
|
|
"c.organization_id = $2",
|
|
"d.is_active = true",
|
|
]
|
|
params = [
|
|
None,
|
|
organization_id,
|
|
limit,
|
|
] # $1 will be embedding_str, $3 is limit
|
|
param_index = 4 # Next available parameter index
|
|
|
|
# Add document_ids filter if provided
|
|
if document_ids:
|
|
placeholders = ", ".join(
|
|
f"${param_index + i}" for i in range(len(document_ids))
|
|
)
|
|
where_conditions.append(f"c.document_id IN ({placeholders})")
|
|
params.extend(document_ids)
|
|
param_index += len(document_ids)
|
|
|
|
# Add document_uuids filter if provided
|
|
if document_uuids:
|
|
placeholders = ", ".join(
|
|
f"${param_index + i}" for i in range(len(document_uuids))
|
|
)
|
|
where_conditions.append(f"d.document_uuid IN ({placeholders})")
|
|
params.extend(document_uuids)
|
|
param_index += len(document_uuids)
|
|
|
|
# Add embedding_model filter if provided (for dimension compatibility)
|
|
if embedding_model:
|
|
where_conditions.append(f"c.embedding_model = ${param_index}")
|
|
params.append(embedding_model)
|
|
param_index += 1
|
|
|
|
# Build the complete SQL query
|
|
where_clause = " AND ".join(where_conditions)
|
|
query_sql = f"""
|
|
SELECT
|
|
c.id,
|
|
c.document_id,
|
|
c.chunk_text,
|
|
c.contextualized_text,
|
|
c.chunk_metadata,
|
|
c.chunk_index,
|
|
d.filename,
|
|
d.document_uuid,
|
|
1 - (c.embedding <=> $1::vector) as similarity
|
|
FROM knowledge_base_chunks c
|
|
JOIN knowledge_base_documents d ON c.document_id = d.id
|
|
WHERE {where_clause}
|
|
ORDER BY c.embedding <=> $1::vector
|
|
LIMIT $3
|
|
"""
|
|
|
|
# Convert embedding to string format for PostgreSQL vector type
|
|
embedding_str = "[" + ",".join(map(str, query_embedding)) + "]"
|
|
params[0] = embedding_str # Set $1
|
|
|
|
# Execute query directly with asyncpg
|
|
rows = await raw_connection.driver_connection.fetch(
|
|
query_sql,
|
|
*params,
|
|
)
|
|
|
|
# Convert asyncpg records to dictionaries
|
|
return [dict(row) for row in rows]
|
|
|
|
async def delete_document(
|
|
self,
|
|
document_uuid: str,
|
|
organization_id: int,
|
|
) -> bool:
|
|
"""Soft delete a document by setting is_active to False.
|
|
|
|
This will also cascade delete all chunks via the database foreign key.
|
|
|
|
Args:
|
|
document_uuid: The unique document UUID
|
|
organization_id: ID of the organization (for authorization)
|
|
|
|
Returns:
|
|
True if document was deleted, False if not found
|
|
"""
|
|
async with self.async_session() as session:
|
|
query = select(KnowledgeBaseDocumentModel).where(
|
|
KnowledgeBaseDocumentModel.document_uuid == document_uuid,
|
|
KnowledgeBaseDocumentModel.organization_id == organization_id,
|
|
)
|
|
|
|
result = await session.execute(query)
|
|
document = result.scalar_one_or_none()
|
|
|
|
if not document:
|
|
return False
|
|
|
|
document.is_active = False
|
|
await session.commit()
|
|
|
|
logger.info(
|
|
f"Deleted document {document_uuid} for organization {organization_id}"
|
|
)
|
|
return True
|
|
|
|
@staticmethod
|
|
def compute_file_hash(file_path: str) -> str:
|
|
"""Compute SHA-256 hash of a file.
|
|
|
|
Args:
|
|
file_path: Path to the file
|
|
|
|
Returns:
|
|
SHA-256 hash as hex string
|
|
"""
|
|
sha256_hash = hashlib.sha256()
|
|
with open(file_path, "rb") as f:
|
|
for byte_block in iter(lambda: f.read(4096), b""):
|
|
sha256_hash.update(byte_block)
|
|
return sha256_hash.hexdigest()
|
|
|
|
@staticmethod
|
|
def get_mime_type(file_path: str) -> str:
|
|
"""Get MIME type based on file extension.
|
|
|
|
Args:
|
|
file_path: Path to the file
|
|
|
|
Returns:
|
|
MIME type string
|
|
"""
|
|
extension = Path(file_path).suffix.lower()
|
|
mime_types = {
|
|
".pdf": "application/pdf",
|
|
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
".doc": "application/msword",
|
|
".txt": "text/plain",
|
|
".json": "application/json",
|
|
".html": "text/html",
|
|
".md": "text/markdown",
|
|
}
|
|
return mime_types.get(extension, "application/octet-stream")
|