mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-10 08:05:22 +02:00
feat: upload file and store embedding
This commit is contained in:
parent
cac25879bf
commit
ec1417da87
21 changed files with 2566 additions and 2 deletions
98
api/alembic/versions/dc33eef8dabe_add_document_tables.py
Normal file
98
api/alembic/versions/dc33eef8dabe_add_document_tables.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
"""add document tables
|
||||
|
||||
Revision ID: dc33eef8dabe
|
||||
Revises: dcb0a27d98c6
|
||||
Create Date: 2026-01-16 13:40:17.808807
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from pgvector.sqlalchemy import Vector
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'dc33eef8dabe'
|
||||
down_revision: Union[str, None] = 'dcb0a27d98c6'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# Enable pgvector extension
|
||||
op.execute('CREATE EXTENSION IF NOT EXISTS vector')
|
||||
|
||||
sa.Enum('pending', 'processing', 'completed', 'failed', name='document_processing_status').create(op.get_bind())
|
||||
op.create_table('knowledge_base_documents',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('document_uuid', sa.String(length=36), nullable=False),
|
||||
sa.Column('organization_id', sa.Integer(), nullable=False),
|
||||
sa.Column('filename', sa.String(length=500), nullable=False),
|
||||
sa.Column('file_size_bytes', sa.Integer(), nullable=True),
|
||||
sa.Column('file_hash', sa.String(length=64), nullable=True),
|
||||
sa.Column('mime_type', sa.String(length=100), nullable=True),
|
||||
sa.Column('source_url', sa.String(), nullable=True),
|
||||
sa.Column('total_chunks', sa.Integer(), nullable=False),
|
||||
sa.Column('processing_status', postgresql.ENUM('pending', 'processing', 'completed', 'failed', name='document_processing_status', create_type=False), server_default=sa.text("'pending'::document_processing_status"), nullable=False),
|
||||
sa.Column('processing_error', sa.Text(), nullable=True),
|
||||
sa.Column('docling_metadata', sa.JSON(), nullable=False),
|
||||
sa.Column('custom_metadata', sa.JSON(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('archived_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(['created_by'], ['users.id'], ),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
op.create_index('ix_kb_documents_created_at', 'knowledge_base_documents', ['created_at'], unique=False)
|
||||
op.create_index('ix_kb_documents_organization_id', 'knowledge_base_documents', ['organization_id'], unique=False)
|
||||
op.create_index('ix_kb_documents_status', 'knowledge_base_documents', ['processing_status'], unique=False)
|
||||
op.create_index('ix_kb_documents_uuid', 'knowledge_base_documents', ['document_uuid'], unique=False)
|
||||
op.create_index(op.f('ix_knowledge_base_documents_document_uuid'), 'knowledge_base_documents', ['document_uuid'], unique=True)
|
||||
op.create_table('knowledge_base_chunks'),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('document_id', sa.Integer(), nullable=False),
|
||||
sa.Column('organization_id', sa.Integer(), nullable=False),
|
||||
sa.Column('chunk_text', sa.Text(), nullable=False),
|
||||
sa.Column('contextualized_text', sa.Text(), nullable=True),
|
||||
sa.Column('chunk_index', sa.Integer(), nullable=False),
|
||||
sa.Column('chunk_metadata', sa.JSON(), nullable=False),
|
||||
sa.Column('embedding_model', sa.String(length=200), nullable=False),
|
||||
sa.Column('embedding_dimension', sa.Integer(), nullable=False),
|
||||
sa.Column('embedding', Vector(384), nullable=True),
|
||||
sa.Column('token_count', sa.Integer(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(['document_id'], ['knowledge_base_documents.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index('ix_kb_chunks_chunk_index', 'knowledge_base_chunks', ['chunk_index'], unique=False)
|
||||
op.create_index('ix_kb_chunks_document_id', 'knowledge_base_chunks', ['document_id'], unique=False)
|
||||
op.create_index('ix_kb_chunks_embedding_ivfflat', 'knowledge_base_chunks', ['embedding'], unique=False, postgresql_using='ivfflat', postgresql_with={'lists': 100}, postgresql_ops={'embedding': 'vector_cosine_ops'})
|
||||
op.create_index('ix_kb_chunks_organization_id', 'knowledge_base_chunks', ['organization_id'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index('ix_kb_chunks_organization_id', table_name='knowledge_base_chunks')
|
||||
op.drop_index('ix_kb_chunks_embedding_ivfflat', table_name='knowledge_base_chunks', postgresql_using='ivfflat', postgresql_with={'lists': 100}, postgresql_ops={'embedding': 'vector_cosine_ops'})
|
||||
op.drop_index('ix_kb_chunks_document_id', table_name='knowledge_base_chunks')
|
||||
op.drop_index('ix_kb_chunks_chunk_index', table_name='knowledge_base_chunks')
|
||||
op.drop_table('knowledge_base_chunks')
|
||||
op.drop_index(op.f('ix_knowledge_base_documents_document_uuid'), table_name='knowledge_base_documents')
|
||||
op.drop_index('ix_kb_documents_uuid', table_name='knowledge_base_documents')
|
||||
op.drop_index('ix_kb_documents_status', table_name='knowledge_base_documents')
|
||||
op.drop_index('ix_kb_documents_organization_id', table_name='knowledge_base_documents')
|
||||
op.drop_index('ix_kb_documents_created_at', table_name='knowledge_base_documents')
|
||||
op.drop_table('knowledge_base_documents')
|
||||
sa.Enum('pending', 'processing', 'completed', 'failed', name='document_processing_status').drop(op.get_bind())
|
||||
|
||||
# Note: We don't drop the vector extension as it may be used by other tables
|
||||
# If you want to drop it, uncomment the following line:
|
||||
# op.execute('DROP EXTENSION IF EXISTS vector')
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -3,6 +3,7 @@ from api.db.api_key_client import APIKeyClient
|
|||
from api.db.campaign_client import CampaignClient
|
||||
from api.db.embed_token_client import EmbedTokenClient
|
||||
from api.db.integration_client import IntegrationClient
|
||||
from api.db.knowledge_base_client import KnowledgeBaseClient
|
||||
from api.db.looptalk_client import LoopTalkClient
|
||||
from api.db.organization_client import OrganizationClient
|
||||
from api.db.organization_configuration_client import OrganizationConfigurationClient
|
||||
|
|
@ -33,6 +34,7 @@ class DBClient(
|
|||
AgentTriggerClient,
|
||||
WebhookCredentialClient,
|
||||
ToolClient,
|
||||
KnowledgeBaseClient,
|
||||
):
|
||||
"""
|
||||
Unified database client that combines all specialized database operations.
|
||||
|
|
@ -54,6 +56,7 @@ class DBClient(
|
|||
- AgentTriggerClient: handles agent trigger operations for API-based call triggering
|
||||
- WebhookCredentialClient: handles webhook credential operations
|
||||
- ToolClient: handles tool operations for reusable HTTP API tools
|
||||
- KnowledgeBaseClient: handles knowledge base document and vector search operations
|
||||
"""
|
||||
|
||||
pass
|
||||
|
|
|
|||
483
api/db/knowledge_base_client.py
Normal file
483
api/db/knowledge_base_client.py
Normal file
|
|
@ -0,0 +1,483 @@
|
|||
"""Database client for managing knowledge base documents and chunks."""
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import KnowledgeBaseChunkModel, KnowledgeBaseDocumentModel
|
||||
|
||||
|
||||
class KnowledgeBaseClient(BaseDBClient):
|
||||
"""Client for managing knowledge base documents and vector embeddings."""
|
||||
|
||||
async def create_document(
|
||||
self,
|
||||
organization_id: int,
|
||||
created_by: int,
|
||||
filename: str,
|
||||
file_size_bytes: int,
|
||||
file_hash: str,
|
||||
mime_type: str,
|
||||
source_url: Optional[str] = None,
|
||||
custom_metadata: Optional[dict] = None,
|
||||
docling_metadata: Optional[dict] = None,
|
||||
document_uuid: Optional[str] = None,
|
||||
) -> KnowledgeBaseDocumentModel:
|
||||
"""Create a new knowledge base document record.
|
||||
|
||||
Args:
|
||||
organization_id: ID of the organization
|
||||
created_by: ID of the user uploading the document
|
||||
filename: Name of the file
|
||||
file_size_bytes: Size of the file in bytes
|
||||
file_hash: SHA-256 hash of the file
|
||||
mime_type: MIME type of the file
|
||||
source_url: Optional URL if document was fetched from web
|
||||
custom_metadata: Optional custom metadata dictionary
|
||||
docling_metadata: Optional docling processing metadata
|
||||
document_uuid: Optional UUID to use (if not provided, one will be generated)
|
||||
|
||||
Returns:
|
||||
The created KnowledgeBaseDocumentModel
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
document = KnowledgeBaseDocumentModel(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
filename=filename,
|
||||
file_size_bytes=file_size_bytes,
|
||||
file_hash=file_hash,
|
||||
mime_type=mime_type,
|
||||
source_url=source_url,
|
||||
custom_metadata=custom_metadata or {},
|
||||
docling_metadata=docling_metadata or {},
|
||||
processing_status="pending",
|
||||
total_chunks=0,
|
||||
)
|
||||
|
||||
# Use provided UUID or let the model generate one
|
||||
if document_uuid:
|
||||
document.document_uuid = document_uuid
|
||||
|
||||
session.add(document)
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
|
||||
logger.info(
|
||||
f"Created document '{filename}' ({document.document_uuid}) "
|
||||
f"for organization {organization_id}"
|
||||
)
|
||||
return document
|
||||
|
||||
async def get_document_by_id(
|
||||
self,
|
||||
document_id: int,
|
||||
) -> Optional[KnowledgeBaseDocumentModel]:
|
||||
"""Get a document by its database ID.
|
||||
|
||||
Args:
|
||||
document_id: The database ID of the document
|
||||
|
||||
Returns:
|
||||
KnowledgeBaseDocumentModel if found, None otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(KnowledgeBaseDocumentModel).where(
|
||||
KnowledgeBaseDocumentModel.id == document_id
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_document_by_uuid(
|
||||
self,
|
||||
document_uuid: str,
|
||||
organization_id: int,
|
||||
) -> Optional[KnowledgeBaseDocumentModel]:
|
||||
"""Get a document by its UUID, scoped to organization.
|
||||
|
||||
Args:
|
||||
document_uuid: The unique document UUID
|
||||
organization_id: ID of the organization
|
||||
|
||||
Returns:
|
||||
KnowledgeBaseDocumentModel if found, None otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(KnowledgeBaseDocumentModel)
|
||||
.where(
|
||||
KnowledgeBaseDocumentModel.document_uuid == document_uuid,
|
||||
KnowledgeBaseDocumentModel.organization_id == organization_id,
|
||||
KnowledgeBaseDocumentModel.is_active == True,
|
||||
)
|
||||
.options(selectinload(KnowledgeBaseDocumentModel.created_by_user))
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_document_by_hash(
|
||||
self,
|
||||
file_hash: str,
|
||||
organization_id: int,
|
||||
) -> Optional[KnowledgeBaseDocumentModel]:
|
||||
"""Check if a document with the same hash already exists.
|
||||
|
||||
Returns the first matching document if multiple exist (can happen with duplicates).
|
||||
|
||||
Args:
|
||||
file_hash: SHA-256 hash of the file
|
||||
organization_id: ID of the organization
|
||||
|
||||
Returns:
|
||||
KnowledgeBaseDocumentModel if found, None otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(KnowledgeBaseDocumentModel)
|
||||
.where(
|
||||
KnowledgeBaseDocumentModel.file_hash == file_hash,
|
||||
KnowledgeBaseDocumentModel.organization_id == organization_id,
|
||||
KnowledgeBaseDocumentModel.is_active == True,
|
||||
)
|
||||
.order_by(KnowledgeBaseDocumentModel.created_at.asc())
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_documents_for_organization(
|
||||
self,
|
||||
organization_id: int,
|
||||
processing_status: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> List[KnowledgeBaseDocumentModel]:
|
||||
"""Get all documents for an organization.
|
||||
|
||||
Args:
|
||||
organization_id: ID of the organization
|
||||
processing_status: Optional filter by status
|
||||
limit: Maximum number of documents to return
|
||||
offset: Number of documents to skip
|
||||
|
||||
Returns:
|
||||
List of KnowledgeBaseDocumentModel instances
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(KnowledgeBaseDocumentModel).where(
|
||||
KnowledgeBaseDocumentModel.organization_id == organization_id,
|
||||
KnowledgeBaseDocumentModel.is_active == True,
|
||||
)
|
||||
|
||||
if processing_status:
|
||||
query = query.where(
|
||||
KnowledgeBaseDocumentModel.processing_status == processing_status
|
||||
)
|
||||
|
||||
query = (
|
||||
query.order_by(KnowledgeBaseDocumentModel.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def update_document_metadata(
|
||||
self,
|
||||
document_id: int,
|
||||
file_size_bytes: Optional[int] = None,
|
||||
file_hash: Optional[str] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
) -> Optional[KnowledgeBaseDocumentModel]:
|
||||
"""Update document file metadata.
|
||||
|
||||
Args:
|
||||
document_id: ID of the document
|
||||
file_size_bytes: Optional file size in bytes
|
||||
file_hash: Optional SHA-256 hash of the file
|
||||
mime_type: Optional MIME type
|
||||
|
||||
Returns:
|
||||
Updated KnowledgeBaseDocumentModel
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(KnowledgeBaseDocumentModel).where(
|
||||
KnowledgeBaseDocumentModel.id == document_id
|
||||
)
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if not document:
|
||||
return None
|
||||
|
||||
if file_size_bytes is not None:
|
||||
document.file_size_bytes = file_size_bytes
|
||||
if file_hash is not None:
|
||||
document.file_hash = file_hash
|
||||
if mime_type is not None:
|
||||
document.mime_type = mime_type
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
|
||||
logger.info(f"Updated document {document_id} metadata")
|
||||
return document
|
||||
|
||||
async def update_document_status(
|
||||
self,
|
||||
document_id: int,
|
||||
status: str,
|
||||
error_message: Optional[str] = None,
|
||||
total_chunks: Optional[int] = None,
|
||||
docling_metadata: Optional[dict] = None,
|
||||
) -> Optional[KnowledgeBaseDocumentModel]:
|
||||
"""Update document processing status.
|
||||
|
||||
Args:
|
||||
document_id: ID of the document
|
||||
status: New status (pending, processing, completed, failed)
|
||||
error_message: Optional error message if status is failed
|
||||
total_chunks: Optional total number of chunks
|
||||
docling_metadata: Optional docling metadata
|
||||
|
||||
Returns:
|
||||
Updated KnowledgeBaseDocumentModel
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(KnowledgeBaseDocumentModel).where(
|
||||
KnowledgeBaseDocumentModel.id == document_id
|
||||
)
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if not document:
|
||||
return None
|
||||
|
||||
document.processing_status = status
|
||||
if error_message:
|
||||
document.processing_error = error_message
|
||||
if total_chunks is not None:
|
||||
document.total_chunks = total_chunks
|
||||
if docling_metadata:
|
||||
document.docling_metadata = docling_metadata
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
|
||||
logger.info(f"Updated document {document_id} status to {status}")
|
||||
return document
|
||||
|
||||
async def create_chunks_batch(
|
||||
self,
|
||||
chunks: List[KnowledgeBaseChunkModel],
|
||||
) -> List[KnowledgeBaseChunkModel]:
|
||||
"""Create multiple chunks in a batch.
|
||||
|
||||
Args:
|
||||
chunks: List of KnowledgeBaseChunkModel instances
|
||||
|
||||
Returns:
|
||||
List of created chunks with IDs
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
session.add_all(chunks)
|
||||
await session.commit()
|
||||
|
||||
for chunk in chunks:
|
||||
await session.refresh(chunk)
|
||||
|
||||
logger.info(f"Created {len(chunks)} chunks")
|
||||
return chunks
|
||||
|
||||
async def get_chunks_for_document(
|
||||
self,
|
||||
document_id: int,
|
||||
organization_id: int,
|
||||
) -> List[KnowledgeBaseChunkModel]:
|
||||
"""Get all chunks for a document.
|
||||
|
||||
Args:
|
||||
document_id: ID of the document
|
||||
organization_id: ID of the organization (for authorization)
|
||||
|
||||
Returns:
|
||||
List of KnowledgeBaseChunkModel instances
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(KnowledgeBaseChunkModel)
|
||||
.where(
|
||||
KnowledgeBaseChunkModel.document_id == document_id,
|
||||
KnowledgeBaseChunkModel.organization_id == organization_id,
|
||||
)
|
||||
.order_by(KnowledgeBaseChunkModel.chunk_index)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def search_similar_chunks(
|
||||
self,
|
||||
query_embedding: List[float],
|
||||
organization_id: int,
|
||||
limit: int = 5,
|
||||
document_ids: Optional[List[int]] = None,
|
||||
document_uuids: Optional[List[str]] = None,
|
||||
) -> List[dict]:
|
||||
"""Search for similar chunks using vector similarity.
|
||||
|
||||
Returns top-k most similar chunks without any similarity threshold filtering.
|
||||
Filtering and reranking should be done at the application layer.
|
||||
|
||||
Args:
|
||||
query_embedding: The query embedding vector
|
||||
organization_id: Organization ID for scoping
|
||||
limit: Maximum number of results to return
|
||||
document_ids: Optional list of document IDs to filter by
|
||||
document_uuids: Optional list of document UUIDs to filter by
|
||||
|
||||
Returns:
|
||||
List of dictionaries with chunk data and similarity scores, ordered by similarity (highest first)
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# Get the raw connection to execute directly with asyncpg
|
||||
# This avoids parameter binding issues with text() and asyncpg
|
||||
connection = await session.connection()
|
||||
raw_connection = await connection.get_raw_connection()
|
||||
|
||||
# Build WHERE clause conditions (no similarity threshold)
|
||||
where_conditions = [
|
||||
"c.organization_id = $2",
|
||||
"d.is_active = true",
|
||||
]
|
||||
params = [None, organization_id, limit] # $1 will be embedding_str, $3 is limit
|
||||
param_index = 4 # Next available parameter index
|
||||
|
||||
# Add document_ids filter if provided
|
||||
if document_ids:
|
||||
placeholders = ", ".join(f"${param_index + i}" for i in range(len(document_ids)))
|
||||
where_conditions.append(f"c.document_id IN ({placeholders})")
|
||||
params.extend(document_ids)
|
||||
param_index += len(document_ids)
|
||||
|
||||
# Add document_uuids filter if provided
|
||||
if document_uuids:
|
||||
placeholders = ", ".join(f"${param_index + i}" for i in range(len(document_uuids)))
|
||||
where_conditions.append(f"d.document_uuid IN ({placeholders})")
|
||||
params.extend(document_uuids)
|
||||
param_index += len(document_uuids)
|
||||
|
||||
# Build the complete SQL query
|
||||
where_clause = " AND ".join(where_conditions)
|
||||
query_sql = f"""
|
||||
SELECT
|
||||
c.id,
|
||||
c.document_id,
|
||||
c.chunk_text,
|
||||
c.contextualized_text,
|
||||
c.chunk_metadata,
|
||||
c.chunk_index,
|
||||
d.filename,
|
||||
d.document_uuid,
|
||||
1 - (c.embedding <=> $1::vector) as similarity
|
||||
FROM knowledge_base_chunks c
|
||||
JOIN knowledge_base_documents d ON c.document_id = d.id
|
||||
WHERE {where_clause}
|
||||
ORDER BY c.embedding <=> $1::vector
|
||||
LIMIT $3
|
||||
"""
|
||||
|
||||
# Convert embedding to string format for PostgreSQL vector type
|
||||
embedding_str = "[" + ",".join(map(str, query_embedding)) + "]"
|
||||
params[0] = embedding_str # Set $1
|
||||
|
||||
# Execute query directly with asyncpg
|
||||
rows = await raw_connection.driver_connection.fetch(
|
||||
query_sql,
|
||||
*params,
|
||||
)
|
||||
|
||||
# Convert asyncpg records to dictionaries
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
async def delete_document(
|
||||
self,
|
||||
document_uuid: str,
|
||||
organization_id: int,
|
||||
) -> bool:
|
||||
"""Soft delete a document by setting is_active to False.
|
||||
|
||||
This will also cascade delete all chunks via the database foreign key.
|
||||
|
||||
Args:
|
||||
document_uuid: The unique document UUID
|
||||
organization_id: ID of the organization (for authorization)
|
||||
|
||||
Returns:
|
||||
True if document was deleted, False if not found
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(KnowledgeBaseDocumentModel).where(
|
||||
KnowledgeBaseDocumentModel.document_uuid == document_uuid,
|
||||
KnowledgeBaseDocumentModel.organization_id == organization_id,
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if not document:
|
||||
return False
|
||||
|
||||
document.is_active = False
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Deleted document {document_uuid} for organization {organization_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def compute_file_hash(file_path: str) -> str:
|
||||
"""Compute SHA-256 hash of a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
SHA-256 hash as hex string
|
||||
"""
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
for byte_block in iter(lambda: f.read(4096), b""):
|
||||
sha256_hash.update(byte_block)
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def get_mime_type(file_path: str) -> str:
|
||||
"""Get MIME type based on file extension.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
MIME type string
|
||||
"""
|
||||
extension = Path(file_path).suffix.lower()
|
||||
mime_types = {
|
||||
".pdf": "application/pdf",
|
||||
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
".doc": "application/msword",
|
||||
".txt": "text/plain",
|
||||
".html": "text/html",
|
||||
".md": "text/markdown",
|
||||
}
|
||||
return mime_types.get(extension, "application/octet-stream")
|
||||
157
api/db/models.py
157
api/db/models.py
|
|
@ -14,10 +14,12 @@ from sqlalchemy import (
|
|||
Integer,
|
||||
String,
|
||||
Table,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
and_,
|
||||
text,
|
||||
)
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy.orm import declarative_base, relationship
|
||||
|
||||
from ..enums import (
|
||||
|
|
@ -890,3 +892,158 @@ class ToolModel(Base):
|
|||
Index("ix_tools_status", "status"),
|
||||
Index("ix_tools_category", "category"),
|
||||
)
|
||||
|
||||
|
||||
class KnowledgeBaseDocumentModel(Base):
|
||||
"""Model for storing document-level metadata in the knowledge base.
|
||||
|
||||
Each document represents a source file (PDF, DOCX, etc.) that has been
|
||||
processed and chunked for retrieval.
|
||||
"""
|
||||
|
||||
__tablename__ = "knowledge_base_documents"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
# Public identifier for API references
|
||||
document_uuid = Column(
|
||||
String(36),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
index=True,
|
||||
default=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
# Organization scoping
|
||||
organization_id = Column(
|
||||
Integer, ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
# Document metadata
|
||||
filename = Column(String(500), nullable=False)
|
||||
file_size_bytes = Column(Integer, nullable=True)
|
||||
file_hash = Column(String(64), nullable=True) # SHA-256 hash for deduplication
|
||||
mime_type = Column(String(100), nullable=True)
|
||||
|
||||
# Processing metadata
|
||||
source_url = Column(String, nullable=True) # If document was fetched from URL
|
||||
total_chunks = Column(Integer, nullable=False, default=0)
|
||||
processing_status = Column(
|
||||
Enum("pending", "processing", "completed", "failed", name="document_processing_status"),
|
||||
nullable=False,
|
||||
default="pending",
|
||||
server_default=text("'pending'::document_processing_status"),
|
||||
)
|
||||
processing_error = Column(Text, nullable=True)
|
||||
|
||||
# Docling conversion metadata
|
||||
docling_metadata = Column(JSON, nullable=False, default=dict) # Store docling document metadata
|
||||
|
||||
# Custom metadata (user-defined tags, categories, etc.)
|
||||
custom_metadata = Column(JSON, nullable=False, default=dict)
|
||||
|
||||
# Audit fields
|
||||
created_by = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
)
|
||||
|
||||
# Soft delete
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
archived_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationships
|
||||
organization = relationship("OrganizationModel")
|
||||
created_by_user = relationship("UserModel")
|
||||
chunks = relationship(
|
||||
"KnowledgeBaseChunkModel",
|
||||
back_populates="document",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Indexes and constraints
|
||||
__table_args__ = (
|
||||
Index("ix_kb_documents_organization_id", "organization_id"),
|
||||
Index("ix_kb_documents_uuid", "document_uuid"),
|
||||
Index("ix_kb_documents_status", "processing_status"),
|
||||
Index("ix_kb_documents_created_at", "created_at"),
|
||||
)
|
||||
|
||||
|
||||
class KnowledgeBaseChunkModel(Base):
|
||||
"""Model for storing document chunks with vector embeddings.
|
||||
|
||||
Each chunk represents a portion of a document that has been:
|
||||
1. Extracted and chunked by docling's HybridChunker
|
||||
2. Optionally contextualized with surrounding information
|
||||
3. Embedded into a vector representation for semantic search
|
||||
"""
|
||||
|
||||
__tablename__ = "knowledge_base_chunks"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
# Link to parent document
|
||||
document_id = Column(
|
||||
Integer,
|
||||
ForeignKey("knowledge_base_documents.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Organization scoping (denormalized for efficient querying)
|
||||
organization_id = Column(
|
||||
Integer, ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
# Chunk content
|
||||
chunk_text = Column(Text, nullable=False) # The actual chunk text
|
||||
contextualized_text = Column(Text, nullable=True) # Enriched text from chunker.contextualize()
|
||||
|
||||
# Chunk positioning and metadata
|
||||
chunk_index = Column(Integer, nullable=False) # Position in document (0-based)
|
||||
|
||||
# Docling chunk metadata
|
||||
chunk_metadata = Column(JSON, nullable=False, default=dict) # Store chunk.meta if available
|
||||
|
||||
# Embedding configuration
|
||||
embedding_model = Column(String(200), nullable=False) # e.g., "sentence-transformers/all-MiniLM-L6-v2"
|
||||
embedding_dimension = Column(Integer, nullable=False) # e.g., 384 for all-MiniLM-L6-v2
|
||||
|
||||
# Vector embedding (pgvector column)
|
||||
# The dimension should match the embedding_dimension field
|
||||
embedding = Column(Vector(384), nullable=True) # Default to 384 for all-MiniLM-L6-v2
|
||||
|
||||
# Token count (useful for chunking strategy analysis)
|
||||
token_count = Column(Integer, nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
)
|
||||
|
||||
# Relationships
|
||||
document = relationship("KnowledgeBaseDocumentModel", back_populates="chunks")
|
||||
organization = relationship("OrganizationModel")
|
||||
|
||||
# Indexes and constraints
|
||||
__table_args__ = (
|
||||
Index("ix_kb_chunks_document_id", "document_id"),
|
||||
Index("ix_kb_chunks_organization_id", "organization_id"),
|
||||
Index("ix_kb_chunks_chunk_index", "chunk_index"),
|
||||
# Vector similarity search index (using IVFFlat or HNSW)
|
||||
# IVFFlat is good for datasets with 10k-1M vectors
|
||||
# HNSW is better for larger datasets but uses more memory
|
||||
Index(
|
||||
"ix_kb_chunks_embedding_ivfflat",
|
||||
"embedding",
|
||||
postgresql_using="ivfflat",
|
||||
postgresql_with={"lists": 100}, # Adjust based on dataset size
|
||||
postgresql_ops={"embedding": "vector_cosine_ops"},
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,3 +13,6 @@ python-multipart==0.0.20
|
|||
sentry-sdk[fastapi]==2.38.0
|
||||
sqlalchemy[asyncio]==2.0.43
|
||||
msgpack==1.1.2
|
||||
docling[rapidocr]==2.68.0
|
||||
sentence-transformers==5.2.0
|
||||
pgvector==0.4.2
|
||||
|
|
|
|||
390
api/routes/knowledge_base.py
Normal file
390
api/routes/knowledge_base.py
Normal file
|
|
@ -0,0 +1,390 @@
|
|||
"""API routes for knowledge base operations."""
|
||||
|
||||
import uuid
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.schemas.knowledge_base import (
|
||||
ChunkSearchRequestSchema,
|
||||
ChunkSearchResponseSchema,
|
||||
DocumentListResponseSchema,
|
||||
DocumentResponseSchema,
|
||||
DocumentUploadRequestSchema,
|
||||
DocumentUploadResponseSchema,
|
||||
ProcessDocumentRequestSchema,
|
||||
)
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.storage import storage_fs
|
||||
from api.tasks.arq import enqueue_job
|
||||
from api.tasks.function_names import FunctionNames
|
||||
|
||||
router = APIRouter(prefix="/knowledge-base", tags=["knowledge-base"])
|
||||
|
||||
|
||||
@router.post(
|
||||
"/upload-url",
|
||||
response_model=DocumentUploadResponseSchema,
|
||||
summary="Get presigned URL for document upload",
|
||||
)
|
||||
async def get_upload_url(
|
||||
request: DocumentUploadRequestSchema,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Generate a presigned PUT URL for uploading a document.
|
||||
|
||||
This endpoint:
|
||||
1. Generates a unique document UUID for organizing the S3 key
|
||||
2. Generates a presigned S3/MinIO URL for uploading the file
|
||||
3. Returns the upload URL and document metadata
|
||||
|
||||
After uploading to the returned URL, call /process-document to create
|
||||
the document record and trigger processing.
|
||||
|
||||
Access Control:
|
||||
* All authenticated users can upload documents scoped to their organization.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Generate unique document UUID for S3 organization
|
||||
document_uuid = str(uuid.uuid4())
|
||||
|
||||
# Generate S3 key: knowledge_base/{org_id}/{document_uuid}/{filename}
|
||||
s3_key = f"knowledge_base/{user.selected_organization_id}/{document_uuid}/{request.filename}"
|
||||
|
||||
# Generate presigned PUT URL (valid for 30 minutes)
|
||||
upload_url = await storage_fs.aget_presigned_put_url(
|
||||
file_path=s3_key,
|
||||
expiration=1800, # 30 minutes
|
||||
content_type=request.mime_type,
|
||||
max_size=100_000_000, # 100MB max
|
||||
)
|
||||
|
||||
if not upload_url:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to generate presigned upload URL"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Generated upload URL for document {document_uuid}, "
|
||||
f"user {user.id}, org {user.selected_organization_id}"
|
||||
)
|
||||
|
||||
return DocumentUploadResponseSchema(
|
||||
upload_url=upload_url,
|
||||
document_uuid=document_uuid,
|
||||
s3_key=s3_key,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error generating upload URL: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to generate upload URL"
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/process-document",
|
||||
response_model=DocumentResponseSchema,
|
||||
summary="Trigger document processing",
|
||||
)
|
||||
async def process_document(
|
||||
request: ProcessDocumentRequestSchema,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Trigger asynchronous processing of an uploaded document.
|
||||
|
||||
This endpoint should be called after successfully uploading a file to the presigned URL.
|
||||
It will:
|
||||
1. Create a document record in the database with the specified UUID
|
||||
2. Enqueue a background task to process the document (chunking and embedding)
|
||||
|
||||
The document status will be updated from 'pending' -> 'processing' -> 'completed' or 'failed'.
|
||||
|
||||
Access Control:
|
||||
* Users can only process documents in their organization.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Extract filename from s3_key
|
||||
filename = request.s3_key.split("/")[-1]
|
||||
|
||||
# Create document record with the specific UUID from upload
|
||||
document = await db_client.create_document(
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.id,
|
||||
filename=filename,
|
||||
file_size_bytes=0, # Will be updated by background task
|
||||
file_hash="", # Will be computed by background task
|
||||
mime_type="application/octet-stream", # Will be detected by background task
|
||||
custom_metadata={"s3_key": request.s3_key},
|
||||
document_uuid=request.document_uuid, # Use UUID from upload
|
||||
)
|
||||
|
||||
# Enqueue background task for processing
|
||||
await enqueue_job(
|
||||
FunctionNames.PROCESS_KNOWLEDGE_BASE_DOCUMENT,
|
||||
document.id,
|
||||
request.s3_key,
|
||||
user.selected_organization_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created document {request.document_uuid} (id={document.id}) and enqueued processing, "
|
||||
f"org {user.selected_organization_id}"
|
||||
)
|
||||
|
||||
return DocumentResponseSchema(
|
||||
id=document.id,
|
||||
document_uuid=request.document_uuid,
|
||||
filename=filename,
|
||||
file_size_bytes=0,
|
||||
file_hash="",
|
||||
mime_type="application/octet-stream",
|
||||
processing_status="pending",
|
||||
processing_error=None,
|
||||
total_chunks=0,
|
||||
custom_metadata={"s3_key": request.s3_key},
|
||||
docling_metadata={},
|
||||
source_url=None,
|
||||
created_at=document.created_at,
|
||||
updated_at=document.updated_at,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.id,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Error processing document: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to process document"
|
||||
) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/documents",
|
||||
response_model=DocumentListResponseSchema,
|
||||
summary="List documents",
|
||||
)
|
||||
async def list_documents(
|
||||
status: Annotated[
|
||||
Optional[str],
|
||||
Query(description="Filter by processing status"),
|
||||
] = None,
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 100,
|
||||
offset: Annotated[int, Query(ge=0)] = 0,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""List all documents for the user's organization.
|
||||
|
||||
Access Control:
|
||||
* Users can only see documents from their organization.
|
||||
"""
|
||||
|
||||
try:
|
||||
documents = await db_client.get_documents_for_organization(
|
||||
organization_id=user.selected_organization_id,
|
||||
processing_status=status,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
# Convert to response schema
|
||||
document_list = [
|
||||
DocumentResponseSchema(
|
||||
id=doc.id,
|
||||
document_uuid=doc.document_uuid,
|
||||
filename=doc.filename,
|
||||
file_size_bytes=doc.file_size_bytes,
|
||||
file_hash=doc.file_hash,
|
||||
mime_type=doc.mime_type,
|
||||
processing_status=doc.processing_status,
|
||||
processing_error=doc.processing_error,
|
||||
total_chunks=doc.total_chunks,
|
||||
custom_metadata=doc.custom_metadata,
|
||||
docling_metadata=doc.docling_metadata,
|
||||
source_url=doc.source_url,
|
||||
created_at=doc.created_at,
|
||||
updated_at=doc.updated_at,
|
||||
organization_id=doc.organization_id,
|
||||
created_by=doc.created_by,
|
||||
is_active=doc.is_active,
|
||||
)
|
||||
for doc in documents
|
||||
]
|
||||
|
||||
return DocumentListResponseSchema(
|
||||
documents=document_list,
|
||||
total=len(document_list),
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error listing documents: {exc}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list documents") from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/documents/{document_uuid}",
|
||||
response_model=DocumentResponseSchema,
|
||||
summary="Get document details",
|
||||
)
|
||||
async def get_document(
|
||||
document_uuid: str,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Get details of a specific document.
|
||||
|
||||
Access Control:
|
||||
* Users can only access documents from their organization.
|
||||
"""
|
||||
|
||||
try:
|
||||
document = await db_client.get_document_by_uuid(
|
||||
document_uuid=document_uuid,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
|
||||
if not document:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
return DocumentResponseSchema(
|
||||
id=document.id,
|
||||
document_uuid=document.document_uuid,
|
||||
filename=document.filename,
|
||||
file_size_bytes=document.file_size_bytes,
|
||||
file_hash=document.file_hash,
|
||||
mime_type=document.mime_type,
|
||||
processing_status=document.processing_status,
|
||||
processing_error=document.processing_error,
|
||||
total_chunks=document.total_chunks,
|
||||
custom_metadata=document.custom_metadata,
|
||||
docling_metadata=document.docling_metadata,
|
||||
source_url=document.source_url,
|
||||
created_at=document.created_at,
|
||||
updated_at=document.updated_at,
|
||||
organization_id=document.organization_id,
|
||||
created_by=document.created_by,
|
||||
is_active=document.is_active,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Error getting document: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to get document"
|
||||
) from exc
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/documents/{document_uuid}",
|
||||
summary="Delete document",
|
||||
)
|
||||
async def delete_document(
|
||||
document_uuid: str,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Soft delete a document and its chunks.
|
||||
|
||||
Access Control:
|
||||
* Users can only delete documents from their organization.
|
||||
"""
|
||||
|
||||
try:
|
||||
success = await db_client.delete_document(
|
||||
document_uuid=document_uuid,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
logger.info(
|
||||
f"Deleted document {document_uuid}, "
|
||||
f"user {user.id}, org {user.selected_organization_id}"
|
||||
)
|
||||
|
||||
return {"success": True, "message": "Document deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Error deleting document: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to delete document"
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/search",
|
||||
response_model=ChunkSearchResponseSchema,
|
||||
summary="Search for similar chunks",
|
||||
)
|
||||
async def search_chunks(
|
||||
request: ChunkSearchRequestSchema,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Search for document chunks similar to the query.
|
||||
|
||||
This endpoint uses vector similarity search to find relevant chunks.
|
||||
Results are returned without threshold filtering - apply similarity
|
||||
thresholds at the application layer after optional reranking.
|
||||
|
||||
Access Control:
|
||||
* Users can only search documents from their organization.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Import here to avoid circular dependency
|
||||
from api.services.admin_utils.local_exec import DocumentProcessor
|
||||
|
||||
# Initialize processor (reuses cached models)
|
||||
processor = DocumentProcessor(
|
||||
db_client=db_client,
|
||||
)
|
||||
|
||||
# Perform search
|
||||
results = await processor.search_similar_chunks(
|
||||
query=request.query,
|
||||
organization_id=user.selected_organization_id,
|
||||
limit=request.limit,
|
||||
document_uuids=request.document_uuids,
|
||||
)
|
||||
|
||||
# Apply similarity threshold if provided
|
||||
if request.min_similarity is not None:
|
||||
results = [r for r in results if r["similarity"] >= request.min_similarity]
|
||||
|
||||
# Convert to response schema
|
||||
from api.schemas.knowledge_base import ChunkResponseSchema
|
||||
|
||||
chunks = [
|
||||
ChunkResponseSchema(
|
||||
id=r["id"],
|
||||
document_id=r["document_id"],
|
||||
chunk_text=r["chunk_text"],
|
||||
contextualized_text=r.get("contextualized_text"),
|
||||
chunk_index=r["chunk_index"],
|
||||
chunk_metadata=r["chunk_metadata"],
|
||||
filename=r["filename"],
|
||||
document_uuid=r["document_uuid"],
|
||||
similarity=r["similarity"],
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
|
||||
return ChunkSearchResponseSchema(
|
||||
chunks=chunks,
|
||||
query=request.query,
|
||||
total_results=len(chunks),
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error searching chunks: {exc}")
|
||||
raise HTTPException(status_code=500, detail="Failed to search chunks") from exc
|
||||
|
|
@ -4,6 +4,7 @@ from loguru import logger
|
|||
from api.routes.campaign import router as campaign_router
|
||||
from api.routes.credentials import router as credentials_router
|
||||
from api.routes.integration import router as integration_router
|
||||
from api.routes.knowledge_base import router as knowledge_base_router
|
||||
from api.routes.looptalk import router as looptalk_router
|
||||
from api.routes.organization import router as organization_router
|
||||
from api.routes.organization_usage import router as organization_usage_router
|
||||
|
|
@ -43,6 +44,7 @@ router.include_router(webrtc_signaling_router)
|
|||
router.include_router(public_embed_router)
|
||||
router.include_router(public_agent_router)
|
||||
router.include_router(workflow_embed_router)
|
||||
router.include_router(knowledge_base_router)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
|
|
|
|||
97
api/schemas/knowledge_base.py
Normal file
97
api/schemas/knowledge_base.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
"""Pydantic schemas for knowledge base operations."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DocumentUploadRequestSchema(BaseModel):
|
||||
"""Request schema for initiating document upload."""
|
||||
|
||||
filename: str = Field(..., description="Name of the file to upload")
|
||||
mime_type: str = Field(..., description="MIME type of the file")
|
||||
custom_metadata: Optional[Dict[str, Any]] = Field(
|
||||
default=None, description="Optional custom metadata"
|
||||
)
|
||||
|
||||
|
||||
class DocumentUploadResponseSchema(BaseModel):
|
||||
"""Response schema containing upload URL and document metadata."""
|
||||
|
||||
upload_url: str = Field(..., description="Signed URL for uploading the file")
|
||||
document_uuid: str = Field(..., description="Unique identifier for the document")
|
||||
s3_key: str = Field(..., description="S3 key where file should be uploaded")
|
||||
|
||||
|
||||
class ProcessDocumentRequestSchema(BaseModel):
|
||||
"""Request schema for triggering document processing."""
|
||||
|
||||
document_uuid: str = Field(..., description="Document UUID to process")
|
||||
s3_key: str = Field(..., description="S3 key of the uploaded file")
|
||||
|
||||
|
||||
class DocumentResponseSchema(BaseModel):
|
||||
"""Response schema for document metadata."""
|
||||
|
||||
id: int
|
||||
document_uuid: str
|
||||
filename: str
|
||||
file_size_bytes: int
|
||||
file_hash: str
|
||||
mime_type: str
|
||||
processing_status: str # pending, processing, completed, failed
|
||||
processing_error: Optional[str] = None
|
||||
total_chunks: int
|
||||
custom_metadata: Dict[str, Any]
|
||||
docling_metadata: Dict[str, Any]
|
||||
source_url: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
organization_id: int
|
||||
created_by: int
|
||||
is_active: bool
|
||||
|
||||
|
||||
class DocumentListResponseSchema(BaseModel):
|
||||
"""Response schema for list of documents."""
|
||||
|
||||
documents: List[DocumentResponseSchema]
|
||||
total: int
|
||||
limit: int
|
||||
offset: int
|
||||
|
||||
|
||||
class ChunkSearchRequestSchema(BaseModel):
|
||||
"""Request schema for searching similar chunks."""
|
||||
|
||||
query: str = Field(..., description="Search query text")
|
||||
limit: int = Field(default=5, ge=1, le=50, description="Maximum number of results")
|
||||
document_uuids: Optional[List[str]] = Field(
|
||||
default=None, description="Filter by specific document UUIDs"
|
||||
)
|
||||
min_similarity: Optional[float] = Field(
|
||||
default=None, ge=0.0, le=1.0, description="Minimum similarity threshold"
|
||||
)
|
||||
|
||||
|
||||
class ChunkResponseSchema(BaseModel):
|
||||
"""Response schema for a document chunk."""
|
||||
|
||||
id: int
|
||||
document_id: int
|
||||
chunk_text: str
|
||||
contextualized_text: Optional[str]
|
||||
chunk_index: int
|
||||
chunk_metadata: Dict[str, Any]
|
||||
filename: str
|
||||
document_uuid: str
|
||||
similarity: float
|
||||
|
||||
|
||||
class ChunkSearchResponseSchema(BaseModel):
|
||||
"""Response schema for chunk search results."""
|
||||
|
||||
chunks: List[ChunkResponseSchema]
|
||||
query: str
|
||||
total_results: int
|
||||
|
|
@ -85,3 +85,16 @@ class BaseFileSystem(ABC):
|
|||
Optional[str]: Presigned PUT URL if successful, None otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def adownload_file(self, source_path: str, local_path: str) -> bool:
|
||||
"""Download a file from storage to local path.
|
||||
|
||||
Args:
|
||||
source_path: Path to the file in storage
|
||||
local_path: Local path where file should be downloaded
|
||||
|
||||
Returns:
|
||||
bool: True if file was downloaded successfully, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -170,3 +170,15 @@ class MinioFileSystem(BaseFileSystem):
|
|||
except Exception as e:
|
||||
logger.error(f"Error generating MinIO upload URL: {e}")
|
||||
return None
|
||||
|
||||
async def adownload_file(self, source_path: str, local_path: str) -> bool:
|
||||
"""Download a file from MinIO to local path."""
|
||||
try:
|
||||
|
||||
def _fget():
|
||||
self.client.fget_object(self.bucket_name, source_path, local_path)
|
||||
|
||||
await asyncio.to_thread(_fget)
|
||||
return True
|
||||
except S3Error:
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -126,3 +126,16 @@ class S3FileSystem(BaseFileSystem):
|
|||
return url
|
||||
except ClientError:
|
||||
return None
|
||||
|
||||
async def adownload_file(self, source_path: str, local_path: str) -> bool:
|
||||
"""Download a file from S3 to local path."""
|
||||
try:
|
||||
async with self.session.client(
|
||||
"s3", region_name=self.region_name
|
||||
) as s3_client:
|
||||
await s3_client.download_file(
|
||||
self.bucket_name, source_path, local_path
|
||||
)
|
||||
return True
|
||||
except ClientError:
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ from api.tasks.campaign_tasks import (
|
|||
process_campaign_batch,
|
||||
sync_campaign_source,
|
||||
)
|
||||
from api.tasks.knowledge_base_processing import process_knowledge_base_document
|
||||
from api.tasks.run_integrations import run_integrations_post_workflow_run
|
||||
from api.tasks.s3_upload import (
|
||||
upload_audio_to_s3,
|
||||
|
|
@ -64,6 +65,7 @@ class WorkerSettings:
|
|||
sync_campaign_source,
|
||||
process_campaign_batch,
|
||||
monitor_campaign_progress,
|
||||
process_knowledge_base_document,
|
||||
]
|
||||
cron_jobs = []
|
||||
redis_settings = REDIS_SETTINGS
|
||||
|
|
|
|||
|
|
@ -7,3 +7,4 @@ class FunctionNames:
|
|||
SYNC_CAMPAIGN_SOURCE = "sync_campaign_source"
|
||||
PROCESS_CAMPAIGN_BATCH = "process_campaign_batch"
|
||||
MONITOR_CAMPAIGN_PROGRESS = "monitor_campaign_progress"
|
||||
PROCESS_KNOWLEDGE_BASE_DOCUMENT = "process_knowledge_base_document"
|
||||
|
|
|
|||
252
api/tasks/knowledge_base_processing.py
Normal file
252
api/tasks/knowledge_base_processing.py
Normal file
|
|
@ -0,0 +1,252 @@
|
|||
"""ARQ background task for processing knowledge base documents."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from docling.chunking import HybridChunker
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
|
||||
from loguru import logger
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import KnowledgeBaseChunkModel, KnowledgeBaseDocumentModel
|
||||
from api.services.storage import storage_fs
|
||||
|
||||
# Constants
|
||||
EMBED_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
EMBEDDING_DIMENSION = 384
|
||||
|
||||
|
||||
async def process_knowledge_base_document(
|
||||
ctx, document_id: int, s3_key: str, organization_id: int, max_tokens: int = 128
|
||||
):
|
||||
"""Process a knowledge base document: download, chunk, embed, and store.
|
||||
|
||||
Args:
|
||||
ctx: ARQ context
|
||||
document_id: Database ID of the document
|
||||
s3_key: S3 key where the file is stored
|
||||
organization_id: Organization ID
|
||||
max_tokens: Maximum number of tokens per chunk (default: 128)
|
||||
"""
|
||||
logger.info(
|
||||
f"Starting knowledge base document processing for document_id={document_id}, "
|
||||
f"s3_key={s3_key}, organization_id={organization_id}"
|
||||
)
|
||||
|
||||
temp_file_path = None
|
||||
|
||||
try:
|
||||
# Update status to processing
|
||||
await db_client.update_document_status(document_id, "processing")
|
||||
|
||||
# Create temp file for download
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
|
||||
temp_file_path = temp_file.name
|
||||
temp_file.close()
|
||||
|
||||
# Download file from S3
|
||||
logger.info(f"Downloading file from S3: {s3_key}")
|
||||
download_success = await storage_fs.adownload_file(s3_key, temp_file_path)
|
||||
|
||||
if not download_success:
|
||||
raise Exception(f"Failed to download file from S3: {s3_key}")
|
||||
|
||||
if not os.path.exists(temp_file_path):
|
||||
raise FileNotFoundError(f"Downloaded file not found: {temp_file_path}")
|
||||
|
||||
file_size = os.path.getsize(temp_file_path)
|
||||
logger.info(f"Downloaded file size: {file_size} bytes")
|
||||
|
||||
# Compute file hash and get mime type
|
||||
file_hash = db_client.compute_file_hash(temp_file_path)
|
||||
mime_type = db_client.get_mime_type(temp_file_path)
|
||||
filename = s3_key.split("/")[-1]
|
||||
|
||||
# Get document record
|
||||
document = await db_client.get_document_by_id(document_id)
|
||||
if not document:
|
||||
raise Exception(f"Document {document_id} not found")
|
||||
|
||||
# Check if a document with this hash already exists (reject duplicates)
|
||||
existing_doc = await db_client.get_document_by_hash(file_hash, organization_id)
|
||||
if existing_doc and existing_doc.id != document_id:
|
||||
error_message = (
|
||||
f"This file is a duplicate of '{existing_doc.filename}'. "
|
||||
f"Please delete the duplicate files and consolidate them into a single unique file before uploading."
|
||||
)
|
||||
logger.warning(
|
||||
f"Duplicate document detected: {document_id} is duplicate of {existing_doc.id} "
|
||||
f"({existing_doc.filename})"
|
||||
)
|
||||
# Update file metadata
|
||||
await db_client.update_document_metadata(
|
||||
document_id,
|
||||
file_size_bytes=file_size,
|
||||
file_hash=file_hash,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
# Mark as failed with duplicate error message
|
||||
await db_client.update_document_status(
|
||||
document_id,
|
||||
"failed",
|
||||
error_message=error_message,
|
||||
docling_metadata={
|
||||
"duplicate_of": existing_doc.document_uuid,
|
||||
"duplicate_filename": existing_doc.filename,
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
# Update document with file metadata
|
||||
await db_client.update_document_metadata(
|
||||
document_id,
|
||||
file_size_bytes=file_size,
|
||||
file_hash=file_hash,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
|
||||
# Initialize models for processing
|
||||
cache_dir = os.path.expanduser("~/.cache/hf_models")
|
||||
logger.info(f"Loading embedding model: {EMBED_MODEL_ID} (cache: {cache_dir})")
|
||||
embedding_model = SentenceTransformer(
|
||||
EMBED_MODEL_ID,
|
||||
cache_folder=cache_dir,
|
||||
)
|
||||
|
||||
logger.info(f"Loading tokenizer: {EMBED_MODEL_ID} (cache: {cache_dir})")
|
||||
tokenizer = HuggingFaceTokenizer(
|
||||
tokenizer=AutoTokenizer.from_pretrained(
|
||||
EMBED_MODEL_ID,
|
||||
cache_dir=cache_dir,
|
||||
),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
|
||||
chunker = HybridChunker(tokenizer=tokenizer)
|
||||
|
||||
# Convert document with docling
|
||||
logger.info("Converting document with docling")
|
||||
converter = DocumentConverter()
|
||||
conversion_result = converter.convert(temp_file_path)
|
||||
doc = conversion_result.document
|
||||
|
||||
# Store docling metadata
|
||||
docling_metadata = {
|
||||
"num_pages": len(doc.pages) if hasattr(doc, "pages") else None,
|
||||
"document_type": type(doc).__name__,
|
||||
}
|
||||
|
||||
# Chunk the document
|
||||
logger.info(f"Chunking document with max_tokens={max_tokens}")
|
||||
chunks = list(chunker.chunk(dl_doc=doc))
|
||||
total_chunks = len(chunks)
|
||||
logger.info(f"Generated {total_chunks} chunks")
|
||||
|
||||
# Process each chunk
|
||||
chunk_texts = []
|
||||
chunk_records = []
|
||||
token_counts = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk_text = chunk.text
|
||||
contextualized_text = chunker.contextualize(chunk=chunk)
|
||||
|
||||
# Calculate token count
|
||||
text_to_tokenize = contextualized_text if contextualized_text else chunk_text
|
||||
token_count = len(
|
||||
tokenizer.tokenizer.encode(text_to_tokenize, add_special_tokens=False)
|
||||
)
|
||||
token_counts.append(token_count)
|
||||
|
||||
# Prepare chunk metadata
|
||||
chunk_metadata = {}
|
||||
if hasattr(chunk, "meta") and chunk.meta:
|
||||
chunk_metadata = {
|
||||
"doc_items": (
|
||||
[str(item) for item in chunk.meta.doc_items]
|
||||
if hasattr(chunk.meta, "doc_items")
|
||||
else []
|
||||
),
|
||||
"headings": (
|
||||
chunk.meta.headings if hasattr(chunk.meta, "headings") else []
|
||||
),
|
||||
}
|
||||
|
||||
# Create chunk record
|
||||
chunk_record = KnowledgeBaseChunkModel(
|
||||
document_id=document_id,
|
||||
organization_id=organization_id,
|
||||
chunk_text=chunk_text,
|
||||
contextualized_text=contextualized_text,
|
||||
chunk_index=i,
|
||||
chunk_metadata=chunk_metadata,
|
||||
embedding_model=EMBED_MODEL_ID,
|
||||
embedding_dimension=EMBEDDING_DIMENSION,
|
||||
token_count=token_count,
|
||||
)
|
||||
|
||||
chunk_records.append(chunk_record)
|
||||
chunk_texts.append(text_to_tokenize)
|
||||
|
||||
# Log chunk statistics
|
||||
if token_counts:
|
||||
avg_tokens = sum(token_counts) / len(token_counts)
|
||||
min_tokens = min(token_counts)
|
||||
max_tokens = max(token_counts)
|
||||
logger.info(f"Chunk token statistics:")
|
||||
logger.info(f" - Average: {avg_tokens:.1f} tokens")
|
||||
logger.info(f" - Min: {min_tokens} tokens")
|
||||
logger.info(f" - Max: {max_tokens} tokens")
|
||||
|
||||
# Generate embeddings in batch
|
||||
logger.info("Generating embeddings")
|
||||
embeddings = embedding_model.encode(
|
||||
chunk_texts,
|
||||
show_progress_bar=False,
|
||||
convert_to_numpy=True,
|
||||
)
|
||||
|
||||
# Attach embeddings to chunk records
|
||||
for chunk_record, embedding in zip(chunk_records, embeddings):
|
||||
chunk_record.embedding = embedding.tolist()
|
||||
|
||||
# Save chunks in database
|
||||
logger.info("Storing chunks in database")
|
||||
await db_client.create_chunks_batch(chunk_records)
|
||||
|
||||
# Update document status to completed
|
||||
await db_client.update_document_status(
|
||||
document_id,
|
||||
"completed",
|
||||
total_chunks=total_chunks,
|
||||
docling_metadata=docling_metadata,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Successfully processed knowledge base document {document_id}. "
|
||||
f"Total chunks: {total_chunks}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing knowledge base document {document_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Update document status to failed
|
||||
await db_client.update_document_status(
|
||||
document_id, "failed", error_message=str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Clean up temp file
|
||||
if temp_file_path and os.path.exists(temp_file_path):
|
||||
try:
|
||||
os.remove(temp_file_path)
|
||||
logger.debug(f"Cleaned up temp file: {temp_file_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp file {temp_file_path}: {e}")
|
||||
Loading…
Add table
Add a link
Reference in a new issue