mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
499 lines
17 KiB
Python
499 lines
17 KiB
Python
"""Database client for managing knowledge base documents and chunks."""
|
|
|
|
import hashlib
|
|
from pathlib import Path
|
|
from typing import List, Optional
|
|
|
|
from loguru import logger
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from api.db.base_client import BaseDBClient
|
|
from api.db.models import KnowledgeBaseChunkModel, KnowledgeBaseDocumentModel
|
|
|
|
|
|
class KnowledgeBaseClient(BaseDBClient):
|
|
"""Client for managing knowledge base documents and vector embeddings."""
|
|
|
|
async def create_document(
|
|
self,
|
|
organization_id: int,
|
|
created_by: int,
|
|
filename: str,
|
|
file_size_bytes: int,
|
|
file_hash: str,
|
|
mime_type: str,
|
|
source_url: Optional[str] = None,
|
|
custom_metadata: Optional[dict] = None,
|
|
docling_metadata: Optional[dict] = None,
|
|
document_uuid: Optional[str] = None,
|
|
) -> KnowledgeBaseDocumentModel:
|
|
"""Create a new knowledge base document record.
|
|
|
|
Args:
|
|
organization_id: ID of the organization
|
|
created_by: ID of the user uploading the document
|
|
filename: Name of the file
|
|
file_size_bytes: Size of the file in bytes
|
|
file_hash: SHA-256 hash of the file
|
|
mime_type: MIME type of the file
|
|
source_url: Optional URL if document was fetched from web
|
|
custom_metadata: Optional custom metadata dictionary
|
|
docling_metadata: Optional docling processing metadata
|
|
document_uuid: Optional UUID to use (if not provided, one will be generated)
|
|
|
|
Returns:
|
|
The created KnowledgeBaseDocumentModel
|
|
"""
|
|
async with self.async_session() as session:
|
|
document = KnowledgeBaseDocumentModel(
|
|
organization_id=organization_id,
|
|
created_by=created_by,
|
|
filename=filename,
|
|
file_size_bytes=file_size_bytes,
|
|
file_hash=file_hash,
|
|
mime_type=mime_type,
|
|
source_url=source_url,
|
|
custom_metadata=custom_metadata or {},
|
|
docling_metadata=docling_metadata or {},
|
|
processing_status="pending",
|
|
total_chunks=0,
|
|
)
|
|
|
|
# Use provided UUID or let the model generate one
|
|
if document_uuid:
|
|
document.document_uuid = document_uuid
|
|
|
|
session.add(document)
|
|
await session.commit()
|
|
await session.refresh(document)
|
|
|
|
logger.info(
|
|
f"Created document '{filename}' ({document.document_uuid}) "
|
|
f"for organization {organization_id}"
|
|
)
|
|
return document
|
|
|
|
async def get_document_by_id(
|
|
self,
|
|
document_id: int,
|
|
) -> Optional[KnowledgeBaseDocumentModel]:
|
|
"""Get a document by its database ID.
|
|
|
|
Args:
|
|
document_id: The database ID of the document
|
|
|
|
Returns:
|
|
KnowledgeBaseDocumentModel if found, None otherwise
|
|
"""
|
|
async with self.async_session() as session:
|
|
query = select(KnowledgeBaseDocumentModel).where(
|
|
KnowledgeBaseDocumentModel.id == document_id
|
|
)
|
|
|
|
result = await session.execute(query)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_document_by_uuid(
|
|
self,
|
|
document_uuid: str,
|
|
organization_id: int,
|
|
) -> Optional[KnowledgeBaseDocumentModel]:
|
|
"""Get a document by its UUID, scoped to organization.
|
|
|
|
Args:
|
|
document_uuid: The unique document UUID
|
|
organization_id: ID of the organization
|
|
|
|
Returns:
|
|
KnowledgeBaseDocumentModel if found, None otherwise
|
|
"""
|
|
async with self.async_session() as session:
|
|
query = (
|
|
select(KnowledgeBaseDocumentModel)
|
|
.where(
|
|
KnowledgeBaseDocumentModel.document_uuid == document_uuid,
|
|
KnowledgeBaseDocumentModel.organization_id == organization_id,
|
|
KnowledgeBaseDocumentModel.is_active == True,
|
|
)
|
|
.options(selectinload(KnowledgeBaseDocumentModel.created_by_user))
|
|
)
|
|
|
|
result = await session.execute(query)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_document_by_hash(
|
|
self,
|
|
file_hash: str,
|
|
organization_id: int,
|
|
) -> Optional[KnowledgeBaseDocumentModel]:
|
|
"""Check if a document with the same hash already exists.
|
|
|
|
Returns the first matching document if multiple exist (can happen with duplicates).
|
|
|
|
Args:
|
|
file_hash: SHA-256 hash of the file
|
|
organization_id: ID of the organization
|
|
|
|
Returns:
|
|
KnowledgeBaseDocumentModel if found, None otherwise
|
|
"""
|
|
async with self.async_session() as session:
|
|
query = (
|
|
select(KnowledgeBaseDocumentModel)
|
|
.where(
|
|
KnowledgeBaseDocumentModel.file_hash == file_hash,
|
|
KnowledgeBaseDocumentModel.organization_id == organization_id,
|
|
KnowledgeBaseDocumentModel.is_active == True,
|
|
)
|
|
.order_by(KnowledgeBaseDocumentModel.created_at.asc())
|
|
.limit(1)
|
|
)
|
|
|
|
result = await session.execute(query)
|
|
return result.scalars().first()
|
|
|
|
async def get_documents_for_organization(
|
|
self,
|
|
organization_id: int,
|
|
processing_status: Optional[str] = None,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
) -> List[KnowledgeBaseDocumentModel]:
|
|
"""Get all documents for an organization.
|
|
|
|
Args:
|
|
organization_id: ID of the organization
|
|
processing_status: Optional filter by status
|
|
limit: Maximum number of documents to return
|
|
offset: Number of documents to skip
|
|
|
|
Returns:
|
|
List of KnowledgeBaseDocumentModel instances
|
|
"""
|
|
async with self.async_session() as session:
|
|
query = select(KnowledgeBaseDocumentModel).where(
|
|
KnowledgeBaseDocumentModel.organization_id == organization_id,
|
|
KnowledgeBaseDocumentModel.is_active == True,
|
|
)
|
|
|
|
if processing_status:
|
|
query = query.where(
|
|
KnowledgeBaseDocumentModel.processing_status == processing_status
|
|
)
|
|
|
|
query = (
|
|
query.order_by(KnowledgeBaseDocumentModel.created_at.desc())
|
|
.limit(limit)
|
|
.offset(offset)
|
|
)
|
|
|
|
result = await session.execute(query)
|
|
return list(result.scalars().all())
|
|
|
|
async def update_document_metadata(
|
|
self,
|
|
document_id: int,
|
|
file_size_bytes: Optional[int] = None,
|
|
file_hash: Optional[str] = None,
|
|
mime_type: Optional[str] = None,
|
|
) -> Optional[KnowledgeBaseDocumentModel]:
|
|
"""Update document file metadata.
|
|
|
|
Args:
|
|
document_id: ID of the document
|
|
file_size_bytes: Optional file size in bytes
|
|
file_hash: Optional SHA-256 hash of the file
|
|
mime_type: Optional MIME type
|
|
|
|
Returns:
|
|
Updated KnowledgeBaseDocumentModel
|
|
"""
|
|
async with self.async_session() as session:
|
|
query = select(KnowledgeBaseDocumentModel).where(
|
|
KnowledgeBaseDocumentModel.id == document_id
|
|
)
|
|
result = await session.execute(query)
|
|
document = result.scalar_one_or_none()
|
|
|
|
if not document:
|
|
return None
|
|
|
|
if file_size_bytes is not None:
|
|
document.file_size_bytes = file_size_bytes
|
|
if file_hash is not None:
|
|
document.file_hash = file_hash
|
|
if mime_type is not None:
|
|
document.mime_type = mime_type
|
|
|
|
await session.commit()
|
|
await session.refresh(document)
|
|
|
|
logger.info(f"Updated document {document_id} metadata")
|
|
return document
|
|
|
|
async def update_document_status(
|
|
self,
|
|
document_id: int,
|
|
status: str,
|
|
error_message: Optional[str] = None,
|
|
total_chunks: Optional[int] = None,
|
|
docling_metadata: Optional[dict] = None,
|
|
) -> Optional[KnowledgeBaseDocumentModel]:
|
|
"""Update document processing status.
|
|
|
|
Args:
|
|
document_id: ID of the document
|
|
status: New status (pending, processing, completed, failed)
|
|
error_message: Optional error message if status is failed
|
|
total_chunks: Optional total number of chunks
|
|
docling_metadata: Optional docling metadata
|
|
|
|
Returns:
|
|
Updated KnowledgeBaseDocumentModel
|
|
"""
|
|
async with self.async_session() as session:
|
|
query = select(KnowledgeBaseDocumentModel).where(
|
|
KnowledgeBaseDocumentModel.id == document_id
|
|
)
|
|
result = await session.execute(query)
|
|
document = result.scalar_one_or_none()
|
|
|
|
if not document:
|
|
return None
|
|
|
|
document.processing_status = status
|
|
if error_message:
|
|
document.processing_error = error_message
|
|
if total_chunks is not None:
|
|
document.total_chunks = total_chunks
|
|
if docling_metadata:
|
|
document.docling_metadata = docling_metadata
|
|
|
|
await session.commit()
|
|
await session.refresh(document)
|
|
|
|
logger.info(f"Updated document {document_id} status to {status}")
|
|
return document
|
|
|
|
async def create_chunks_batch(
|
|
self,
|
|
chunks: List[KnowledgeBaseChunkModel],
|
|
) -> List[KnowledgeBaseChunkModel]:
|
|
"""Create multiple chunks in a batch.
|
|
|
|
Args:
|
|
chunks: List of KnowledgeBaseChunkModel instances
|
|
|
|
Returns:
|
|
List of created chunks with IDs
|
|
"""
|
|
async with self.async_session() as session:
|
|
session.add_all(chunks)
|
|
await session.commit()
|
|
|
|
for chunk in chunks:
|
|
await session.refresh(chunk)
|
|
|
|
logger.info(f"Created {len(chunks)} chunks")
|
|
return chunks
|
|
|
|
async def get_chunks_for_document(
|
|
self,
|
|
document_id: int,
|
|
organization_id: int,
|
|
) -> List[KnowledgeBaseChunkModel]:
|
|
"""Get all chunks for a document.
|
|
|
|
Args:
|
|
document_id: ID of the document
|
|
organization_id: ID of the organization (for authorization)
|
|
|
|
Returns:
|
|
List of KnowledgeBaseChunkModel instances
|
|
"""
|
|
async with self.async_session() as session:
|
|
query = (
|
|
select(KnowledgeBaseChunkModel)
|
|
.where(
|
|
KnowledgeBaseChunkModel.document_id == document_id,
|
|
KnowledgeBaseChunkModel.organization_id == organization_id,
|
|
)
|
|
.order_by(KnowledgeBaseChunkModel.chunk_index)
|
|
)
|
|
|
|
result = await session.execute(query)
|
|
return list(result.scalars().all())
|
|
|
|
async def search_similar_chunks(
|
|
self,
|
|
query_embedding: List[float],
|
|
organization_id: int,
|
|
limit: int = 5,
|
|
document_ids: Optional[List[int]] = None,
|
|
document_uuids: Optional[List[str]] = None,
|
|
embedding_model: Optional[str] = None,
|
|
) -> List[dict]:
|
|
"""Search for similar chunks using vector similarity.
|
|
|
|
Returns top-k most similar chunks without any similarity threshold filtering.
|
|
Filtering and reranking should be done at the application layer.
|
|
|
|
Args:
|
|
query_embedding: The query embedding vector
|
|
organization_id: Organization ID for scoping
|
|
limit: Maximum number of results to return
|
|
document_ids: Optional list of document IDs to filter by
|
|
document_uuids: Optional list of document UUIDs to filter by
|
|
embedding_model: Optional embedding model to filter by (for dimension compatibility)
|
|
|
|
Returns:
|
|
List of dictionaries with chunk data and similarity scores, ordered by similarity (highest first)
|
|
"""
|
|
async with self.async_session() as session:
|
|
# Get the raw connection to execute directly with asyncpg
|
|
# This avoids parameter binding issues with text() and asyncpg
|
|
connection = await session.connection()
|
|
raw_connection = await connection.get_raw_connection()
|
|
|
|
# Build WHERE clause conditions (no similarity threshold)
|
|
where_conditions = [
|
|
"c.organization_id = $2",
|
|
"d.is_active = true",
|
|
]
|
|
params = [
|
|
None,
|
|
organization_id,
|
|
limit,
|
|
] # $1 will be embedding_str, $3 is limit
|
|
param_index = 4 # Next available parameter index
|
|
|
|
# Add document_ids filter if provided
|
|
if document_ids:
|
|
placeholders = ", ".join(
|
|
f"${param_index + i}" for i in range(len(document_ids))
|
|
)
|
|
where_conditions.append(f"c.document_id IN ({placeholders})")
|
|
params.extend(document_ids)
|
|
param_index += len(document_ids)
|
|
|
|
# Add document_uuids filter if provided
|
|
if document_uuids:
|
|
placeholders = ", ".join(
|
|
f"${param_index + i}" for i in range(len(document_uuids))
|
|
)
|
|
where_conditions.append(f"d.document_uuid IN ({placeholders})")
|
|
params.extend(document_uuids)
|
|
param_index += len(document_uuids)
|
|
|
|
# Add embedding_model filter if provided (for dimension compatibility)
|
|
if embedding_model:
|
|
where_conditions.append(f"c.embedding_model = ${param_index}")
|
|
params.append(embedding_model)
|
|
param_index += 1
|
|
|
|
# Build the complete SQL query
|
|
where_clause = " AND ".join(where_conditions)
|
|
query_sql = f"""
|
|
SELECT
|
|
c.id,
|
|
c.document_id,
|
|
c.chunk_text,
|
|
c.contextualized_text,
|
|
c.chunk_metadata,
|
|
c.chunk_index,
|
|
d.filename,
|
|
d.document_uuid,
|
|
1 - (c.embedding <=> $1::vector) as similarity
|
|
FROM knowledge_base_chunks c
|
|
JOIN knowledge_base_documents d ON c.document_id = d.id
|
|
WHERE {where_clause}
|
|
ORDER BY c.embedding <=> $1::vector
|
|
LIMIT $3
|
|
"""
|
|
|
|
# Convert embedding to string format for PostgreSQL vector type
|
|
embedding_str = "[" + ",".join(map(str, query_embedding)) + "]"
|
|
params[0] = embedding_str # Set $1
|
|
|
|
# Execute query directly with asyncpg
|
|
rows = await raw_connection.driver_connection.fetch(
|
|
query_sql,
|
|
*params,
|
|
)
|
|
|
|
# Convert asyncpg records to dictionaries
|
|
return [dict(row) for row in rows]
|
|
|
|
async def delete_document(
|
|
self,
|
|
document_uuid: str,
|
|
organization_id: int,
|
|
) -> bool:
|
|
"""Soft delete a document by setting is_active to False.
|
|
|
|
This will also cascade delete all chunks via the database foreign key.
|
|
|
|
Args:
|
|
document_uuid: The unique document UUID
|
|
organization_id: ID of the organization (for authorization)
|
|
|
|
Returns:
|
|
True if document was deleted, False if not found
|
|
"""
|
|
async with self.async_session() as session:
|
|
query = select(KnowledgeBaseDocumentModel).where(
|
|
KnowledgeBaseDocumentModel.document_uuid == document_uuid,
|
|
KnowledgeBaseDocumentModel.organization_id == organization_id,
|
|
)
|
|
|
|
result = await session.execute(query)
|
|
document = result.scalar_one_or_none()
|
|
|
|
if not document:
|
|
return False
|
|
|
|
document.is_active = False
|
|
await session.commit()
|
|
|
|
logger.info(
|
|
f"Deleted document {document_uuid} for organization {organization_id}"
|
|
)
|
|
return True
|
|
|
|
@staticmethod
|
|
def compute_file_hash(file_path: str) -> str:
|
|
"""Compute SHA-256 hash of a file.
|
|
|
|
Args:
|
|
file_path: Path to the file
|
|
|
|
Returns:
|
|
SHA-256 hash as hex string
|
|
"""
|
|
sha256_hash = hashlib.sha256()
|
|
with open(file_path, "rb") as f:
|
|
for byte_block in iter(lambda: f.read(4096), b""):
|
|
sha256_hash.update(byte_block)
|
|
return sha256_hash.hexdigest()
|
|
|
|
@staticmethod
|
|
def get_mime_type(file_path: str) -> str:
|
|
"""Get MIME type based on file extension.
|
|
|
|
Args:
|
|
file_path: Path to the file
|
|
|
|
Returns:
|
|
MIME type string
|
|
"""
|
|
extension = Path(file_path).suffix.lower()
|
|
mime_types = {
|
|
".pdf": "application/pdf",
|
|
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
".doc": "application/msword",
|
|
".txt": "text/plain",
|
|
".html": "text/html",
|
|
".md": "text/markdown",
|
|
}
|
|
return mime_types.get(extension, "application/octet-stream")
|