mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-10 08:05:22 +02:00
feat: knowledge base functionality for the voice agent (#120)
* feat: upload file and store embedding * feat: add documents in nodes * feat: add openai embedding service
This commit is contained in:
parent
e2fa4bbb98
commit
ef5b9e40a9
52 changed files with 4551 additions and 114 deletions
499
api/db/knowledge_base_client.py
Normal file
499
api/db/knowledge_base_client.py
Normal file
|
|
@ -0,0 +1,499 @@
|
|||
"""Database client for managing knowledge base documents and chunks."""
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import KnowledgeBaseChunkModel, KnowledgeBaseDocumentModel
|
||||
|
||||
|
||||
class KnowledgeBaseClient(BaseDBClient):
|
||||
"""Client for managing knowledge base documents and vector embeddings."""
|
||||
|
||||
async def create_document(
|
||||
self,
|
||||
organization_id: int,
|
||||
created_by: int,
|
||||
filename: str,
|
||||
file_size_bytes: int,
|
||||
file_hash: str,
|
||||
mime_type: str,
|
||||
source_url: Optional[str] = None,
|
||||
custom_metadata: Optional[dict] = None,
|
||||
docling_metadata: Optional[dict] = None,
|
||||
document_uuid: Optional[str] = None,
|
||||
) -> KnowledgeBaseDocumentModel:
|
||||
"""Create a new knowledge base document record.
|
||||
|
||||
Args:
|
||||
organization_id: ID of the organization
|
||||
created_by: ID of the user uploading the document
|
||||
filename: Name of the file
|
||||
file_size_bytes: Size of the file in bytes
|
||||
file_hash: SHA-256 hash of the file
|
||||
mime_type: MIME type of the file
|
||||
source_url: Optional URL if document was fetched from web
|
||||
custom_metadata: Optional custom metadata dictionary
|
||||
docling_metadata: Optional docling processing metadata
|
||||
document_uuid: Optional UUID to use (if not provided, one will be generated)
|
||||
|
||||
Returns:
|
||||
The created KnowledgeBaseDocumentModel
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
document = KnowledgeBaseDocumentModel(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
filename=filename,
|
||||
file_size_bytes=file_size_bytes,
|
||||
file_hash=file_hash,
|
||||
mime_type=mime_type,
|
||||
source_url=source_url,
|
||||
custom_metadata=custom_metadata or {},
|
||||
docling_metadata=docling_metadata or {},
|
||||
processing_status="pending",
|
||||
total_chunks=0,
|
||||
)
|
||||
|
||||
# Use provided UUID or let the model generate one
|
||||
if document_uuid:
|
||||
document.document_uuid = document_uuid
|
||||
|
||||
session.add(document)
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
|
||||
logger.info(
|
||||
f"Created document '{filename}' ({document.document_uuid}) "
|
||||
f"for organization {organization_id}"
|
||||
)
|
||||
return document
|
||||
|
||||
async def get_document_by_id(
|
||||
self,
|
||||
document_id: int,
|
||||
) -> Optional[KnowledgeBaseDocumentModel]:
|
||||
"""Get a document by its database ID.
|
||||
|
||||
Args:
|
||||
document_id: The database ID of the document
|
||||
|
||||
Returns:
|
||||
KnowledgeBaseDocumentModel if found, None otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(KnowledgeBaseDocumentModel).where(
|
||||
KnowledgeBaseDocumentModel.id == document_id
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_document_by_uuid(
|
||||
self,
|
||||
document_uuid: str,
|
||||
organization_id: int,
|
||||
) -> Optional[KnowledgeBaseDocumentModel]:
|
||||
"""Get a document by its UUID, scoped to organization.
|
||||
|
||||
Args:
|
||||
document_uuid: The unique document UUID
|
||||
organization_id: ID of the organization
|
||||
|
||||
Returns:
|
||||
KnowledgeBaseDocumentModel if found, None otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(KnowledgeBaseDocumentModel)
|
||||
.where(
|
||||
KnowledgeBaseDocumentModel.document_uuid == document_uuid,
|
||||
KnowledgeBaseDocumentModel.organization_id == organization_id,
|
||||
KnowledgeBaseDocumentModel.is_active == True,
|
||||
)
|
||||
.options(selectinload(KnowledgeBaseDocumentModel.created_by_user))
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_document_by_hash(
|
||||
self,
|
||||
file_hash: str,
|
||||
organization_id: int,
|
||||
) -> Optional[KnowledgeBaseDocumentModel]:
|
||||
"""Check if a document with the same hash already exists.
|
||||
|
||||
Returns the first matching document if multiple exist (can happen with duplicates).
|
||||
|
||||
Args:
|
||||
file_hash: SHA-256 hash of the file
|
||||
organization_id: ID of the organization
|
||||
|
||||
Returns:
|
||||
KnowledgeBaseDocumentModel if found, None otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(KnowledgeBaseDocumentModel)
|
||||
.where(
|
||||
KnowledgeBaseDocumentModel.file_hash == file_hash,
|
||||
KnowledgeBaseDocumentModel.organization_id == organization_id,
|
||||
KnowledgeBaseDocumentModel.is_active == True,
|
||||
)
|
||||
.order_by(KnowledgeBaseDocumentModel.created_at.asc())
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_documents_for_organization(
|
||||
self,
|
||||
organization_id: int,
|
||||
processing_status: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> List[KnowledgeBaseDocumentModel]:
|
||||
"""Get all documents for an organization.
|
||||
|
||||
Args:
|
||||
organization_id: ID of the organization
|
||||
processing_status: Optional filter by status
|
||||
limit: Maximum number of documents to return
|
||||
offset: Number of documents to skip
|
||||
|
||||
Returns:
|
||||
List of KnowledgeBaseDocumentModel instances
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(KnowledgeBaseDocumentModel).where(
|
||||
KnowledgeBaseDocumentModel.organization_id == organization_id,
|
||||
KnowledgeBaseDocumentModel.is_active == True,
|
||||
)
|
||||
|
||||
if processing_status:
|
||||
query = query.where(
|
||||
KnowledgeBaseDocumentModel.processing_status == processing_status
|
||||
)
|
||||
|
||||
query = (
|
||||
query.order_by(KnowledgeBaseDocumentModel.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def update_document_metadata(
|
||||
self,
|
||||
document_id: int,
|
||||
file_size_bytes: Optional[int] = None,
|
||||
file_hash: Optional[str] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
) -> Optional[KnowledgeBaseDocumentModel]:
|
||||
"""Update document file metadata.
|
||||
|
||||
Args:
|
||||
document_id: ID of the document
|
||||
file_size_bytes: Optional file size in bytes
|
||||
file_hash: Optional SHA-256 hash of the file
|
||||
mime_type: Optional MIME type
|
||||
|
||||
Returns:
|
||||
Updated KnowledgeBaseDocumentModel
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(KnowledgeBaseDocumentModel).where(
|
||||
KnowledgeBaseDocumentModel.id == document_id
|
||||
)
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if not document:
|
||||
return None
|
||||
|
||||
if file_size_bytes is not None:
|
||||
document.file_size_bytes = file_size_bytes
|
||||
if file_hash is not None:
|
||||
document.file_hash = file_hash
|
||||
if mime_type is not None:
|
||||
document.mime_type = mime_type
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
|
||||
logger.info(f"Updated document {document_id} metadata")
|
||||
return document
|
||||
|
||||
async def update_document_status(
|
||||
self,
|
||||
document_id: int,
|
||||
status: str,
|
||||
error_message: Optional[str] = None,
|
||||
total_chunks: Optional[int] = None,
|
||||
docling_metadata: Optional[dict] = None,
|
||||
) -> Optional[KnowledgeBaseDocumentModel]:
|
||||
"""Update document processing status.
|
||||
|
||||
Args:
|
||||
document_id: ID of the document
|
||||
status: New status (pending, processing, completed, failed)
|
||||
error_message: Optional error message if status is failed
|
||||
total_chunks: Optional total number of chunks
|
||||
docling_metadata: Optional docling metadata
|
||||
|
||||
Returns:
|
||||
Updated KnowledgeBaseDocumentModel
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(KnowledgeBaseDocumentModel).where(
|
||||
KnowledgeBaseDocumentModel.id == document_id
|
||||
)
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if not document:
|
||||
return None
|
||||
|
||||
document.processing_status = status
|
||||
if error_message:
|
||||
document.processing_error = error_message
|
||||
if total_chunks is not None:
|
||||
document.total_chunks = total_chunks
|
||||
if docling_metadata:
|
||||
document.docling_metadata = docling_metadata
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
|
||||
logger.info(f"Updated document {document_id} status to {status}")
|
||||
return document
|
||||
|
||||
async def create_chunks_batch(
|
||||
self,
|
||||
chunks: List[KnowledgeBaseChunkModel],
|
||||
) -> List[KnowledgeBaseChunkModel]:
|
||||
"""Create multiple chunks in a batch.
|
||||
|
||||
Args:
|
||||
chunks: List of KnowledgeBaseChunkModel instances
|
||||
|
||||
Returns:
|
||||
List of created chunks with IDs
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
session.add_all(chunks)
|
||||
await session.commit()
|
||||
|
||||
for chunk in chunks:
|
||||
await session.refresh(chunk)
|
||||
|
||||
logger.info(f"Created {len(chunks)} chunks")
|
||||
return chunks
|
||||
|
||||
async def get_chunks_for_document(
|
||||
self,
|
||||
document_id: int,
|
||||
organization_id: int,
|
||||
) -> List[KnowledgeBaseChunkModel]:
|
||||
"""Get all chunks for a document.
|
||||
|
||||
Args:
|
||||
document_id: ID of the document
|
||||
organization_id: ID of the organization (for authorization)
|
||||
|
||||
Returns:
|
||||
List of KnowledgeBaseChunkModel instances
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(KnowledgeBaseChunkModel)
|
||||
.where(
|
||||
KnowledgeBaseChunkModel.document_id == document_id,
|
||||
KnowledgeBaseChunkModel.organization_id == organization_id,
|
||||
)
|
||||
.order_by(KnowledgeBaseChunkModel.chunk_index)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def search_similar_chunks(
|
||||
self,
|
||||
query_embedding: List[float],
|
||||
organization_id: int,
|
||||
limit: int = 5,
|
||||
document_ids: Optional[List[int]] = None,
|
||||
document_uuids: Optional[List[str]] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
) -> List[dict]:
|
||||
"""Search for similar chunks using vector similarity.
|
||||
|
||||
Returns top-k most similar chunks without any similarity threshold filtering.
|
||||
Filtering and reranking should be done at the application layer.
|
||||
|
||||
Args:
|
||||
query_embedding: The query embedding vector
|
||||
organization_id: Organization ID for scoping
|
||||
limit: Maximum number of results to return
|
||||
document_ids: Optional list of document IDs to filter by
|
||||
document_uuids: Optional list of document UUIDs to filter by
|
||||
embedding_model: Optional embedding model to filter by (for dimension compatibility)
|
||||
|
||||
Returns:
|
||||
List of dictionaries with chunk data and similarity scores, ordered by similarity (highest first)
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# Get the raw connection to execute directly with asyncpg
|
||||
# This avoids parameter binding issues with text() and asyncpg
|
||||
connection = await session.connection()
|
||||
raw_connection = await connection.get_raw_connection()
|
||||
|
||||
# Build WHERE clause conditions (no similarity threshold)
|
||||
where_conditions = [
|
||||
"c.organization_id = $2",
|
||||
"d.is_active = true",
|
||||
]
|
||||
params = [
|
||||
None,
|
||||
organization_id,
|
||||
limit,
|
||||
] # $1 will be embedding_str, $3 is limit
|
||||
param_index = 4 # Next available parameter index
|
||||
|
||||
# Add document_ids filter if provided
|
||||
if document_ids:
|
||||
placeholders = ", ".join(
|
||||
f"${param_index + i}" for i in range(len(document_ids))
|
||||
)
|
||||
where_conditions.append(f"c.document_id IN ({placeholders})")
|
||||
params.extend(document_ids)
|
||||
param_index += len(document_ids)
|
||||
|
||||
# Add document_uuids filter if provided
|
||||
if document_uuids:
|
||||
placeholders = ", ".join(
|
||||
f"${param_index + i}" for i in range(len(document_uuids))
|
||||
)
|
||||
where_conditions.append(f"d.document_uuid IN ({placeholders})")
|
||||
params.extend(document_uuids)
|
||||
param_index += len(document_uuids)
|
||||
|
||||
# Add embedding_model filter if provided (for dimension compatibility)
|
||||
if embedding_model:
|
||||
where_conditions.append(f"c.embedding_model = ${param_index}")
|
||||
params.append(embedding_model)
|
||||
param_index += 1
|
||||
|
||||
# Build the complete SQL query
|
||||
where_clause = " AND ".join(where_conditions)
|
||||
query_sql = f"""
|
||||
SELECT
|
||||
c.id,
|
||||
c.document_id,
|
||||
c.chunk_text,
|
||||
c.contextualized_text,
|
||||
c.chunk_metadata,
|
||||
c.chunk_index,
|
||||
d.filename,
|
||||
d.document_uuid,
|
||||
1 - (c.embedding <=> $1::vector) as similarity
|
||||
FROM knowledge_base_chunks c
|
||||
JOIN knowledge_base_documents d ON c.document_id = d.id
|
||||
WHERE {where_clause}
|
||||
ORDER BY c.embedding <=> $1::vector
|
||||
LIMIT $3
|
||||
"""
|
||||
|
||||
# Convert embedding to string format for PostgreSQL vector type
|
||||
embedding_str = "[" + ",".join(map(str, query_embedding)) + "]"
|
||||
params[0] = embedding_str # Set $1
|
||||
|
||||
# Execute query directly with asyncpg
|
||||
rows = await raw_connection.driver_connection.fetch(
|
||||
query_sql,
|
||||
*params,
|
||||
)
|
||||
|
||||
# Convert asyncpg records to dictionaries
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
async def delete_document(
|
||||
self,
|
||||
document_uuid: str,
|
||||
organization_id: int,
|
||||
) -> bool:
|
||||
"""Soft delete a document by setting is_active to False.
|
||||
|
||||
This will also cascade delete all chunks via the database foreign key.
|
||||
|
||||
Args:
|
||||
document_uuid: The unique document UUID
|
||||
organization_id: ID of the organization (for authorization)
|
||||
|
||||
Returns:
|
||||
True if document was deleted, False if not found
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(KnowledgeBaseDocumentModel).where(
|
||||
KnowledgeBaseDocumentModel.document_uuid == document_uuid,
|
||||
KnowledgeBaseDocumentModel.organization_id == organization_id,
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if not document:
|
||||
return False
|
||||
|
||||
document.is_active = False
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Deleted document {document_uuid} for organization {organization_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def compute_file_hash(file_path: str) -> str:
|
||||
"""Compute SHA-256 hash of a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
SHA-256 hash as hex string
|
||||
"""
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
for byte_block in iter(lambda: f.read(4096), b""):
|
||||
sha256_hash.update(byte_block)
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def get_mime_type(file_path: str) -> str:
|
||||
"""Get MIME type based on file extension.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
MIME type string
|
||||
"""
|
||||
extension = Path(file_path).suffix.lower()
|
||||
mime_types = {
|
||||
".pdf": "application/pdf",
|
||||
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
".doc": "application/msword",
|
||||
".txt": "text/plain",
|
||||
".html": "text/html",
|
||||
".md": "text/markdown",
|
||||
}
|
||||
return mime_types.get(extension, "application/octet-stream")
|
||||
Loading…
Add table
Add a link
Reference in a new issue