mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: knowledge base functionality for the voice agent (#120)
* feat: upload file and store embedding * feat: add documents in nodes * feat: add openai embedding service
This commit is contained in:
parent
e2fa4bbb98
commit
ef5b9e40a9
52 changed files with 4551 additions and 114 deletions
194
api/alembic/versions/dc33eef8dabe_add_document_tables.py
Normal file
194
api/alembic/versions/dc33eef8dabe_add_document_tables.py
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
"""add document tables
|
||||
|
||||
Revision ID: dc33eef8dabe
|
||||
Revises: dcb0a27d98c6
|
||||
Create Date: 2026-01-16 13:40:17.808807
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "dc33eef8dabe"
|
||||
down_revision: Union[str, None] = "dcb0a27d98c6"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# Enable pgvector extension
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
|
||||
sa.Enum(
|
||||
"pending",
|
||||
"processing",
|
||||
"completed",
|
||||
"failed",
|
||||
name="document_processing_status",
|
||||
).create(op.get_bind())
|
||||
op.create_table(
|
||||
"knowledge_base_documents",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("document_uuid", sa.String(length=36), nullable=False),
|
||||
sa.Column("organization_id", sa.Integer(), nullable=False),
|
||||
sa.Column("filename", sa.String(length=500), nullable=False),
|
||||
sa.Column("file_size_bytes", sa.Integer(), nullable=True),
|
||||
sa.Column("file_hash", sa.String(length=64), nullable=True),
|
||||
sa.Column("mime_type", sa.String(length=100), nullable=True),
|
||||
sa.Column("source_url", sa.String(), nullable=True),
|
||||
sa.Column("total_chunks", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"processing_status",
|
||||
postgresql.ENUM(
|
||||
"pending",
|
||||
"processing",
|
||||
"completed",
|
||||
"failed",
|
||||
name="document_processing_status",
|
||||
create_type=False,
|
||||
),
|
||||
server_default=sa.text("'pending'::document_processing_status"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("processing_error", sa.Text(), nullable=True),
|
||||
sa.Column("docling_metadata", sa.JSON(), nullable=False),
|
||||
sa.Column("custom_metadata", sa.JSON(), nullable=False),
|
||||
sa.Column("created_by", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("archived_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["created_by"],
|
||||
["users.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"], ["organizations.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kb_documents_created_at",
|
||||
"knowledge_base_documents",
|
||||
["created_at"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kb_documents_organization_id",
|
||||
"knowledge_base_documents",
|
||||
["organization_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kb_documents_status",
|
||||
"knowledge_base_documents",
|
||||
["processing_status"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kb_documents_uuid",
|
||||
"knowledge_base_documents",
|
||||
["document_uuid"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_knowledge_base_documents_document_uuid"),
|
||||
"knowledge_base_documents",
|
||||
["document_uuid"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_table(
|
||||
"knowledge_base_chunks",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("document_id", sa.Integer(), nullable=False),
|
||||
sa.Column("organization_id", sa.Integer(), nullable=False),
|
||||
sa.Column("chunk_text", sa.Text(), nullable=False),
|
||||
sa.Column("contextualized_text", sa.Text(), nullable=True),
|
||||
sa.Column("chunk_index", sa.Integer(), nullable=False),
|
||||
sa.Column("chunk_metadata", sa.JSON(), nullable=False),
|
||||
sa.Column("embedding_model", sa.String(length=200), nullable=False),
|
||||
sa.Column("embedding_dimension", sa.Integer(), nullable=False),
|
||||
sa.Column("embedding", Vector(1536), nullable=True),
|
||||
sa.Column("token_count", sa.Integer(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["document_id"], ["knowledge_base_documents.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"], ["organizations.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kb_chunks_chunk_index",
|
||||
"knowledge_base_chunks",
|
||||
["chunk_index"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kb_chunks_document_id",
|
||||
"knowledge_base_chunks",
|
||||
["document_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kb_chunks_embedding_ivfflat",
|
||||
"knowledge_base_chunks",
|
||||
["embedding"],
|
||||
unique=False,
|
||||
postgresql_using="ivfflat",
|
||||
postgresql_with={"lists": 100},
|
||||
postgresql_ops={"embedding": "vector_cosine_ops"},
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kb_chunks_organization_id",
|
||||
"knowledge_base_chunks",
|
||||
["organization_id"],
|
||||
unique=False,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("ix_kb_chunks_organization_id", table_name="knowledge_base_chunks")
|
||||
op.drop_index(
|
||||
"ix_kb_chunks_embedding_ivfflat",
|
||||
table_name="knowledge_base_chunks",
|
||||
postgresql_using="ivfflat",
|
||||
postgresql_with={"lists": 100},
|
||||
postgresql_ops={"embedding": "vector_cosine_ops"},
|
||||
)
|
||||
op.drop_index("ix_kb_chunks_document_id", table_name="knowledge_base_chunks")
|
||||
op.drop_index("ix_kb_chunks_chunk_index", table_name="knowledge_base_chunks")
|
||||
op.drop_table("knowledge_base_chunks")
|
||||
op.drop_index(
|
||||
op.f("ix_knowledge_base_documents_document_uuid"),
|
||||
table_name="knowledge_base_documents",
|
||||
)
|
||||
op.drop_index("ix_kb_documents_uuid", table_name="knowledge_base_documents")
|
||||
op.drop_index("ix_kb_documents_status", table_name="knowledge_base_documents")
|
||||
op.drop_index(
|
||||
"ix_kb_documents_organization_id", table_name="knowledge_base_documents"
|
||||
)
|
||||
op.drop_index("ix_kb_documents_created_at", table_name="knowledge_base_documents")
|
||||
op.drop_table("knowledge_base_documents")
|
||||
sa.Enum(
|
||||
"pending",
|
||||
"processing",
|
||||
"completed",
|
||||
"failed",
|
||||
name="document_processing_status",
|
||||
).drop(op.get_bind())
|
||||
|
||||
# Note: We don't drop the vector extension as it may be used by other tables
|
||||
# If you want to drop it, uncomment the following line:
|
||||
# op.execute('DROP EXTENSION IF EXISTS vector')
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -3,6 +3,7 @@ from api.db.api_key_client import APIKeyClient
|
|||
from api.db.campaign_client import CampaignClient
|
||||
from api.db.embed_token_client import EmbedTokenClient
|
||||
from api.db.integration_client import IntegrationClient
|
||||
from api.db.knowledge_base_client import KnowledgeBaseClient
|
||||
from api.db.looptalk_client import LoopTalkClient
|
||||
from api.db.organization_client import OrganizationClient
|
||||
from api.db.organization_configuration_client import OrganizationConfigurationClient
|
||||
|
|
@ -33,6 +34,7 @@ class DBClient(
|
|||
AgentTriggerClient,
|
||||
WebhookCredentialClient,
|
||||
ToolClient,
|
||||
KnowledgeBaseClient,
|
||||
):
|
||||
"""
|
||||
Unified database client that combines all specialized database operations.
|
||||
|
|
@ -54,6 +56,7 @@ class DBClient(
|
|||
- AgentTriggerClient: handles agent trigger operations for API-based call triggering
|
||||
- WebhookCredentialClient: handles webhook credential operations
|
||||
- ToolClient: handles tool operations for reusable HTTP API tools
|
||||
- KnowledgeBaseClient: handles knowledge base document and vector search operations
|
||||
"""
|
||||
|
||||
pass
|
||||
|
|
|
|||
499
api/db/knowledge_base_client.py
Normal file
499
api/db/knowledge_base_client.py
Normal file
|
|
@ -0,0 +1,499 @@
|
|||
"""Database client for managing knowledge base documents and chunks."""
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import KnowledgeBaseChunkModel, KnowledgeBaseDocumentModel
|
||||
|
||||
|
||||
class KnowledgeBaseClient(BaseDBClient):
|
||||
"""Client for managing knowledge base documents and vector embeddings."""
|
||||
|
||||
async def create_document(
|
||||
self,
|
||||
organization_id: int,
|
||||
created_by: int,
|
||||
filename: str,
|
||||
file_size_bytes: int,
|
||||
file_hash: str,
|
||||
mime_type: str,
|
||||
source_url: Optional[str] = None,
|
||||
custom_metadata: Optional[dict] = None,
|
||||
docling_metadata: Optional[dict] = None,
|
||||
document_uuid: Optional[str] = None,
|
||||
) -> KnowledgeBaseDocumentModel:
|
||||
"""Create a new knowledge base document record.
|
||||
|
||||
Args:
|
||||
organization_id: ID of the organization
|
||||
created_by: ID of the user uploading the document
|
||||
filename: Name of the file
|
||||
file_size_bytes: Size of the file in bytes
|
||||
file_hash: SHA-256 hash of the file
|
||||
mime_type: MIME type of the file
|
||||
source_url: Optional URL if document was fetched from web
|
||||
custom_metadata: Optional custom metadata dictionary
|
||||
docling_metadata: Optional docling processing metadata
|
||||
document_uuid: Optional UUID to use (if not provided, one will be generated)
|
||||
|
||||
Returns:
|
||||
The created KnowledgeBaseDocumentModel
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
document = KnowledgeBaseDocumentModel(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
filename=filename,
|
||||
file_size_bytes=file_size_bytes,
|
||||
file_hash=file_hash,
|
||||
mime_type=mime_type,
|
||||
source_url=source_url,
|
||||
custom_metadata=custom_metadata or {},
|
||||
docling_metadata=docling_metadata or {},
|
||||
processing_status="pending",
|
||||
total_chunks=0,
|
||||
)
|
||||
|
||||
# Use provided UUID or let the model generate one
|
||||
if document_uuid:
|
||||
document.document_uuid = document_uuid
|
||||
|
||||
session.add(document)
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
|
||||
logger.info(
|
||||
f"Created document '{filename}' ({document.document_uuid}) "
|
||||
f"for organization {organization_id}"
|
||||
)
|
||||
return document
|
||||
|
||||
async def get_document_by_id(
|
||||
self,
|
||||
document_id: int,
|
||||
) -> Optional[KnowledgeBaseDocumentModel]:
|
||||
"""Get a document by its database ID.
|
||||
|
||||
Args:
|
||||
document_id: The database ID of the document
|
||||
|
||||
Returns:
|
||||
KnowledgeBaseDocumentModel if found, None otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(KnowledgeBaseDocumentModel).where(
|
||||
KnowledgeBaseDocumentModel.id == document_id
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_document_by_uuid(
|
||||
self,
|
||||
document_uuid: str,
|
||||
organization_id: int,
|
||||
) -> Optional[KnowledgeBaseDocumentModel]:
|
||||
"""Get a document by its UUID, scoped to organization.
|
||||
|
||||
Args:
|
||||
document_uuid: The unique document UUID
|
||||
organization_id: ID of the organization
|
||||
|
||||
Returns:
|
||||
KnowledgeBaseDocumentModel if found, None otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(KnowledgeBaseDocumentModel)
|
||||
.where(
|
||||
KnowledgeBaseDocumentModel.document_uuid == document_uuid,
|
||||
KnowledgeBaseDocumentModel.organization_id == organization_id,
|
||||
KnowledgeBaseDocumentModel.is_active == True,
|
||||
)
|
||||
.options(selectinload(KnowledgeBaseDocumentModel.created_by_user))
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_document_by_hash(
|
||||
self,
|
||||
file_hash: str,
|
||||
organization_id: int,
|
||||
) -> Optional[KnowledgeBaseDocumentModel]:
|
||||
"""Check if a document with the same hash already exists.
|
||||
|
||||
Returns the first matching document if multiple exist (can happen with duplicates).
|
||||
|
||||
Args:
|
||||
file_hash: SHA-256 hash of the file
|
||||
organization_id: ID of the organization
|
||||
|
||||
Returns:
|
||||
KnowledgeBaseDocumentModel if found, None otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(KnowledgeBaseDocumentModel)
|
||||
.where(
|
||||
KnowledgeBaseDocumentModel.file_hash == file_hash,
|
||||
KnowledgeBaseDocumentModel.organization_id == organization_id,
|
||||
KnowledgeBaseDocumentModel.is_active == True,
|
||||
)
|
||||
.order_by(KnowledgeBaseDocumentModel.created_at.asc())
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_documents_for_organization(
|
||||
self,
|
||||
organization_id: int,
|
||||
processing_status: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> List[KnowledgeBaseDocumentModel]:
|
||||
"""Get all documents for an organization.
|
||||
|
||||
Args:
|
||||
organization_id: ID of the organization
|
||||
processing_status: Optional filter by status
|
||||
limit: Maximum number of documents to return
|
||||
offset: Number of documents to skip
|
||||
|
||||
Returns:
|
||||
List of KnowledgeBaseDocumentModel instances
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(KnowledgeBaseDocumentModel).where(
|
||||
KnowledgeBaseDocumentModel.organization_id == organization_id,
|
||||
KnowledgeBaseDocumentModel.is_active == True,
|
||||
)
|
||||
|
||||
if processing_status:
|
||||
query = query.where(
|
||||
KnowledgeBaseDocumentModel.processing_status == processing_status
|
||||
)
|
||||
|
||||
query = (
|
||||
query.order_by(KnowledgeBaseDocumentModel.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def update_document_metadata(
|
||||
self,
|
||||
document_id: int,
|
||||
file_size_bytes: Optional[int] = None,
|
||||
file_hash: Optional[str] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
) -> Optional[KnowledgeBaseDocumentModel]:
|
||||
"""Update document file metadata.
|
||||
|
||||
Args:
|
||||
document_id: ID of the document
|
||||
file_size_bytes: Optional file size in bytes
|
||||
file_hash: Optional SHA-256 hash of the file
|
||||
mime_type: Optional MIME type
|
||||
|
||||
Returns:
|
||||
Updated KnowledgeBaseDocumentModel
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(KnowledgeBaseDocumentModel).where(
|
||||
KnowledgeBaseDocumentModel.id == document_id
|
||||
)
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if not document:
|
||||
return None
|
||||
|
||||
if file_size_bytes is not None:
|
||||
document.file_size_bytes = file_size_bytes
|
||||
if file_hash is not None:
|
||||
document.file_hash = file_hash
|
||||
if mime_type is not None:
|
||||
document.mime_type = mime_type
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
|
||||
logger.info(f"Updated document {document_id} metadata")
|
||||
return document
|
||||
|
||||
async def update_document_status(
|
||||
self,
|
||||
document_id: int,
|
||||
status: str,
|
||||
error_message: Optional[str] = None,
|
||||
total_chunks: Optional[int] = None,
|
||||
docling_metadata: Optional[dict] = None,
|
||||
) -> Optional[KnowledgeBaseDocumentModel]:
|
||||
"""Update document processing status.
|
||||
|
||||
Args:
|
||||
document_id: ID of the document
|
||||
status: New status (pending, processing, completed, failed)
|
||||
error_message: Optional error message if status is failed
|
||||
total_chunks: Optional total number of chunks
|
||||
docling_metadata: Optional docling metadata
|
||||
|
||||
Returns:
|
||||
Updated KnowledgeBaseDocumentModel
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(KnowledgeBaseDocumentModel).where(
|
||||
KnowledgeBaseDocumentModel.id == document_id
|
||||
)
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if not document:
|
||||
return None
|
||||
|
||||
document.processing_status = status
|
||||
if error_message:
|
||||
document.processing_error = error_message
|
||||
if total_chunks is not None:
|
||||
document.total_chunks = total_chunks
|
||||
if docling_metadata:
|
||||
document.docling_metadata = docling_metadata
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
|
||||
logger.info(f"Updated document {document_id} status to {status}")
|
||||
return document
|
||||
|
||||
async def create_chunks_batch(
|
||||
self,
|
||||
chunks: List[KnowledgeBaseChunkModel],
|
||||
) -> List[KnowledgeBaseChunkModel]:
|
||||
"""Create multiple chunks in a batch.
|
||||
|
||||
Args:
|
||||
chunks: List of KnowledgeBaseChunkModel instances
|
||||
|
||||
Returns:
|
||||
List of created chunks with IDs
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
session.add_all(chunks)
|
||||
await session.commit()
|
||||
|
||||
for chunk in chunks:
|
||||
await session.refresh(chunk)
|
||||
|
||||
logger.info(f"Created {len(chunks)} chunks")
|
||||
return chunks
|
||||
|
||||
async def get_chunks_for_document(
|
||||
self,
|
||||
document_id: int,
|
||||
organization_id: int,
|
||||
) -> List[KnowledgeBaseChunkModel]:
|
||||
"""Get all chunks for a document.
|
||||
|
||||
Args:
|
||||
document_id: ID of the document
|
||||
organization_id: ID of the organization (for authorization)
|
||||
|
||||
Returns:
|
||||
List of KnowledgeBaseChunkModel instances
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(KnowledgeBaseChunkModel)
|
||||
.where(
|
||||
KnowledgeBaseChunkModel.document_id == document_id,
|
||||
KnowledgeBaseChunkModel.organization_id == organization_id,
|
||||
)
|
||||
.order_by(KnowledgeBaseChunkModel.chunk_index)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def search_similar_chunks(
|
||||
self,
|
||||
query_embedding: List[float],
|
||||
organization_id: int,
|
||||
limit: int = 5,
|
||||
document_ids: Optional[List[int]] = None,
|
||||
document_uuids: Optional[List[str]] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
) -> List[dict]:
|
||||
"""Search for similar chunks using vector similarity.
|
||||
|
||||
Returns top-k most similar chunks without any similarity threshold filtering.
|
||||
Filtering and reranking should be done at the application layer.
|
||||
|
||||
Args:
|
||||
query_embedding: The query embedding vector
|
||||
organization_id: Organization ID for scoping
|
||||
limit: Maximum number of results to return
|
||||
document_ids: Optional list of document IDs to filter by
|
||||
document_uuids: Optional list of document UUIDs to filter by
|
||||
embedding_model: Optional embedding model to filter by (for dimension compatibility)
|
||||
|
||||
Returns:
|
||||
List of dictionaries with chunk data and similarity scores, ordered by similarity (highest first)
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# Get the raw connection to execute directly with asyncpg
|
||||
# This avoids parameter binding issues with text() and asyncpg
|
||||
connection = await session.connection()
|
||||
raw_connection = await connection.get_raw_connection()
|
||||
|
||||
# Build WHERE clause conditions (no similarity threshold)
|
||||
where_conditions = [
|
||||
"c.organization_id = $2",
|
||||
"d.is_active = true",
|
||||
]
|
||||
params = [
|
||||
None,
|
||||
organization_id,
|
||||
limit,
|
||||
] # $1 will be embedding_str, $3 is limit
|
||||
param_index = 4 # Next available parameter index
|
||||
|
||||
# Add document_ids filter if provided
|
||||
if document_ids:
|
||||
placeholders = ", ".join(
|
||||
f"${param_index + i}" for i in range(len(document_ids))
|
||||
)
|
||||
where_conditions.append(f"c.document_id IN ({placeholders})")
|
||||
params.extend(document_ids)
|
||||
param_index += len(document_ids)
|
||||
|
||||
# Add document_uuids filter if provided
|
||||
if document_uuids:
|
||||
placeholders = ", ".join(
|
||||
f"${param_index + i}" for i in range(len(document_uuids))
|
||||
)
|
||||
where_conditions.append(f"d.document_uuid IN ({placeholders})")
|
||||
params.extend(document_uuids)
|
||||
param_index += len(document_uuids)
|
||||
|
||||
# Add embedding_model filter if provided (for dimension compatibility)
|
||||
if embedding_model:
|
||||
where_conditions.append(f"c.embedding_model = ${param_index}")
|
||||
params.append(embedding_model)
|
||||
param_index += 1
|
||||
|
||||
# Build the complete SQL query
|
||||
where_clause = " AND ".join(where_conditions)
|
||||
query_sql = f"""
|
||||
SELECT
|
||||
c.id,
|
||||
c.document_id,
|
||||
c.chunk_text,
|
||||
c.contextualized_text,
|
||||
c.chunk_metadata,
|
||||
c.chunk_index,
|
||||
d.filename,
|
||||
d.document_uuid,
|
||||
1 - (c.embedding <=> $1::vector) as similarity
|
||||
FROM knowledge_base_chunks c
|
||||
JOIN knowledge_base_documents d ON c.document_id = d.id
|
||||
WHERE {where_clause}
|
||||
ORDER BY c.embedding <=> $1::vector
|
||||
LIMIT $3
|
||||
"""
|
||||
|
||||
# Convert embedding to string format for PostgreSQL vector type
|
||||
embedding_str = "[" + ",".join(map(str, query_embedding)) + "]"
|
||||
params[0] = embedding_str # Set $1
|
||||
|
||||
# Execute query directly with asyncpg
|
||||
rows = await raw_connection.driver_connection.fetch(
|
||||
query_sql,
|
||||
*params,
|
||||
)
|
||||
|
||||
# Convert asyncpg records to dictionaries
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
async def delete_document(
|
||||
self,
|
||||
document_uuid: str,
|
||||
organization_id: int,
|
||||
) -> bool:
|
||||
"""Soft delete a document by setting is_active to False.
|
||||
|
||||
This will also cascade delete all chunks via the database foreign key.
|
||||
|
||||
Args:
|
||||
document_uuid: The unique document UUID
|
||||
organization_id: ID of the organization (for authorization)
|
||||
|
||||
Returns:
|
||||
True if document was deleted, False if not found
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(KnowledgeBaseDocumentModel).where(
|
||||
KnowledgeBaseDocumentModel.document_uuid == document_uuid,
|
||||
KnowledgeBaseDocumentModel.organization_id == organization_id,
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if not document:
|
||||
return False
|
||||
|
||||
document.is_active = False
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Deleted document {document_uuid} for organization {organization_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def compute_file_hash(file_path: str) -> str:
|
||||
"""Compute SHA-256 hash of a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
SHA-256 hash as hex string
|
||||
"""
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
for byte_block in iter(lambda: f.read(4096), b""):
|
||||
sha256_hash.update(byte_block)
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def get_mime_type(file_path: str) -> str:
|
||||
"""Get MIME type based on file extension.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
MIME type string
|
||||
"""
|
||||
extension = Path(file_path).suffix.lower()
|
||||
mime_types = {
|
||||
".pdf": "application/pdf",
|
||||
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
".doc": "application/msword",
|
||||
".txt": "text/plain",
|
||||
".html": "text/html",
|
||||
".md": "text/markdown",
|
||||
}
|
||||
return mime_types.get(extension, "application/octet-stream")
|
||||
178
api/db/models.py
178
api/db/models.py
|
|
@ -2,6 +2,7 @@ import uuid
|
|||
from datetime import UTC, datetime
|
||||
|
||||
from loguru import logger
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
Boolean,
|
||||
|
|
@ -14,6 +15,7 @@ from sqlalchemy import (
|
|||
Integer,
|
||||
String,
|
||||
Table,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
and_,
|
||||
text,
|
||||
|
|
@ -890,3 +892,179 @@ class ToolModel(Base):
|
|||
Index("ix_tools_status", "status"),
|
||||
Index("ix_tools_category", "category"),
|
||||
)
|
||||
|
||||
|
||||
class KnowledgeBaseDocumentModel(Base):
|
||||
"""Model for storing document-level metadata in the knowledge base.
|
||||
|
||||
Each document represents a source file (PDF, DOCX, etc.) that has been
|
||||
processed and chunked for retrieval.
|
||||
"""
|
||||
|
||||
__tablename__ = "knowledge_base_documents"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
# Public identifier for API references
|
||||
document_uuid = Column(
|
||||
String(36),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
index=True,
|
||||
default=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
# Organization scoping
|
||||
organization_id = Column(
|
||||
Integer, ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
# Document metadata
|
||||
filename = Column(String(500), nullable=False)
|
||||
file_size_bytes = Column(Integer, nullable=True)
|
||||
file_hash = Column(String(64), nullable=True) # SHA-256 hash for deduplication
|
||||
mime_type = Column(String(100), nullable=True)
|
||||
|
||||
# Processing metadata
|
||||
source_url = Column(String, nullable=True) # If document was fetched from URL
|
||||
total_chunks = Column(Integer, nullable=False, default=0)
|
||||
processing_status = Column(
|
||||
Enum(
|
||||
"pending",
|
||||
"processing",
|
||||
"completed",
|
||||
"failed",
|
||||
name="document_processing_status",
|
||||
),
|
||||
nullable=False,
|
||||
default="pending",
|
||||
server_default=text("'pending'::document_processing_status"),
|
||||
)
|
||||
processing_error = Column(Text, nullable=True)
|
||||
|
||||
# Docling conversion metadata
|
||||
docling_metadata = Column(
|
||||
JSON, nullable=False, default=dict
|
||||
) # Store docling document metadata
|
||||
|
||||
# Custom metadata (user-defined tags, categories, etc.)
|
||||
custom_metadata = Column(JSON, nullable=False, default=dict)
|
||||
|
||||
# Audit fields
|
||||
created_by = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
)
|
||||
|
||||
# Soft delete
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
archived_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationships
|
||||
organization = relationship("OrganizationModel")
|
||||
created_by_user = relationship("UserModel")
|
||||
chunks = relationship(
|
||||
"KnowledgeBaseChunkModel",
|
||||
back_populates="document",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Indexes and constraints
|
||||
__table_args__ = (
|
||||
Index("ix_kb_documents_organization_id", "organization_id"),
|
||||
Index("ix_kb_documents_uuid", "document_uuid"),
|
||||
Index("ix_kb_documents_status", "processing_status"),
|
||||
Index("ix_kb_documents_created_at", "created_at"),
|
||||
)
|
||||
|
||||
|
||||
class KnowledgeBaseChunkModel(Base):
|
||||
"""Model for storing document chunks with vector embeddings.
|
||||
|
||||
Each chunk represents a portion of a document that has been:
|
||||
1. Extracted and chunked by docling's HybridChunker
|
||||
2. Optionally contextualized with surrounding information
|
||||
3. Embedded into a vector representation for semantic search
|
||||
"""
|
||||
|
||||
__tablename__ = "knowledge_base_chunks"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
# Link to parent document
|
||||
document_id = Column(
|
||||
Integer,
|
||||
ForeignKey("knowledge_base_documents.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Organization scoping (denormalized for efficient querying)
|
||||
organization_id = Column(
|
||||
Integer, ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
# Chunk content
|
||||
chunk_text = Column(Text, nullable=False) # The actual chunk text
|
||||
contextualized_text = Column(
|
||||
Text, nullable=True
|
||||
) # Enriched text from chunker.contextualize()
|
||||
|
||||
# Chunk positioning and metadata
|
||||
chunk_index = Column(Integer, nullable=False) # Position in document (0-based)
|
||||
|
||||
# Docling chunk metadata
|
||||
chunk_metadata = Column(
|
||||
JSON, nullable=False, default=dict
|
||||
) # Store chunk.meta if available
|
||||
|
||||
# Embedding configuration
|
||||
embedding_model = Column(
|
||||
String(200), nullable=False
|
||||
) # e.g., "sentence-transformers/all-MiniLM-L6-v2"
|
||||
embedding_dimension = Column(
|
||||
Integer, nullable=False
|
||||
) # e.g., 384 for all-MiniLM-L6-v2
|
||||
|
||||
# Vector embedding (pgvector column)
|
||||
# The dimension should match the embedding_dimension field
|
||||
# Default: 1536 dimensions for OpenAI text-embedding-3-small
|
||||
# SentenceTransformer (384-dim) also supported but stored as 384-dim vectors
|
||||
embedding = Column(Vector(1536), nullable=True)
|
||||
|
||||
# Token count (useful for chunking strategy analysis)
|
||||
token_count = Column(Integer, nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
)
|
||||
|
||||
# Relationships
|
||||
document = relationship("KnowledgeBaseDocumentModel", back_populates="chunks")
|
||||
organization = relationship("OrganizationModel")
|
||||
|
||||
# Indexes and constraints
|
||||
__table_args__ = (
|
||||
Index("ix_kb_chunks_document_id", "document_id"),
|
||||
Index("ix_kb_chunks_organization_id", "organization_id"),
|
||||
Index("ix_kb_chunks_chunk_index", "chunk_index"),
|
||||
Index(
|
||||
"ix_kb_chunks_embedding_model", "embedding_model"
|
||||
), # For filtering by model
|
||||
# Vector similarity search index (using IVFFlat or HNSW)
|
||||
# IVFFlat is good for datasets with 10k-1M vectors
|
||||
# HNSW is better for larger datasets but uses more memory
|
||||
Index(
|
||||
"ix_kb_chunks_embedding_ivfflat",
|
||||
"embedding",
|
||||
postgresql_using="ivfflat",
|
||||
postgresql_with={"lists": 100}, # Adjust based on dataset size
|
||||
postgresql_ops={"embedding": "vector_cosine_ops"},
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,3 +13,6 @@ python-multipart==0.0.20
|
|||
sentry-sdk[fastapi]==2.38.0
|
||||
sqlalchemy[asyncio]==2.0.43
|
||||
msgpack==1.1.2
|
||||
docling[rapidocr]==2.68.0
|
||||
sentence-transformers==5.2.0
|
||||
pgvector==0.4.2
|
||||
|
|
|
|||
405
api/routes/knowledge_base.py
Normal file
405
api/routes/knowledge_base.py
Normal file
|
|
@ -0,0 +1,405 @@
|
|||
"""API routes for knowledge base operations."""
|
||||
|
||||
import uuid
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.schemas.knowledge_base import (
|
||||
ChunkSearchRequestSchema,
|
||||
ChunkSearchResponseSchema,
|
||||
DocumentListResponseSchema,
|
||||
DocumentResponseSchema,
|
||||
DocumentUploadRequestSchema,
|
||||
DocumentUploadResponseSchema,
|
||||
ProcessDocumentRequestSchema,
|
||||
)
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.storage import storage_fs
|
||||
from api.tasks.arq import enqueue_job
|
||||
from api.tasks.function_names import FunctionNames
|
||||
|
||||
router = APIRouter(prefix="/knowledge-base", tags=["knowledge-base"])
|
||||
|
||||
|
||||
@router.post(
|
||||
"/upload-url",
|
||||
response_model=DocumentUploadResponseSchema,
|
||||
summary="Get presigned URL for document upload",
|
||||
)
|
||||
async def get_upload_url(
|
||||
request: DocumentUploadRequestSchema,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Generate a presigned PUT URL for uploading a document.
|
||||
|
||||
This endpoint:
|
||||
1. Generates a unique document UUID for organizing the S3 key
|
||||
2. Generates a presigned S3/MinIO URL for uploading the file
|
||||
3. Returns the upload URL and document metadata
|
||||
|
||||
After uploading to the returned URL, call /process-document to create
|
||||
the document record and trigger processing.
|
||||
|
||||
Access Control:
|
||||
* All authenticated users can upload documents scoped to their organization.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Generate unique document UUID for S3 organization
|
||||
document_uuid = str(uuid.uuid4())
|
||||
|
||||
# Generate S3 key: knowledge_base/{org_id}/{document_uuid}/{filename}
|
||||
s3_key = f"knowledge_base/{user.selected_organization_id}/{document_uuid}/{request.filename}"
|
||||
|
||||
# Generate presigned PUT URL (valid for 30 minutes)
|
||||
upload_url = await storage_fs.aget_presigned_put_url(
|
||||
file_path=s3_key,
|
||||
expiration=1800, # 30 minutes
|
||||
content_type=request.mime_type,
|
||||
max_size=100_000_000, # 100MB max
|
||||
)
|
||||
|
||||
if not upload_url:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to generate presigned upload URL"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Generated upload URL for document {document_uuid}, "
|
||||
f"user {user.id}, org {user.selected_organization_id}"
|
||||
)
|
||||
|
||||
return DocumentUploadResponseSchema(
|
||||
upload_url=upload_url,
|
||||
document_uuid=document_uuid,
|
||||
s3_key=s3_key,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error generating upload URL: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to generate upload URL"
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/process-document",
|
||||
response_model=DocumentResponseSchema,
|
||||
summary="Trigger document processing",
|
||||
)
|
||||
async def process_document(
|
||||
request: ProcessDocumentRequestSchema,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Trigger asynchronous processing of an uploaded document.
|
||||
|
||||
This endpoint should be called after successfully uploading a file to the presigned URL.
|
||||
It will:
|
||||
1. Create a document record in the database with the specified UUID
|
||||
2. Enqueue a background task to process the document (chunking and embedding)
|
||||
|
||||
The document status will be updated from 'pending' -> 'processing' -> 'completed' or 'failed'.
|
||||
|
||||
Embedding Services:
|
||||
* openai (default): High-quality 1536-dimensional embeddings (requires OPENAI_API_KEY)
|
||||
* sentence_transformer: Free, offline-capable, 384-dimensional embeddings
|
||||
|
||||
Access Control:
|
||||
* Users can only process documents in their organization.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Extract filename from s3_key
|
||||
filename = request.s3_key.split("/")[-1]
|
||||
|
||||
# Create document record with the specific UUID from upload
|
||||
document = await db_client.create_document(
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.id,
|
||||
filename=filename,
|
||||
file_size_bytes=0, # Will be updated by background task
|
||||
file_hash="", # Will be computed by background task
|
||||
mime_type="application/octet-stream", # Will be detected by background task
|
||||
custom_metadata={"s3_key": request.s3_key},
|
||||
document_uuid=request.document_uuid, # Use UUID from upload
|
||||
)
|
||||
|
||||
# Enqueue background task for processing
|
||||
await enqueue_job(
|
||||
FunctionNames.PROCESS_KNOWLEDGE_BASE_DOCUMENT,
|
||||
document.id,
|
||||
request.s3_key,
|
||||
user.selected_organization_id,
|
||||
128, # max_tokens (default)
|
||||
request.embedding_service,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created document {request.document_uuid} (id={document.id}) and enqueued processing "
|
||||
f"with {request.embedding_service} embeddings, org {user.selected_organization_id}"
|
||||
)
|
||||
|
||||
return DocumentResponseSchema(
|
||||
id=document.id,
|
||||
document_uuid=request.document_uuid,
|
||||
filename=filename,
|
||||
file_size_bytes=0,
|
||||
file_hash="",
|
||||
mime_type="application/octet-stream",
|
||||
processing_status="pending",
|
||||
processing_error=None,
|
||||
total_chunks=0,
|
||||
custom_metadata={"s3_key": request.s3_key},
|
||||
docling_metadata={},
|
||||
source_url=None,
|
||||
created_at=document.created_at,
|
||||
updated_at=document.updated_at,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.id,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Error processing document: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to process document"
|
||||
) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/documents",
|
||||
response_model=DocumentListResponseSchema,
|
||||
summary="List documents",
|
||||
)
|
||||
async def list_documents(
|
||||
status: Annotated[
|
||||
Optional[str],
|
||||
Query(description="Filter by processing status"),
|
||||
] = None,
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 100,
|
||||
offset: Annotated[int, Query(ge=0)] = 0,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""List all documents for the user's organization.
|
||||
|
||||
Access Control:
|
||||
* Users can only see documents from their organization.
|
||||
"""
|
||||
|
||||
try:
|
||||
documents = await db_client.get_documents_for_organization(
|
||||
organization_id=user.selected_organization_id,
|
||||
processing_status=status,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
# Convert to response schema
|
||||
document_list = [
|
||||
DocumentResponseSchema(
|
||||
id=doc.id,
|
||||
document_uuid=doc.document_uuid,
|
||||
filename=doc.filename,
|
||||
file_size_bytes=doc.file_size_bytes,
|
||||
file_hash=doc.file_hash,
|
||||
mime_type=doc.mime_type,
|
||||
processing_status=doc.processing_status,
|
||||
processing_error=doc.processing_error,
|
||||
total_chunks=doc.total_chunks,
|
||||
custom_metadata=doc.custom_metadata,
|
||||
docling_metadata=doc.docling_metadata,
|
||||
source_url=doc.source_url,
|
||||
created_at=doc.created_at,
|
||||
updated_at=doc.updated_at,
|
||||
organization_id=doc.organization_id,
|
||||
created_by=doc.created_by,
|
||||
is_active=doc.is_active,
|
||||
)
|
||||
for doc in documents
|
||||
]
|
||||
|
||||
return DocumentListResponseSchema(
|
||||
documents=document_list,
|
||||
total=len(document_list),
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error listing documents: {exc}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list documents") from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/documents/{document_uuid}",
|
||||
response_model=DocumentResponseSchema,
|
||||
summary="Get document details",
|
||||
)
|
||||
async def get_document(
|
||||
document_uuid: str,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Get details of a specific document.
|
||||
|
||||
Access Control:
|
||||
* Users can only access documents from their organization.
|
||||
"""
|
||||
|
||||
try:
|
||||
document = await db_client.get_document_by_uuid(
|
||||
document_uuid=document_uuid,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
|
||||
if not document:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
return DocumentResponseSchema(
|
||||
id=document.id,
|
||||
document_uuid=document.document_uuid,
|
||||
filename=document.filename,
|
||||
file_size_bytes=document.file_size_bytes,
|
||||
file_hash=document.file_hash,
|
||||
mime_type=document.mime_type,
|
||||
processing_status=document.processing_status,
|
||||
processing_error=document.processing_error,
|
||||
total_chunks=document.total_chunks,
|
||||
custom_metadata=document.custom_metadata,
|
||||
docling_metadata=document.docling_metadata,
|
||||
source_url=document.source_url,
|
||||
created_at=document.created_at,
|
||||
updated_at=document.updated_at,
|
||||
organization_id=document.organization_id,
|
||||
created_by=document.created_by,
|
||||
is_active=document.is_active,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Error getting document: {exc}")
|
||||
raise HTTPException(status_code=500, detail="Failed to get document") from exc
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/documents/{document_uuid}",
|
||||
summary="Delete document",
|
||||
)
|
||||
async def delete_document(
|
||||
document_uuid: str,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Soft delete a document and its chunks.
|
||||
|
||||
Access Control:
|
||||
* Users can only delete documents from their organization.
|
||||
"""
|
||||
|
||||
try:
|
||||
success = await db_client.delete_document(
|
||||
document_uuid=document_uuid,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
logger.info(
|
||||
f"Deleted document {document_uuid}, "
|
||||
f"user {user.id}, org {user.selected_organization_id}"
|
||||
)
|
||||
|
||||
return {"success": True, "message": "Document deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Error deleting document: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to delete document"
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/search",
|
||||
response_model=ChunkSearchResponseSchema,
|
||||
summary="Search for similar chunks",
|
||||
)
|
||||
async def search_chunks(
|
||||
request: ChunkSearchRequestSchema,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Search for document chunks similar to the query.
|
||||
|
||||
This endpoint uses vector similarity search to find relevant chunks.
|
||||
Results are returned without threshold filtering - apply similarity
|
||||
thresholds at the application layer after optional reranking.
|
||||
|
||||
Access Control:
|
||||
* Users can only search documents from their organization.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Import here to avoid circular dependency
|
||||
from api.services.gen_ai import OpenAIEmbeddingService
|
||||
|
||||
# Try to get user's embeddings configuration
|
||||
user_config = await db_client.get_user_configurations(user.id)
|
||||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
|
||||
if user_config.embeddings:
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
|
||||
# Initialize embedding service with user config or fallback to env
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
)
|
||||
|
||||
# Perform search
|
||||
results = await embedding_service.search_similar_chunks(
|
||||
query=request.query,
|
||||
organization_id=user.selected_organization_id,
|
||||
limit=request.limit,
|
||||
document_uuids=request.document_uuids,
|
||||
)
|
||||
|
||||
# Apply similarity threshold if provided
|
||||
if request.min_similarity is not None:
|
||||
results = [r for r in results if r["similarity"] >= request.min_similarity]
|
||||
|
||||
# Convert to response schema
|
||||
from api.schemas.knowledge_base import ChunkResponseSchema
|
||||
|
||||
chunks = [
|
||||
ChunkResponseSchema(
|
||||
id=r["id"],
|
||||
document_id=r["document_id"],
|
||||
chunk_text=r["chunk_text"],
|
||||
contextualized_text=r.get("contextualized_text"),
|
||||
chunk_index=r["chunk_index"],
|
||||
chunk_metadata=r["chunk_metadata"],
|
||||
filename=r["filename"],
|
||||
document_uuid=r["document_uuid"],
|
||||
similarity=r["similarity"],
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
|
||||
return ChunkSearchResponseSchema(
|
||||
chunks=chunks,
|
||||
query=request.query,
|
||||
total_results=len(chunks),
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error searching chunks: {exc}")
|
||||
raise HTTPException(status_code=500, detail="Failed to search chunks") from exc
|
||||
|
|
@ -4,6 +4,7 @@ from loguru import logger
|
|||
from api.routes.campaign import router as campaign_router
|
||||
from api.routes.credentials import router as credentials_router
|
||||
from api.routes.integration import router as integration_router
|
||||
from api.routes.knowledge_base import router as knowledge_base_router
|
||||
from api.routes.looptalk import router as looptalk_router
|
||||
from api.routes.organization import router as organization_router
|
||||
from api.routes.organization_usage import router as organization_usage_router
|
||||
|
|
@ -43,6 +44,7 @@ router.include_router(webrtc_signaling_router)
|
|||
router.include_router(public_embed_router)
|
||||
router.include_router(public_agent_router)
|
||||
router.include_router(workflow_embed_router)
|
||||
router.include_router(knowledge_base_router)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ class DefaultConfigurationsResponse(TypedDict):
|
|||
llm: dict[str, dict]
|
||||
tts: dict[str, dict]
|
||||
stt: dict[str, dict]
|
||||
embeddings: dict[str, dict]
|
||||
default_providers: dict[str, str]
|
||||
|
||||
|
||||
|
|
@ -50,6 +51,10 @@ async def get_default_configurations() -> DefaultConfigurationsResponse:
|
|||
provider: model_cls.model_json_schema()
|
||||
for provider, model_cls in REGISTRY[ServiceType.STT].items()
|
||||
},
|
||||
"embeddings": {
|
||||
provider: model_cls.model_json_schema()
|
||||
for provider, model_cls in REGISTRY[ServiceType.EMBEDDINGS].items()
|
||||
},
|
||||
"default_providers": DEFAULT_SERVICE_PROVIDERS,
|
||||
}
|
||||
return configurations
|
||||
|
|
@ -69,6 +74,7 @@ class UserConfigurationRequestResponseSchema(BaseModel):
|
|||
llm: dict[str, Union[str, float]] | None = None
|
||||
tts: dict[str, Union[str, float]] | None = None
|
||||
stt: dict[str, Union[str, float]] | None = None
|
||||
embeddings: dict[str, Union[str, float]] | None = None
|
||||
test_phone_number: str | None = None
|
||||
timezone: str | None = None
|
||||
organization_pricing: dict[str, Union[float, str, bool]] | None = None
|
||||
|
|
|
|||
102
api/schemas/knowledge_base.py
Normal file
102
api/schemas/knowledge_base.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
"""Pydantic schemas for knowledge base operations."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DocumentUploadRequestSchema(BaseModel):
|
||||
"""Request schema for initiating document upload."""
|
||||
|
||||
filename: str = Field(..., description="Name of the file to upload")
|
||||
mime_type: str = Field(..., description="MIME type of the file")
|
||||
custom_metadata: Optional[Dict[str, Any]] = Field(
|
||||
default=None, description="Optional custom metadata"
|
||||
)
|
||||
|
||||
|
||||
class DocumentUploadResponseSchema(BaseModel):
|
||||
"""Response schema containing upload URL and document metadata."""
|
||||
|
||||
upload_url: str = Field(..., description="Signed URL for uploading the file")
|
||||
document_uuid: str = Field(..., description="Unique identifier for the document")
|
||||
s3_key: str = Field(..., description="S3 key where file should be uploaded")
|
||||
|
||||
|
||||
class ProcessDocumentRequestSchema(BaseModel):
|
||||
"""Request schema for triggering document processing."""
|
||||
|
||||
document_uuid: str = Field(..., description="Document UUID to process")
|
||||
s3_key: str = Field(..., description="S3 key of the uploaded file")
|
||||
embedding_service: Literal["sentence_transformer", "openai"] = Field(
|
||||
default="openai",
|
||||
description="Embedding service to use for processing. "
|
||||
"Options: 'openai' (default, 1536-dim, requires API key) or 'sentence_transformer' (free, 384-dim)",
|
||||
)
|
||||
|
||||
|
||||
class DocumentResponseSchema(BaseModel):
|
||||
"""Response schema for document metadata."""
|
||||
|
||||
id: int
|
||||
document_uuid: str
|
||||
filename: str
|
||||
file_size_bytes: int
|
||||
file_hash: str
|
||||
mime_type: str
|
||||
processing_status: str # pending, processing, completed, failed
|
||||
processing_error: Optional[str] = None
|
||||
total_chunks: int
|
||||
custom_metadata: Dict[str, Any]
|
||||
docling_metadata: Dict[str, Any]
|
||||
source_url: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
organization_id: int
|
||||
created_by: int
|
||||
is_active: bool
|
||||
|
||||
|
||||
class DocumentListResponseSchema(BaseModel):
|
||||
"""Response schema for list of documents."""
|
||||
|
||||
documents: List[DocumentResponseSchema]
|
||||
total: int
|
||||
limit: int
|
||||
offset: int
|
||||
|
||||
|
||||
class ChunkSearchRequestSchema(BaseModel):
|
||||
"""Request schema for searching similar chunks."""
|
||||
|
||||
query: str = Field(..., description="Search query text")
|
||||
limit: int = Field(default=5, ge=1, le=50, description="Maximum number of results")
|
||||
document_uuids: Optional[List[str]] = Field(
|
||||
default=None, description="Filter by specific document UUIDs"
|
||||
)
|
||||
min_similarity: Optional[float] = Field(
|
||||
default=None, ge=0.0, le=1.0, description="Minimum similarity threshold"
|
||||
)
|
||||
|
||||
|
||||
class ChunkResponseSchema(BaseModel):
|
||||
"""Response schema for a document chunk."""
|
||||
|
||||
id: int
|
||||
document_id: int
|
||||
chunk_text: str
|
||||
contextualized_text: Optional[str]
|
||||
chunk_index: int
|
||||
chunk_metadata: Dict[str, Any]
|
||||
filename: str
|
||||
document_uuid: str
|
||||
similarity: float
|
||||
|
||||
|
||||
class ChunkSearchResponseSchema(BaseModel):
|
||||
"""Response schema for chunk search results."""
|
||||
|
||||
chunks: List[ChunkResponseSchema]
|
||||
query: str
|
||||
total_results: int
|
||||
|
|
@ -3,6 +3,7 @@ from datetime import datetime
|
|||
from pydantic import BaseModel
|
||||
|
||||
from api.services.configuration.registry import (
|
||||
EmbeddingsConfig,
|
||||
LLMConfig,
|
||||
STTConfig,
|
||||
TTSConfig,
|
||||
|
|
@ -13,6 +14,7 @@ class UserConfiguration(BaseModel):
|
|||
llm: LLMConfig | None = None
|
||||
stt: STTConfig | None = None
|
||||
tts: TTSConfig | None = None
|
||||
embeddings: EmbeddingsConfig | None = None
|
||||
test_phone_number: str | None = None
|
||||
timezone: str | None = None
|
||||
last_validated_at: datetime | None = None
|
||||
|
|
|
|||
|
|
@ -48,6 +48,12 @@ class UserConfigurationValidator:
|
|||
status_list.extend(self._validate_service(configuration.llm, "llm"))
|
||||
status_list.extend(self._validate_service(configuration.stt, "stt"))
|
||||
status_list.extend(self._validate_service(configuration.tts, "tts"))
|
||||
# Embeddings is optional - only validate if configured
|
||||
status_list.extend(
|
||||
self._validate_service(
|
||||
configuration.embeddings, "embeddings", required=False
|
||||
)
|
||||
)
|
||||
|
||||
if status_list:
|
||||
raise ValueError(status_list)
|
||||
|
|
@ -55,11 +61,16 @@ class UserConfigurationValidator:
|
|||
return {"status": [{"model": "all", "message": "ok"}]}
|
||||
|
||||
def _validate_service(
|
||||
self, service_config: Optional[ServiceConfig], service_name: str
|
||||
self,
|
||||
service_config: Optional[ServiceConfig],
|
||||
service_name: str,
|
||||
required: bool = True,
|
||||
) -> list[APIKeyStatus]:
|
||||
"""Validate a service configuration and return any error statuses."""
|
||||
if not service_config:
|
||||
return [{"model": service_name, "message": "API key is missing"}]
|
||||
if required:
|
||||
return [{"model": service_name, "message": "API key is missing"}]
|
||||
return [] # Optional service not configured is OK
|
||||
|
||||
provider = service_config.provider
|
||||
api_key = service_config.api_key
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ left as ``None``.
|
|||
from api.services.configuration.registry import (
|
||||
DeepgramSTTConfiguration,
|
||||
ElevenlabsTTSConfiguration,
|
||||
OpenAIEmbeddingsConfiguration,
|
||||
OpenAILLMService,
|
||||
ServiceProviders,
|
||||
)
|
||||
|
|
@ -22,6 +23,7 @@ _DEFAULTS = {
|
|||
"llm": (ServiceProviders.OPENAI, OpenAILLMService),
|
||||
"tts": (ServiceProviders.ELEVENLABS, ElevenlabsTTSConfiguration),
|
||||
"stt": (ServiceProviders.DEEPGRAM, DeepgramSTTConfiguration),
|
||||
"embeddings": (ServiceProviders.OPENAI, OpenAIEmbeddingsConfiguration),
|
||||
}
|
||||
|
||||
# Public mapping of service name -> default provider
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ def mask_user_config(config: UserConfiguration) -> Dict[str, Any]:
|
|||
"llm": _mask_service(config.llm),
|
||||
"tts": _mask_service(config.tts),
|
||||
"stt": _mask_service(config.stt),
|
||||
"embeddings": _mask_service(config.embeddings),
|
||||
"test_phone_number": config.test_phone_number,
|
||||
"timezone": config.timezone,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing import Dict
|
|||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.configuration.masking import is_mask_of
|
||||
|
||||
SERVICE_FIELDS = ("llm", "tts", "stt")
|
||||
SERVICE_FIELDS = ("llm", "tts", "stt", "embeddings")
|
||||
|
||||
|
||||
def merge_user_configurations(
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ class ServiceType(Enum):
|
|||
LLM = auto()
|
||||
TTS = auto()
|
||||
STT = auto()
|
||||
EMBEDDINGS = auto()
|
||||
|
||||
|
||||
class ServiceProviders(str, Enum):
|
||||
|
|
@ -50,11 +51,16 @@ class BaseSTTConfiguration(BaseServiceConfiguration):
|
|||
model: str
|
||||
|
||||
|
||||
class BaseEmbeddingsConfiguration(BaseServiceConfiguration):
|
||||
model: str
|
||||
|
||||
|
||||
# Unified registry for all service types
|
||||
REGISTRY: Dict[ServiceType, Dict[str, Type[BaseServiceConfiguration]]] = {
|
||||
ServiceType.LLM: {},
|
||||
ServiceType.TTS: {},
|
||||
ServiceType.STT: {},
|
||||
ServiceType.EMBEDDINGS: {},
|
||||
}
|
||||
|
||||
T = TypeVar("T", bound=BaseServiceConfiguration)
|
||||
|
|
@ -93,6 +99,10 @@ def register_stt(cls: Type[BaseSTTConfiguration]):
|
|||
return register_service(ServiceType.STT)(cls)
|
||||
|
||||
|
||||
def register_embeddings(cls: Type[BaseEmbeddingsConfiguration]):
|
||||
return register_service(ServiceType.EMBEDDINGS)(cls)
|
||||
|
||||
|
||||
###################################################### LLM ########################################################################
|
||||
|
||||
# Suggested models for each provider (used for UI dropdown)
|
||||
|
|
@ -436,6 +446,27 @@ STTConfig = Annotated[
|
|||
Field(discriminator="provider"),
|
||||
]
|
||||
|
||||
ServiceConfig = Annotated[
|
||||
Union[LLMConfig, TTSConfig, STTConfig], Field(discriminator="provider")
|
||||
###################################################### EMBEDDINGS ########################################################################
|
||||
|
||||
OPENAI_EMBEDDING_MODELS = ["text-embedding-3-small"]
|
||||
|
||||
|
||||
@register_embeddings
|
||||
class OpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
|
||||
provider: Literal[ServiceProviders.OPENAI] = ServiceProviders.OPENAI
|
||||
model: str = Field(
|
||||
default="text-embedding-3-small",
|
||||
json_schema_extra={"examples": OPENAI_EMBEDDING_MODELS},
|
||||
)
|
||||
api_key: str
|
||||
|
||||
|
||||
EmbeddingsConfig = Annotated[
|
||||
Union[OpenAIEmbeddingsConfiguration],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
||||
ServiceConfig = Annotated[
|
||||
Union[LLMConfig, TTSConfig, STTConfig, EmbeddingsConfig],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -85,3 +85,16 @@ class BaseFileSystem(ABC):
|
|||
Optional[str]: Presigned PUT URL if successful, None otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def adownload_file(self, source_path: str, local_path: str) -> bool:
|
||||
"""Download a file from storage to local path.
|
||||
|
||||
Args:
|
||||
source_path: Path to the file in storage
|
||||
local_path: Local path where file should be downloaded
|
||||
|
||||
Returns:
|
||||
bool: True if file was downloaded successfully, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -170,3 +170,15 @@ class MinioFileSystem(BaseFileSystem):
|
|||
except Exception as e:
|
||||
logger.error(f"Error generating MinIO upload URL: {e}")
|
||||
return None
|
||||
|
||||
async def adownload_file(self, source_path: str, local_path: str) -> bool:
|
||||
"""Download a file from MinIO to local path."""
|
||||
try:
|
||||
|
||||
def _fget():
|
||||
self.client.fget_object(self.bucket_name, source_path, local_path)
|
||||
|
||||
await asyncio.to_thread(_fget)
|
||||
return True
|
||||
except S3Error:
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -126,3 +126,14 @@ class S3FileSystem(BaseFileSystem):
|
|||
return url
|
||||
except ClientError:
|
||||
return None
|
||||
|
||||
async def adownload_file(self, source_path: str, local_path: str) -> bool:
|
||||
"""Download a file from S3 to local path."""
|
||||
try:
|
||||
async with self.session.client(
|
||||
"s3", region_name=self.region_name
|
||||
) as s3_client:
|
||||
await s3_client.download_file(self.bucket_name, source_path, local_path)
|
||||
return True
|
||||
except ClientError:
|
||||
return False
|
||||
|
|
|
|||
15
api/services/gen_ai/__init__.py
Normal file
15
api/services/gen_ai/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
"""Generative AI services for embeddings and document processing."""
|
||||
|
||||
from .embedding import (
|
||||
BaseEmbeddingService,
|
||||
EmbeddingAPIKeyNotConfiguredError,
|
||||
OpenAIEmbeddingService,
|
||||
SentenceTransformerEmbeddingService,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseEmbeddingService",
|
||||
"EmbeddingAPIKeyNotConfiguredError",
|
||||
"SentenceTransformerEmbeddingService",
|
||||
"OpenAIEmbeddingService",
|
||||
]
|
||||
12
api/services/gen_ai/embedding/__init__.py
Normal file
12
api/services/gen_ai/embedding/__init__.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
"""Embedding services for document processing and retrieval."""
|
||||
|
||||
from .base import BaseEmbeddingService
|
||||
from .openai_service import EmbeddingAPIKeyNotConfiguredError, OpenAIEmbeddingService
|
||||
from .sentence_transformer_service import SentenceTransformerEmbeddingService
|
||||
|
||||
__all__ = [
|
||||
"BaseEmbeddingService",
|
||||
"EmbeddingAPIKeyNotConfiguredError",
|
||||
"SentenceTransformerEmbeddingService",
|
||||
"OpenAIEmbeddingService",
|
||||
]
|
||||
75
api/services/gen_ai/embedding/base.py
Normal file
75
api/services/gen_ai/embedding/base.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
"""Base class for embedding services."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class BaseEmbeddingService(ABC):
|
||||
"""Abstract base class for embedding services.
|
||||
|
||||
All embedding services (SentenceTransformer, OpenAI, etc.) should inherit from this class
|
||||
and implement the required methods.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_model_id(self) -> str:
|
||||
"""Return the model identifier.
|
||||
|
||||
Returns:
|
||||
String identifier for the model (e.g., 'sentence-transformers/all-MiniLM-L6-v2')
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_embedding_dimension(self) -> int:
|
||||
"""Return the embedding dimension.
|
||||
|
||||
Returns:
|
||||
Integer dimension of the embedding vectors
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a batch of texts.
|
||||
|
||||
Args:
|
||||
texts: List of text strings to embed
|
||||
|
||||
Returns:
|
||||
List of embedding vectors (each vector is a list of floats)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def embed_query(self, query: str) -> List[float]:
|
||||
"""Embed a single query text.
|
||||
|
||||
Args:
|
||||
query: Query text to embed
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def search_similar_chunks(
|
||||
self,
|
||||
query: str,
|
||||
organization_id: int,
|
||||
limit: int = 5,
|
||||
document_uuids: Optional[List[str]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search for similar chunks using vector similarity.
|
||||
|
||||
Args:
|
||||
query: Search query text
|
||||
organization_id: Organization ID for scoping
|
||||
limit: Maximum number of results to return
|
||||
document_uuids: Optional list of document UUIDs to filter by
|
||||
|
||||
Returns:
|
||||
List of dictionaries containing chunk data and similarity scores
|
||||
"""
|
||||
pass
|
||||
372
api/services/gen_ai/embedding/openai_service.py
Normal file
372
api/services/gen_ai/embedding/openai_service.py
Normal file
|
|
@ -0,0 +1,372 @@
|
|||
"""OpenAI embedding service.
|
||||
|
||||
This module provides document processing capabilities using:
|
||||
- OpenAI's text-embedding-3-small for embeddings (1536 dimensions)
|
||||
- Docling for document conversion and chunking
|
||||
- pgvector for vector similarity search
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from docling.chunking import HybridChunker
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
|
||||
from loguru import logger
|
||||
from openai import AsyncOpenAI
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from api.db.db_client import DBClient
|
||||
from api.db.models import KnowledgeBaseChunkModel
|
||||
|
||||
from .base import BaseEmbeddingService
|
||||
|
||||
# Model configuration
|
||||
DEFAULT_MODEL_ID = "text-embedding-3-small"
|
||||
EMBEDDING_DIMENSION = 1536 # Dimension for text-embedding-3-small
|
||||
|
||||
# For chunking, we'll use the same tokenizer as SentenceTransformer
|
||||
# since OpenAI uses similar tokenization
|
||||
TOKENIZER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
|
||||
class EmbeddingAPIKeyNotConfiguredError(Exception):
|
||||
"""Raised when OpenAI API key is not configured for embeddings."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
"OpenAI API key not configured. Please set your API key in "
|
||||
"Model Configurations > Embedding to use document processing."
|
||||
)
|
||||
|
||||
|
||||
class OpenAIEmbeddingService(BaseEmbeddingService):
|
||||
"""Embedding service using OpenAI's text-embedding-3-small."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_client: DBClient,
|
||||
api_key: Optional[str] = None,
|
||||
model_id: str = DEFAULT_MODEL_ID,
|
||||
max_tokens: int = 512,
|
||||
):
|
||||
"""Initialize the OpenAI embedding service.
|
||||
|
||||
Args:
|
||||
db_client: Database client for storing documents and chunks
|
||||
api_key: OpenAI API key. If not provided, the client will not be
|
||||
initialized and operations will fail with a clear error.
|
||||
model_id: OpenAI embedding model ID (default: text-embedding-3-small)
|
||||
max_tokens: Maximum number of tokens per chunk (default: 512)
|
||||
"""
|
||||
self.db = db_client
|
||||
self.model_id = model_id
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
# Only initialize OpenAI client if API key is provided
|
||||
self._api_key_configured = bool(api_key)
|
||||
if self._api_key_configured:
|
||||
self.client = AsyncOpenAI(api_key=api_key)
|
||||
logger.info(f"OpenAI embedding service initialized with model: {model_id}")
|
||||
else:
|
||||
self.client = None
|
||||
logger.warning(
|
||||
"OpenAI embedding service initialized without API key. "
|
||||
"Operations will fail until API key is configured in Model Configurations."
|
||||
)
|
||||
|
||||
# Initialize tokenizer for chunking
|
||||
# We use a HuggingFace tokenizer for consistent chunking
|
||||
logger.info(
|
||||
f"Loading tokenizer for chunking: {TOKENIZER_MODEL} with max_tokens={max_tokens}"
|
||||
)
|
||||
try:
|
||||
self.tokenizer = HuggingFaceTokenizer(
|
||||
tokenizer=AutoTokenizer.from_pretrained(
|
||||
TOKENIZER_MODEL,
|
||||
local_files_only=True,
|
||||
),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
logger.info("Loaded tokenizer from cache")
|
||||
except Exception as e:
|
||||
logger.warning(f"Tokenizer not in cache, downloading: {e}")
|
||||
self.tokenizer = HuggingFaceTokenizer(
|
||||
tokenizer=AutoTokenizer.from_pretrained(TOKENIZER_MODEL),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
logger.info("Tokenizer downloaded and cached")
|
||||
|
||||
# Initialize chunker
|
||||
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
|
||||
self.chunker = HybridChunker(tokenizer=self.tokenizer)
|
||||
|
||||
# Initialize document converter
|
||||
self.converter = DocumentConverter()
|
||||
|
||||
def get_model_id(self) -> str:
|
||||
"""Return the model identifier."""
|
||||
return self.model_id
|
||||
|
||||
def get_embedding_dimension(self) -> int:
|
||||
"""Return the embedding dimension."""
|
||||
return EMBEDDING_DIMENSION
|
||||
|
||||
def _ensure_api_key_configured(self):
|
||||
"""Check if API key is configured and raise error if not."""
|
||||
if not self._api_key_configured or self.client is None:
|
||||
raise EmbeddingAPIKeyNotConfiguredError()
|
||||
|
||||
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a batch of texts using OpenAI API.
|
||||
|
||||
Args:
|
||||
texts: List of text strings to embed
|
||||
|
||||
Returns:
|
||||
List of embedding vectors (each vector is a list of floats)
|
||||
|
||||
Raises:
|
||||
EmbeddingAPIKeyNotConfiguredError: If API key is not configured
|
||||
"""
|
||||
self._ensure_api_key_configured()
|
||||
|
||||
try:
|
||||
# OpenAI API call
|
||||
response = await self.client.embeddings.create(
|
||||
input=texts,
|
||||
model=self.model_id,
|
||||
)
|
||||
|
||||
# Extract embeddings from response
|
||||
embeddings = [item.embedding for item in response.data]
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating OpenAI embeddings: {e}")
|
||||
raise
|
||||
|
||||
async def embed_query(self, query: str) -> List[float]:
|
||||
"""Embed a single query text using OpenAI API.
|
||||
|
||||
Args:
|
||||
query: Query text to embed
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats
|
||||
|
||||
Raises:
|
||||
EmbeddingAPIKeyNotConfiguredError: If API key is not configured
|
||||
"""
|
||||
self._ensure_api_key_configured()
|
||||
embeddings = await self.embed_texts([query])
|
||||
return embeddings[0]
|
||||
|
||||
async def search_similar_chunks(
|
||||
self,
|
||||
query: str,
|
||||
organization_id: int,
|
||||
limit: int = 5,
|
||||
document_uuids: Optional[List[str]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search for similar chunks using vector similarity.
|
||||
|
||||
Args:
|
||||
query: Search query text
|
||||
organization_id: Organization ID for scoping
|
||||
limit: Maximum number of results to return
|
||||
document_uuids: Optional list of document UUIDs to filter by
|
||||
|
||||
Returns:
|
||||
List of dictionaries with chunk data and similarity scores
|
||||
|
||||
Raises:
|
||||
EmbeddingAPIKeyNotConfiguredError: If API key is not configured
|
||||
"""
|
||||
self._ensure_api_key_configured()
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await self.embed_query(query)
|
||||
|
||||
# Perform vector similarity search
|
||||
results = await self.db.search_similar_chunks(
|
||||
query_embedding=query_embedding,
|
||||
organization_id=organization_id,
|
||||
limit=limit,
|
||||
document_uuids=document_uuids,
|
||||
embedding_model=self.model_id,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def process_document(
|
||||
self,
|
||||
file_path: str,
|
||||
organization_id: int,
|
||||
created_by: int,
|
||||
custom_metadata: dict = None,
|
||||
):
|
||||
"""Process a document: convert, chunk, embed, and store in database.
|
||||
|
||||
Args:
|
||||
file_path: Path to the document file
|
||||
organization_id: Organization ID for scoping
|
||||
created_by: User ID who uploaded the document
|
||||
custom_metadata: Optional custom metadata dictionary
|
||||
|
||||
Returns:
|
||||
The created document record
|
||||
"""
|
||||
try:
|
||||
# Extract file metadata
|
||||
filename = Path(file_path).name
|
||||
file_hash = self.db.compute_file_hash(file_path)
|
||||
file_size = os.path.getsize(file_path)
|
||||
mime_type = self.db.get_mime_type(file_path)
|
||||
|
||||
# Check if document already exists
|
||||
existing_doc = await self.db.get_document_by_hash(
|
||||
file_hash, organization_id
|
||||
)
|
||||
if existing_doc:
|
||||
logger.info(f"Document already exists: {filename} (hash: {file_hash})")
|
||||
return existing_doc
|
||||
|
||||
# Create document record
|
||||
doc_record = await self.db.create_document(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
filename=filename,
|
||||
file_size_bytes=file_size,
|
||||
file_hash=file_hash,
|
||||
mime_type=mime_type,
|
||||
custom_metadata=custom_metadata or {},
|
||||
)
|
||||
|
||||
logger.info(f"Processing document with OpenAI embeddings: {filename}")
|
||||
|
||||
# Update status to processing
|
||||
await self.db.update_document_status(doc_record.id, "processing")
|
||||
|
||||
# Step 1: Convert document using docling
|
||||
logger.info("Converting document with docling...")
|
||||
conversion_result = self.converter.convert(file_path)
|
||||
doc = conversion_result.document
|
||||
|
||||
# Store docling metadata
|
||||
docling_metadata = {
|
||||
"num_pages": len(doc.pages) if hasattr(doc, "pages") else None,
|
||||
"document_type": type(doc).__name__,
|
||||
}
|
||||
|
||||
# Step 2: Chunk the document
|
||||
logger.info(f"Chunking document with max_tokens={self.max_tokens}...")
|
||||
chunks = list(self.chunker.chunk(dl_doc=doc))
|
||||
total_chunks = len(chunks)
|
||||
|
||||
logger.info(f"Generated {total_chunks} chunks")
|
||||
|
||||
# Step 3: Process each chunk
|
||||
chunk_texts = []
|
||||
chunk_records = []
|
||||
token_counts = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Get chunk text
|
||||
chunk_text = chunk.text
|
||||
|
||||
# Get contextualized text
|
||||
contextualized_text = self.chunker.contextualize(chunk=chunk)
|
||||
|
||||
# Calculate token count
|
||||
text_to_tokenize = (
|
||||
contextualized_text if contextualized_text else chunk_text
|
||||
)
|
||||
token_count = len(
|
||||
self.tokenizer.tokenizer.encode(
|
||||
text_to_tokenize, add_special_tokens=False
|
||||
)
|
||||
)
|
||||
token_counts.append(token_count)
|
||||
|
||||
# Prepare chunk metadata
|
||||
chunk_metadata = {}
|
||||
if hasattr(chunk, "meta") and chunk.meta:
|
||||
chunk_metadata = {
|
||||
"doc_items": (
|
||||
[str(item) for item in chunk.meta.doc_items]
|
||||
if hasattr(chunk.meta, "doc_items")
|
||||
else []
|
||||
),
|
||||
"headings": (
|
||||
chunk.meta.headings
|
||||
if hasattr(chunk.meta, "headings")
|
||||
else []
|
||||
),
|
||||
}
|
||||
|
||||
# Create chunk record (without embedding yet)
|
||||
chunk_record = KnowledgeBaseChunkModel(
|
||||
document_id=doc_record.id,
|
||||
organization_id=organization_id,
|
||||
chunk_text=chunk_text,
|
||||
contextualized_text=contextualized_text,
|
||||
chunk_index=i,
|
||||
chunk_metadata=chunk_metadata,
|
||||
embedding_model=self.model_id,
|
||||
embedding_dimension=EMBEDDING_DIMENSION,
|
||||
token_count=token_count,
|
||||
)
|
||||
|
||||
chunk_records.append(chunk_record)
|
||||
chunk_texts.append(text_to_tokenize)
|
||||
|
||||
# Log chunk statistics
|
||||
if token_counts:
|
||||
avg_tokens = sum(token_counts) / len(token_counts)
|
||||
min_tokens = min(token_counts)
|
||||
max_tokens = max(token_counts)
|
||||
logger.info("Chunk token statistics:")
|
||||
logger.info(f" - Average: {avg_tokens:.1f} tokens")
|
||||
logger.info(f" - Min: {min_tokens} tokens")
|
||||
logger.info(f" - Max: {max_tokens} tokens")
|
||||
|
||||
# Step 4: Generate embeddings using OpenAI API
|
||||
logger.info(f"Generating embeddings using OpenAI ({self.model_id})...")
|
||||
embeddings = await self.embed_texts(chunk_texts)
|
||||
|
||||
# Step 5: Attach embeddings to chunk records
|
||||
for chunk_record, embedding in zip(chunk_records, embeddings):
|
||||
chunk_record.embedding = embedding
|
||||
|
||||
# Step 6: Save all chunks in batch
|
||||
logger.info("Storing chunks in database...")
|
||||
await self.db.create_chunks_batch(chunk_records)
|
||||
|
||||
# Update document status to completed
|
||||
await self.db.update_document_status(
|
||||
doc_record.id,
|
||||
"completed",
|
||||
total_chunks=total_chunks,
|
||||
docling_metadata=docling_metadata,
|
||||
)
|
||||
|
||||
logger.info(f"Successfully processed document: {filename}")
|
||||
logger.info(f" - Total chunks: {total_chunks}")
|
||||
logger.info(f" - Embedding model: {self.model_id}")
|
||||
logger.info(f" - Document ID: {doc_record.id}")
|
||||
logger.info(f" - Document UUID: {doc_record.document_uuid}")
|
||||
|
||||
return doc_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing document with OpenAI: {e}")
|
||||
|
||||
# Update document status to failed if it exists
|
||||
if "doc_record" in locals():
|
||||
await self.db.update_document_status(
|
||||
doc_record.id, "failed", error_message=str(e)
|
||||
)
|
||||
|
||||
raise
|
||||
350
api/services/gen_ai/embedding/sentence_transformer_service.py
Normal file
350
api/services/gen_ai/embedding/sentence_transformer_service.py
Normal file
|
|
@ -0,0 +1,350 @@
|
|||
"""Sentence Transformer embedding service.
|
||||
|
||||
This module provides document processing capabilities using:
|
||||
- Sentence-transformers for embeddings (all-MiniLM-L6-v2)
|
||||
- Docling for document conversion and chunking
|
||||
- pgvector for vector similarity search
|
||||
|
||||
Setup for offline usage:
|
||||
1. First run: Downloads and caches models to ~/.cache/sentence_transformers
|
||||
2. Subsequent runs: Uses cached models (no internet needed)
|
||||
3. For fully offline mode: Set TRANSFORMERS_OFFLINE=1 and HF_HUB_OFFLINE=1
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from docling.chunking import HybridChunker
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
|
||||
from loguru import logger
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from api.db.db_client import DBClient
|
||||
from api.db.models import KnowledgeBaseChunkModel
|
||||
|
||||
from .base import BaseEmbeddingService
|
||||
|
||||
# Set environment variables for model caching
|
||||
os.environ.setdefault("TRANSFORMERS_OFFLINE", "0")
|
||||
os.environ.setdefault("HF_HUB_OFFLINE", "0")
|
||||
os.environ.setdefault(
|
||||
"SENTENCE_TRANSFORMERS_HOME", os.path.expanduser("~/.cache/sentence_transformers")
|
||||
)
|
||||
|
||||
# Model configuration
|
||||
DEFAULT_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
EMBEDDING_DIMENSION = 384 # Dimension for all-MiniLM-L6-v2
|
||||
|
||||
|
||||
class SentenceTransformerEmbeddingService(BaseEmbeddingService):
|
||||
"""Embedding service using Sentence Transformers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_client: DBClient,
|
||||
model_id: str = DEFAULT_MODEL_ID,
|
||||
max_tokens: int = 512,
|
||||
):
|
||||
"""Initialize the Sentence Transformer embedding service.
|
||||
|
||||
Args:
|
||||
db_client: Database client for storing documents and chunks
|
||||
model_id: Sentence-transformers model ID (default: all-MiniLM-L6-v2)
|
||||
max_tokens: Maximum number of tokens per chunk (default: 512)
|
||||
Note: This applies to the contextualized text (with headings/captions)
|
||||
"""
|
||||
self.db = db_client
|
||||
self.model_id = model_id
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
# Initialize embedding model
|
||||
logger.info(f"Loading embedding model: {model_id}")
|
||||
try:
|
||||
# Try to load from cache first (local_files_only=True)
|
||||
self.embedding_model = SentenceTransformer(
|
||||
model_id,
|
||||
cache_folder=os.environ.get("SENTENCE_TRANSFORMERS_HOME"),
|
||||
local_files_only=True,
|
||||
)
|
||||
logger.info("Loaded model from cache")
|
||||
except Exception as e:
|
||||
logger.warning(f"Model not in cache, downloading: {e}")
|
||||
# If not in cache, download it (this will cache it for next time)
|
||||
self.embedding_model = SentenceTransformer(
|
||||
model_id,
|
||||
cache_folder=os.environ.get("SENTENCE_TRANSFORMERS_HOME"),
|
||||
)
|
||||
logger.info("Model downloaded and cached")
|
||||
|
||||
# Initialize tokenizer for chunking with max_tokens
|
||||
logger.info(f"Loading tokenizer: {model_id} with max_tokens={max_tokens}")
|
||||
try:
|
||||
# Try to load from cache first
|
||||
self.tokenizer = HuggingFaceTokenizer(
|
||||
tokenizer=AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
local_files_only=True,
|
||||
),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
logger.info("Loaded tokenizer from cache")
|
||||
except Exception as e:
|
||||
logger.warning(f"Tokenizer not in cache, downloading: {e}")
|
||||
# If not in cache, download it
|
||||
self.tokenizer = HuggingFaceTokenizer(
|
||||
tokenizer=AutoTokenizer.from_pretrained(model_id),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
logger.info("Tokenizer downloaded and cached")
|
||||
|
||||
# Initialize chunker
|
||||
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
|
||||
self.chunker = HybridChunker(tokenizer=self.tokenizer)
|
||||
|
||||
# Initialize document converter
|
||||
self.converter = DocumentConverter()
|
||||
|
||||
def get_model_id(self) -> str:
|
||||
"""Return the model identifier."""
|
||||
return self.model_id
|
||||
|
||||
def get_embedding_dimension(self) -> int:
|
||||
"""Return the embedding dimension."""
|
||||
return EMBEDDING_DIMENSION
|
||||
|
||||
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a batch of texts.
|
||||
|
||||
Args:
|
||||
texts: List of text strings to embed
|
||||
|
||||
Returns:
|
||||
List of embedding vectors (each vector is a list of floats)
|
||||
"""
|
||||
embeddings = self.embedding_model.encode(
|
||||
texts,
|
||||
show_progress_bar=False,
|
||||
convert_to_numpy=True,
|
||||
)
|
||||
return [embedding.tolist() for embedding in embeddings]
|
||||
|
||||
async def embed_query(self, query: str) -> List[float]:
|
||||
"""Embed a single query text.
|
||||
|
||||
Args:
|
||||
query: Query text to embed
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats
|
||||
"""
|
||||
embedding = self.embedding_model.encode([query])[0]
|
||||
return embedding.tolist()
|
||||
|
||||
async def search_similar_chunks(
|
||||
self,
|
||||
query: str,
|
||||
organization_id: int,
|
||||
limit: int = 5,
|
||||
document_uuids: Optional[List[str]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search for similar chunks using vector similarity.
|
||||
|
||||
Returns top-k most similar chunks without any threshold filtering.
|
||||
Apply similarity thresholds and reranking at the application layer.
|
||||
|
||||
Args:
|
||||
query: Search query text
|
||||
organization_id: Organization ID for scoping
|
||||
limit: Maximum number of results to return
|
||||
document_uuids: Optional list of document UUIDs to filter by
|
||||
|
||||
Returns:
|
||||
List of dictionaries with chunk data and similarity scores
|
||||
"""
|
||||
# Generate query embedding
|
||||
query_embedding = await self.embed_query(query)
|
||||
|
||||
# Perform vector similarity search
|
||||
results = await self.db.search_similar_chunks(
|
||||
query_embedding=query_embedding,
|
||||
organization_id=organization_id,
|
||||
limit=limit,
|
||||
document_uuids=document_uuids,
|
||||
embedding_model=self.model_id,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def process_document(
|
||||
self,
|
||||
file_path: str,
|
||||
organization_id: int,
|
||||
created_by: int,
|
||||
custom_metadata: dict = None,
|
||||
):
|
||||
"""Process a document: convert, chunk, embed, and store in database.
|
||||
|
||||
Args:
|
||||
file_path: Path to the document file
|
||||
organization_id: Organization ID for scoping
|
||||
created_by: User ID who uploaded the document
|
||||
custom_metadata: Optional custom metadata dictionary
|
||||
|
||||
Returns:
|
||||
The created document record
|
||||
"""
|
||||
try:
|
||||
# Extract file metadata
|
||||
filename = Path(file_path).name
|
||||
file_hash = self.db.compute_file_hash(file_path)
|
||||
file_size = os.path.getsize(file_path)
|
||||
mime_type = self.db.get_mime_type(file_path)
|
||||
|
||||
# Check if document already exists
|
||||
existing_doc = await self.db.get_document_by_hash(
|
||||
file_hash, organization_id
|
||||
)
|
||||
if existing_doc:
|
||||
logger.info(f"Document already exists: {filename} (hash: {file_hash})")
|
||||
return existing_doc
|
||||
|
||||
# Create document record
|
||||
doc_record = await self.db.create_document(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
filename=filename,
|
||||
file_size_bytes=file_size,
|
||||
file_hash=file_hash,
|
||||
mime_type=mime_type,
|
||||
custom_metadata=custom_metadata or {},
|
||||
)
|
||||
|
||||
logger.info(f"Processing document: {filename}")
|
||||
|
||||
# Update status to processing
|
||||
await self.db.update_document_status(doc_record.id, "processing")
|
||||
|
||||
# Step 1: Convert document using docling
|
||||
logger.info("Converting document with docling...")
|
||||
conversion_result = self.converter.convert(file_path)
|
||||
doc = conversion_result.document
|
||||
|
||||
# Store docling metadata
|
||||
docling_metadata = {
|
||||
"num_pages": len(doc.pages) if hasattr(doc, "pages") else None,
|
||||
"document_type": type(doc).__name__,
|
||||
}
|
||||
|
||||
# Step 2: Chunk the document
|
||||
logger.info(f"Chunking document with max_tokens={self.max_tokens}...")
|
||||
chunks = list(self.chunker.chunk(dl_doc=doc))
|
||||
total_chunks = len(chunks)
|
||||
|
||||
logger.info(f"Generated {total_chunks} chunks")
|
||||
|
||||
# Step 3: Process each chunk
|
||||
chunk_texts = []
|
||||
chunk_records = []
|
||||
token_counts = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Get chunk text
|
||||
chunk_text = chunk.text
|
||||
|
||||
# Get contextualized text (enriched with surrounding context)
|
||||
contextualized_text = self.chunker.contextualize(chunk=chunk)
|
||||
|
||||
# Calculate actual token count using the tokenizer
|
||||
text_to_tokenize = (
|
||||
contextualized_text if contextualized_text else chunk_text
|
||||
)
|
||||
token_count = len(
|
||||
self.tokenizer.tokenizer.encode(
|
||||
text_to_tokenize, add_special_tokens=False
|
||||
)
|
||||
)
|
||||
token_counts.append(token_count)
|
||||
|
||||
# Prepare chunk metadata
|
||||
chunk_metadata = {}
|
||||
if hasattr(chunk, "meta") and chunk.meta:
|
||||
chunk_metadata = {
|
||||
"doc_items": (
|
||||
[str(item) for item in chunk.meta.doc_items]
|
||||
if hasattr(chunk.meta, "doc_items")
|
||||
else []
|
||||
),
|
||||
"headings": (
|
||||
chunk.meta.headings
|
||||
if hasattr(chunk.meta, "headings")
|
||||
else []
|
||||
),
|
||||
}
|
||||
|
||||
# Create chunk record (without embedding yet)
|
||||
chunk_record = KnowledgeBaseChunkModel(
|
||||
document_id=doc_record.id,
|
||||
organization_id=organization_id,
|
||||
chunk_text=chunk_text,
|
||||
contextualized_text=contextualized_text,
|
||||
chunk_index=i,
|
||||
chunk_metadata=chunk_metadata,
|
||||
embedding_model=self.model_id,
|
||||
embedding_dimension=EMBEDDING_DIMENSION,
|
||||
token_count=token_count,
|
||||
)
|
||||
|
||||
chunk_records.append(chunk_record)
|
||||
# Use contextualized text for embedding if available
|
||||
chunk_texts.append(text_to_tokenize)
|
||||
|
||||
# Log chunk statistics
|
||||
if token_counts:
|
||||
avg_tokens = sum(token_counts) / len(token_counts)
|
||||
min_tokens = min(token_counts)
|
||||
max_tokens = max(token_counts)
|
||||
logger.info("Chunk token statistics:")
|
||||
logger.info(f" - Average: {avg_tokens:.1f} tokens")
|
||||
logger.info(f" - Min: {min_tokens} tokens")
|
||||
logger.info(f" - Max: {max_tokens} tokens")
|
||||
|
||||
# Step 4: Generate embeddings in batch
|
||||
logger.info("Generating embeddings...")
|
||||
embeddings = await self.embed_texts(chunk_texts)
|
||||
|
||||
# Step 5: Attach embeddings to chunk records
|
||||
for chunk_record, embedding in zip(chunk_records, embeddings):
|
||||
chunk_record.embedding = embedding
|
||||
|
||||
# Step 6: Save all chunks in batch
|
||||
logger.info("Storing chunks in database...")
|
||||
await self.db.create_chunks_batch(chunk_records)
|
||||
|
||||
# Update document status to completed
|
||||
await self.db.update_document_status(
|
||||
doc_record.id,
|
||||
"completed",
|
||||
total_chunks=total_chunks,
|
||||
docling_metadata=docling_metadata,
|
||||
)
|
||||
|
||||
logger.info(f"Successfully processed document: {filename}")
|
||||
logger.info(f" - Total chunks: {total_chunks}")
|
||||
logger.info(f" - Document ID: {doc_record.id}")
|
||||
logger.info(f" - Document UUID: {doc_record.document_uuid}")
|
||||
|
||||
return doc_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing document: {e}")
|
||||
|
||||
# Update document status to failed if it exists
|
||||
if "doc_record" in locals():
|
||||
await self.db.update_document_status(
|
||||
doc_record.id, "failed", error_message=str(e)
|
||||
)
|
||||
|
||||
raise
|
||||
44
api/services/pricing/embeddings.py
Normal file
44
api/services/pricing/embeddings.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
"""
|
||||
Embeddings pricing models for different providers.
|
||||
|
||||
Prices are per token for embedding models.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import PricingModel
|
||||
|
||||
|
||||
class EmbeddingPricingModel(PricingModel):
|
||||
"""Pricing model for token-based embedding services."""
|
||||
|
||||
def __init__(self, token_price: Decimal):
|
||||
"""Initialize with price per token.
|
||||
|
||||
Args:
|
||||
token_price: Cost per token for embedding
|
||||
"""
|
||||
self.token_price = token_price
|
||||
|
||||
def calculate_cost(self, token_count: int) -> Decimal:
|
||||
"""Calculate cost for embedding token usage."""
|
||||
return Decimal(token_count) * self.token_price
|
||||
|
||||
|
||||
# Embeddings pricing registry
|
||||
EMBEDDINGS_PRICING: Dict[str, Dict[str, EmbeddingPricingModel]] = {
|
||||
ServiceProviders.OPENAI: {
|
||||
"text-embedding-3-small": EmbeddingPricingModel(
|
||||
token_price=Decimal("0.02") / 1_000_000, # $0.02 per 1M tokens
|
||||
),
|
||||
"text-embedding-3-large": EmbeddingPricingModel(
|
||||
token_price=Decimal("0.13") / 1_000_000, # $0.13 per 1M tokens
|
||||
),
|
||||
"text-embedding-ada-002": EmbeddingPricingModel(
|
||||
token_price=Decimal("0.10") / 1_000_000, # $0.10 per 1M tokens (legacy)
|
||||
),
|
||||
},
|
||||
}
|
||||
|
|
@ -4,6 +4,7 @@ Main pricing registry that combines all service type pricing models.
|
|||
|
||||
from typing import Dict
|
||||
|
||||
from .embeddings import EMBEDDINGS_PRICING
|
||||
from .llm import LLM_PRICING
|
||||
from .stt import STT_PRICING
|
||||
from .tts import TTS_PRICING
|
||||
|
|
@ -13,4 +14,5 @@ PRICING_REGISTRY: Dict = {
|
|||
"llm": LLM_PRICING,
|
||||
"tts": TTS_PRICING,
|
||||
"stt": STT_PRICING,
|
||||
"embeddings": EMBEDDINGS_PRICING,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ class NodeDataDTO(BaseModel):
|
|||
delayed_start: bool = False
|
||||
delayed_start_duration: Optional[float] = None
|
||||
tool_uuids: Optional[List[str]] = None
|
||||
document_uuids: Optional[List[str]] = None
|
||||
trigger_path: Optional[str] = None
|
||||
# Webhook node specific fields
|
||||
enabled: bool = True
|
||||
|
|
|
|||
|
|
@ -41,6 +41,10 @@ from api.services.workflow.pipecat_engine_variable_extractor import (
|
|||
VariableExtractionManager,
|
||||
)
|
||||
from api.services.workflow.tools.calculator import get_calculator_tools, safe_calculator
|
||||
from api.services.workflow.tools.knowledge_base import (
|
||||
get_knowledge_base_tool,
|
||||
retrieve_from_knowledge_base,
|
||||
)
|
||||
from api.services.workflow.tools.timezone import (
|
||||
convert_time,
|
||||
get_current_time,
|
||||
|
|
@ -290,6 +294,48 @@ class PipecatEngine:
|
|||
self.llm.register_function("get_current_time", get_current_time_func)
|
||||
self.llm.register_function("convert_time", convert_time_func)
|
||||
|
||||
async def _register_knowledge_base_function(
|
||||
self, document_uuids: list[str]
|
||||
) -> None:
|
||||
"""Register knowledge base retrieval function with the LLM.
|
||||
|
||||
Args:
|
||||
document_uuids: List of document UUIDs to filter the search by
|
||||
"""
|
||||
logger.debug(
|
||||
f"Registering knowledge base retrieval function with {len(document_uuids)} document(s)"
|
||||
)
|
||||
|
||||
async def retrieve_kb_func(function_call_params: FunctionCallParams) -> None:
|
||||
logger.info("LLM Function Call EXECUTED: retrieve_from_knowledge_base")
|
||||
logger.info(f"Arguments: {function_call_params.arguments}")
|
||||
try:
|
||||
query = function_call_params.arguments.get("query", "")
|
||||
organization_id = await self._get_organization_id()
|
||||
|
||||
if not organization_id:
|
||||
raise ValueError(
|
||||
"Organization ID not available for knowledge base retrieval"
|
||||
)
|
||||
|
||||
result = await retrieve_from_knowledge_base(
|
||||
query=query,
|
||||
organization_id=organization_id,
|
||||
document_uuids=document_uuids,
|
||||
limit=3, # Return top 3 most relevant chunks
|
||||
)
|
||||
|
||||
await function_call_params.result_callback(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Knowledge base retrieval failed: {e}")
|
||||
await function_call_params.result_callback(
|
||||
{"error": str(e), "chunks": [], "query": query, "total_results": 0}
|
||||
)
|
||||
|
||||
# Register the function with the LLM
|
||||
self.llm.register_function("retrieve_from_knowledge_base", retrieve_kb_func)
|
||||
|
||||
async def _perform_variable_extraction_if_needed(
|
||||
self, previous_node: Optional[Node]
|
||||
) -> None:
|
||||
|
|
@ -346,6 +392,10 @@ class PipecatEngine:
|
|||
if node.tool_uuids and self._custom_tool_manager:
|
||||
await self._custom_tool_manager.register_handlers(node.tool_uuids)
|
||||
|
||||
# Register knowledge base retrieval handler if node has documents
|
||||
if node.document_uuids:
|
||||
await self._register_knowledge_base_function(node.document_uuids)
|
||||
|
||||
# Set up system message and functions
|
||||
(
|
||||
system_message,
|
||||
|
|
@ -575,6 +625,17 @@ class PipecatEngine:
|
|||
# Add built-in function schemas (calculator and timezone tools)
|
||||
functions.extend(self.builtin_function_schemas)
|
||||
|
||||
# Add knowledge base retrieval tool if node has documents
|
||||
if node.document_uuids:
|
||||
kb_tool_def = get_knowledge_base_tool(node.document_uuids)
|
||||
kb_schema = get_function_schema(
|
||||
kb_tool_def["function"]["name"],
|
||||
kb_tool_def["function"]["description"],
|
||||
properties=kb_tool_def["function"]["parameters"].get("properties", {}),
|
||||
required=kb_tool_def["function"]["parameters"].get("required", []),
|
||||
)
|
||||
functions.append(kb_schema)
|
||||
|
||||
# Add custom tools from node.tool_uuids
|
||||
if node.tool_uuids and self._custom_tool_manager:
|
||||
custom_tool_schemas = await self._custom_tool_manager.get_tool_schemas(
|
||||
|
|
|
|||
305
api/services/workflow/tools/knowledge_base.py
Normal file
305
api/services/workflow/tools/knowledge_base.py
Normal file
|
|
@ -0,0 +1,305 @@
|
|||
"""Knowledge Base retrieval tool for workflow execution.
|
||||
|
||||
This module provides vector similarity search capabilities for retrieving
|
||||
relevant information from the knowledge base during conversations.
|
||||
|
||||
Implements OpenTelemetry tracing for observability in Langfuse.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from opentelemetry import trace
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.gen_ai import OpenAIEmbeddingService
|
||||
from api.services.pipecat.tracing_config import is_tracing_enabled
|
||||
from pipecat.utils.tracing.context_registry import (
|
||||
get_current_conversation_context,
|
||||
get_current_turn_context,
|
||||
)
|
||||
|
||||
|
||||
async def retrieve_from_knowledge_base(
|
||||
query: str,
|
||||
organization_id: int,
|
||||
document_uuids: Optional[List[str]] = None,
|
||||
limit: int = 3,
|
||||
embeddings_api_key: Optional[str] = None,
|
||||
embeddings_model: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Retrieve relevant information from the knowledge base using vector similarity search.
|
||||
|
||||
Uses OpenAI text-embedding-3-small for embeddings by default. This provides
|
||||
high-quality 1536-dimensional embeddings for accurate retrieval.
|
||||
|
||||
This function includes OpenTelemetry tracing for Langfuse observability.
|
||||
|
||||
Args:
|
||||
query: The search query to find relevant information
|
||||
organization_id: Organization ID for scoping the search
|
||||
document_uuids: Optional list of document UUIDs to filter by
|
||||
limit: Maximum number of chunks to return (default: 3)
|
||||
embeddings_api_key: Optional API key for embedding service
|
||||
embeddings_model: Optional model ID for embedding service
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- chunks: List of relevant text chunks with metadata
|
||||
- query: The original query
|
||||
- total_results: Number of results returned
|
||||
"""
|
||||
# Create span for retrieval operation if tracing is enabled
|
||||
if is_tracing_enabled():
|
||||
try:
|
||||
# Get parent context from turn or conversation
|
||||
turn_context = get_current_turn_context()
|
||||
conversation_context = get_current_conversation_context()
|
||||
parent_context = turn_context or conversation_context
|
||||
|
||||
# Get tracer
|
||||
tracer = trace.get_tracer("pipecat")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to setup tracing context: {e}")
|
||||
# Fall back to non-traced execution
|
||||
return await _perform_retrieval(
|
||||
query,
|
||||
organization_id,
|
||||
document_uuids,
|
||||
limit,
|
||||
embeddings_api_key,
|
||||
embeddings_model,
|
||||
)
|
||||
|
||||
# Create span with parent context
|
||||
if parent_context:
|
||||
with tracer.start_as_current_span(
|
||||
"knowledge_base_retrieval", context=parent_context
|
||||
) as span:
|
||||
try:
|
||||
# Mark trace as public for Langfuse
|
||||
span.set_attribute("langfuse.trace.public", True)
|
||||
|
||||
# Add operation metadata
|
||||
span.set_attribute(
|
||||
"gen_ai.operation.name", "knowledge_base_retrieval"
|
||||
)
|
||||
span.set_attribute("retrieval.query", query)
|
||||
span.set_attribute("retrieval.limit", limit)
|
||||
span.set_attribute("retrieval.organization_id", organization_id)
|
||||
|
||||
# Add document filter info
|
||||
if document_uuids:
|
||||
span.set_attribute(
|
||||
"retrieval.document_count", len(document_uuids)
|
||||
)
|
||||
span.set_attribute(
|
||||
"retrieval.document_uuids", json.dumps(document_uuids)
|
||||
)
|
||||
|
||||
# Perform the actual retrieval
|
||||
result = await _perform_retrieval(
|
||||
query,
|
||||
organization_id,
|
||||
document_uuids,
|
||||
limit,
|
||||
embeddings_api_key,
|
||||
embeddings_model,
|
||||
)
|
||||
|
||||
# Add result metadata to span
|
||||
span.set_attribute(
|
||||
"retrieval.results_count", result["total_results"]
|
||||
)
|
||||
|
||||
if result.get("error"):
|
||||
span.set_attribute("retrieval.error", result["error"])
|
||||
span.set_status(
|
||||
trace.Status(trace.StatusCode.ERROR, result["error"])
|
||||
)
|
||||
else:
|
||||
# Add similarity scores
|
||||
if result["chunks"]:
|
||||
similarities = [
|
||||
chunk["similarity"] for chunk in result["chunks"]
|
||||
]
|
||||
span.set_attribute(
|
||||
"retrieval.avg_similarity",
|
||||
round(sum(similarities) / len(similarities), 4),
|
||||
)
|
||||
span.set_attribute(
|
||||
"retrieval.max_similarity", max(similarities)
|
||||
)
|
||||
span.set_attribute(
|
||||
"retrieval.min_similarity", min(similarities)
|
||||
)
|
||||
|
||||
# Add retrieved documents info
|
||||
filenames = list(
|
||||
set(chunk["filename"] for chunk in result["chunks"])
|
||||
)
|
||||
span.set_attribute(
|
||||
"retrieval.source_files", json.dumps(filenames)
|
||||
)
|
||||
|
||||
# Add output as JSON for Langfuse
|
||||
output_data = {
|
||||
"query": query,
|
||||
"chunks_retrieved": len(result["chunks"]),
|
||||
"chunks": [
|
||||
{
|
||||
"text": chunk["text"][:200] + "..."
|
||||
if len(chunk["text"]) > 200
|
||||
else chunk["text"],
|
||||
"filename": chunk["filename"],
|
||||
"similarity": chunk["similarity"],
|
||||
}
|
||||
for chunk in result["chunks"]
|
||||
],
|
||||
}
|
||||
span.set_attribute("output", json.dumps(output_data))
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in traced retrieval: {e}")
|
||||
span.record_exception(e)
|
||||
span.set_status(trace.Status(trace.StatusCode.ERROR, str(e)))
|
||||
raise
|
||||
else:
|
||||
# No parent context - perform retrieval without tracing
|
||||
logger.debug(
|
||||
"No parent context available for knowledge base retrieval tracing"
|
||||
)
|
||||
return await _perform_retrieval(
|
||||
query,
|
||||
organization_id,
|
||||
document_uuids,
|
||||
limit,
|
||||
embeddings_api_key,
|
||||
embeddings_model,
|
||||
)
|
||||
else:
|
||||
# Tracing is disabled - perform retrieval without tracing
|
||||
return await _perform_retrieval(
|
||||
query,
|
||||
organization_id,
|
||||
document_uuids,
|
||||
limit,
|
||||
embeddings_api_key,
|
||||
embeddings_model,
|
||||
)
|
||||
|
||||
|
||||
async def _perform_retrieval(
|
||||
query: str,
|
||||
organization_id: int,
|
||||
document_uuids: Optional[List[str]],
|
||||
limit: int,
|
||||
embeddings_api_key: Optional[str] = None,
|
||||
embeddings_model: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Internal function to perform the actual retrieval operation.
|
||||
|
||||
Separated from tracing logic for cleaner code organization.
|
||||
Uses OpenAI embeddings by default for high-quality retrieval.
|
||||
"""
|
||||
try:
|
||||
# Create a new embedding service instance
|
||||
# Uses OpenAI text-embedding-3-small by default, or user-provided config
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
max_tokens=128, # This is only used for chunking, not for retrieval
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
)
|
||||
|
||||
# Perform vector similarity search
|
||||
results = await embedding_service.search_similar_chunks(
|
||||
query=query,
|
||||
organization_id=organization_id,
|
||||
limit=limit,
|
||||
document_uuids=document_uuids,
|
||||
)
|
||||
|
||||
# Format results for LLM consumption
|
||||
chunks = []
|
||||
for result in results:
|
||||
chunk_info = {
|
||||
"text": result.get("contextualized_text") or result.get("chunk_text"),
|
||||
"filename": result.get("filename"),
|
||||
"similarity": round(result.get("similarity", 0), 4),
|
||||
"chunk_index": result.get("chunk_index"),
|
||||
}
|
||||
chunks.append(chunk_info)
|
||||
|
||||
logger.info(
|
||||
f"Knowledge base retrieval: query='{query}', "
|
||||
f"results={len(chunks)}, "
|
||||
f"document_filter={document_uuids}"
|
||||
)
|
||||
|
||||
return {
|
||||
"chunks": chunks,
|
||||
"query": query,
|
||||
"total_results": len(chunks),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving from knowledge base: {e}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"chunks": [],
|
||||
"query": query,
|
||||
"total_results": 0,
|
||||
}
|
||||
|
||||
|
||||
def get_knowledge_base_tool(
|
||||
document_uuids: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get knowledge base retrieval tool definition for LLM function calling.
|
||||
|
||||
Args:
|
||||
document_uuids: Optional list of document UUIDs to include in description
|
||||
|
||||
Returns:
|
||||
Tool definition compatible with LLM function calling
|
||||
"""
|
||||
# Build description based on whether specific documents are filtered
|
||||
if document_uuids and len(document_uuids) > 0:
|
||||
description = (
|
||||
"Retrieve relevant information from specific documents in the knowledge base. "
|
||||
"Use this tool when you need to look up facts, policies, procedures, or any information "
|
||||
"that might be stored in the available documents. The search will only look in the "
|
||||
f"documents associated with this conversation step ({len(document_uuids)} document(s) available)."
|
||||
)
|
||||
else:
|
||||
description = (
|
||||
"Retrieve relevant information from the knowledge base. "
|
||||
"Use this tool when you need to look up facts, policies, procedures, or any information "
|
||||
"that might be stored in the knowledge base documents."
|
||||
)
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "retrieve_from_knowledge_base",
|
||||
"description": description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The search query to find relevant information. "
|
||||
"Be specific and use natural language. "
|
||||
"Example: 'What is the refund policy for canceled orders?'"
|
||||
),
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -48,6 +48,7 @@ class Node:
|
|||
self.delayed_start = data.delayed_start
|
||||
self.delayed_start_duration = data.delayed_start_duration
|
||||
self.tool_uuids = data.tool_uuids
|
||||
self.document_uuids = data.document_uuids
|
||||
|
||||
self.data = data
|
||||
|
||||
|
|
@ -189,16 +190,6 @@ class WorkflowGraph:
|
|||
in_d, out_d = in_deg[n.id], out_deg[n.id]
|
||||
|
||||
match n.node_type:
|
||||
case NodeType.startNode:
|
||||
if in_d != 0 or out_d < 1:
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.node,
|
||||
id=n.id,
|
||||
field=None,
|
||||
message=f"StartNode must have at least 1 outgoing edge",
|
||||
)
|
||||
)
|
||||
case NodeType.endNode:
|
||||
if in_d < 1 or out_d != 0:
|
||||
errors.append(
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ from api.tasks.campaign_tasks import (
|
|||
process_campaign_batch,
|
||||
sync_campaign_source,
|
||||
)
|
||||
from api.tasks.knowledge_base_processing import process_knowledge_base_document
|
||||
from api.tasks.run_integrations import run_integrations_post_workflow_run
|
||||
from api.tasks.s3_upload import (
|
||||
upload_audio_to_s3,
|
||||
|
|
@ -64,6 +65,7 @@ class WorkerSettings:
|
|||
sync_campaign_source,
|
||||
process_campaign_batch,
|
||||
monitor_campaign_progress,
|
||||
process_knowledge_base_document,
|
||||
]
|
||||
cron_jobs = []
|
||||
redis_settings = REDIS_SETTINGS
|
||||
|
|
|
|||
|
|
@ -7,3 +7,4 @@ class FunctionNames:
|
|||
SYNC_CAMPAIGN_SOURCE = "sync_campaign_source"
|
||||
PROCESS_CAMPAIGN_BATCH = "process_campaign_batch"
|
||||
MONITOR_CAMPAIGN_PROGRESS = "monitor_campaign_progress"
|
||||
PROCESS_KNOWLEDGE_BASE_DOCUMENT = "process_knowledge_base_document"
|
||||
|
|
|
|||
311
api/tasks/knowledge_base_processing.py
Normal file
311
api/tasks/knowledge_base_processing.py
Normal file
|
|
@ -0,0 +1,311 @@
|
|||
"""ARQ background task for processing knowledge base documents."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Literal
|
||||
|
||||
from docling.chunking import HybridChunker
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
|
||||
from loguru import logger
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import KnowledgeBaseChunkModel
|
||||
from api.services.gen_ai import (
|
||||
OpenAIEmbeddingService,
|
||||
SentenceTransformerEmbeddingService,
|
||||
)
|
||||
from api.services.storage import storage_fs
|
||||
|
||||
# For tokenization/chunking - use SentenceTransformer tokenizer as baseline
|
||||
TOKENIZER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
|
||||
async def process_knowledge_base_document(
|
||||
ctx,
|
||||
document_id: int,
|
||||
s3_key: str,
|
||||
organization_id: int,
|
||||
max_tokens: int = 128,
|
||||
embedding_service: Literal["sentence_transformer", "openai"] = "openai",
|
||||
):
|
||||
"""Process a knowledge base document: download, chunk, embed, and store.
|
||||
|
||||
Args:
|
||||
ctx: ARQ context
|
||||
document_id: Database ID of the document
|
||||
s3_key: S3 key where the file is stored
|
||||
organization_id: Organization ID
|
||||
max_tokens: Maximum number of tokens per chunk (default: 128)
|
||||
embedding_service: Embedding service to use (default: "openai")
|
||||
- "openai": Use OpenAI text-embedding-3-small (1536-dim, requires API key)
|
||||
- "sentence_transformer": Use SentenceTransformer (all-MiniLM-L6-v2, 384-dim, free)
|
||||
"""
|
||||
logger.info(
|
||||
f"Starting knowledge base document processing for document_id={document_id}, "
|
||||
f"s3_key={s3_key}, organization_id={organization_id}"
|
||||
)
|
||||
|
||||
temp_file_path = None
|
||||
|
||||
try:
|
||||
# Update status to processing
|
||||
await db_client.update_document_status(document_id, "processing")
|
||||
|
||||
# Extract file extension from S3 key
|
||||
filename = s3_key.split("/")[-1]
|
||||
file_extension = (
|
||||
os.path.splitext(filename)[1] or ".bin"
|
||||
) # Default to .bin if no extension
|
||||
|
||||
# Create temp file for download with correct extension
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=file_extension)
|
||||
temp_file_path = temp_file.name
|
||||
temp_file.close()
|
||||
|
||||
# Download file from S3
|
||||
logger.info(f"Downloading file from S3: {s3_key}")
|
||||
download_success = await storage_fs.adownload_file(s3_key, temp_file_path)
|
||||
|
||||
if not download_success:
|
||||
raise Exception(f"Failed to download file from S3: {s3_key}")
|
||||
|
||||
if not os.path.exists(temp_file_path):
|
||||
raise FileNotFoundError(f"Downloaded file not found: {temp_file_path}")
|
||||
|
||||
file_size = os.path.getsize(temp_file_path)
|
||||
logger.info(f"Downloaded file size: {file_size} bytes")
|
||||
|
||||
# Compute file hash and get mime type
|
||||
file_hash = db_client.compute_file_hash(temp_file_path)
|
||||
mime_type = db_client.get_mime_type(temp_file_path)
|
||||
filename = s3_key.split("/")[-1]
|
||||
|
||||
# Get document record
|
||||
document = await db_client.get_document_by_id(document_id)
|
||||
if not document:
|
||||
raise Exception(f"Document {document_id} not found")
|
||||
|
||||
# Check if a document with this hash already exists (reject duplicates)
|
||||
existing_doc = await db_client.get_document_by_hash(file_hash, organization_id)
|
||||
if existing_doc and existing_doc.id != document_id:
|
||||
error_message = (
|
||||
f"This file is a duplicate of '{existing_doc.filename}'. "
|
||||
f"Please delete the duplicate files and consolidate them into a single unique file before uploading."
|
||||
)
|
||||
logger.warning(
|
||||
f"Duplicate document detected: {document_id} is duplicate of {existing_doc.id} "
|
||||
f"({existing_doc.filename})"
|
||||
)
|
||||
# Update file metadata
|
||||
await db_client.update_document_metadata(
|
||||
document_id,
|
||||
file_size_bytes=file_size,
|
||||
file_hash=file_hash,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
# Mark as failed with duplicate error message
|
||||
await db_client.update_document_status(
|
||||
document_id,
|
||||
"failed",
|
||||
error_message=error_message,
|
||||
docling_metadata={
|
||||
"duplicate_of": existing_doc.document_uuid,
|
||||
"duplicate_filename": existing_doc.filename,
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
# Update document with file metadata
|
||||
await db_client.update_document_metadata(
|
||||
document_id,
|
||||
file_size_bytes=file_size,
|
||||
file_hash=file_hash,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
|
||||
# Initialize the embedding service based on the parameter
|
||||
if embedding_service == "openai":
|
||||
logger.info(
|
||||
f"Initializing OpenAI embedding service with max_tokens={max_tokens}"
|
||||
)
|
||||
# Try to get user's embeddings configuration
|
||||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
if document.created_by:
|
||||
user_config = await db_client.get_user_configurations(
|
||||
document.created_by
|
||||
)
|
||||
if user_config.embeddings:
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
logger.info(
|
||||
f"Using user embeddings config: model={embeddings_model}"
|
||||
)
|
||||
|
||||
# Check if API key is configured
|
||||
if not embeddings_api_key:
|
||||
error_message = (
|
||||
"OpenAI API key not configured. Please set your API key in "
|
||||
"Model Configurations > Embedding to process documents."
|
||||
)
|
||||
logger.warning(f"Document {document_id}: {error_message}")
|
||||
await db_client.update_document_status(
|
||||
document_id, "failed", error_message=error_message
|
||||
)
|
||||
return
|
||||
|
||||
service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
max_tokens=max_tokens,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
)
|
||||
elif embedding_service == "sentence_transformer":
|
||||
logger.info(
|
||||
f"Initializing SentenceTransformer embedding service with max_tokens={max_tokens}"
|
||||
)
|
||||
service = SentenceTransformerEmbeddingService(
|
||||
db_client=db_client,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid embedding_service: {embedding_service}. "
|
||||
f"Must be 'sentence_transformer' or 'openai'"
|
||||
)
|
||||
|
||||
# Step 1: Convert document with docling
|
||||
logger.info("Converting document with docling")
|
||||
converter = DocumentConverter()
|
||||
conversion_result = converter.convert(temp_file_path)
|
||||
doc = conversion_result.document
|
||||
|
||||
# Store docling metadata
|
||||
docling_metadata = {
|
||||
"num_pages": len(doc.pages) if hasattr(doc, "pages") else None,
|
||||
"document_type": type(doc).__name__,
|
||||
}
|
||||
|
||||
# Step 2: Initialize tokenizer for chunking
|
||||
logger.info(
|
||||
f"Loading tokenizer: {TOKENIZER_MODEL} with max_tokens={max_tokens}"
|
||||
)
|
||||
tokenizer = HuggingFaceTokenizer(
|
||||
tokenizer=AutoTokenizer.from_pretrained(TOKENIZER_MODEL),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# Step 3: Initialize chunker
|
||||
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
|
||||
chunker = HybridChunker(tokenizer=tokenizer)
|
||||
|
||||
# Step 4: Chunk the document
|
||||
logger.info(f"Chunking document with max_tokens={max_tokens}")
|
||||
chunks = list(chunker.chunk(dl_doc=doc))
|
||||
total_chunks = len(chunks)
|
||||
logger.info(f"Generated {total_chunks} chunks")
|
||||
|
||||
# Step 5: Process each chunk
|
||||
chunk_texts = []
|
||||
chunk_records = []
|
||||
token_counts = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk_text = chunk.text
|
||||
contextualized_text = chunker.contextualize(chunk=chunk)
|
||||
|
||||
# Calculate token count
|
||||
text_to_tokenize = (
|
||||
contextualized_text if contextualized_text else chunk_text
|
||||
)
|
||||
token_count = len(
|
||||
tokenizer.tokenizer.encode(text_to_tokenize, add_special_tokens=False)
|
||||
)
|
||||
token_counts.append(token_count)
|
||||
|
||||
# Prepare chunk metadata
|
||||
chunk_metadata = {}
|
||||
if hasattr(chunk, "meta") and chunk.meta:
|
||||
chunk_metadata = {
|
||||
"doc_items": (
|
||||
[str(item) for item in chunk.meta.doc_items]
|
||||
if hasattr(chunk.meta, "doc_items")
|
||||
else []
|
||||
),
|
||||
"headings": (
|
||||
chunk.meta.headings if hasattr(chunk.meta, "headings") else []
|
||||
),
|
||||
}
|
||||
|
||||
# Create chunk record (without embedding yet)
|
||||
chunk_record = KnowledgeBaseChunkModel(
|
||||
document_id=document_id,
|
||||
organization_id=organization_id,
|
||||
chunk_text=chunk_text,
|
||||
contextualized_text=contextualized_text,
|
||||
chunk_index=i,
|
||||
chunk_metadata=chunk_metadata,
|
||||
embedding_model=service.get_model_id(),
|
||||
embedding_dimension=service.get_embedding_dimension(),
|
||||
token_count=token_count,
|
||||
)
|
||||
|
||||
chunk_records.append(chunk_record)
|
||||
chunk_texts.append(text_to_tokenize)
|
||||
|
||||
# Log chunk statistics
|
||||
if token_counts:
|
||||
avg_tokens = sum(token_counts) / len(token_counts)
|
||||
min_tokens = min(token_counts)
|
||||
max_tokens_actual = max(token_counts)
|
||||
logger.info("Chunk token statistics:")
|
||||
logger.info(f" - Average: {avg_tokens:.1f} tokens")
|
||||
logger.info(f" - Min: {min_tokens} tokens")
|
||||
logger.info(f" - Max: {max_tokens_actual} tokens")
|
||||
|
||||
# Step 6: Generate embeddings using the embedding service
|
||||
logger.info(f"Generating embeddings using {embedding_service}")
|
||||
embeddings = await service.embed_texts(chunk_texts)
|
||||
|
||||
# Step 7: Attach embeddings to chunk records
|
||||
for chunk_record, embedding in zip(chunk_records, embeddings):
|
||||
chunk_record.embedding = embedding
|
||||
|
||||
# Step 8: Save chunks in database
|
||||
logger.info("Storing chunks in database")
|
||||
await db_client.create_chunks_batch(chunk_records)
|
||||
|
||||
# Step 9: Update document status to completed
|
||||
await db_client.update_document_status(
|
||||
document_id,
|
||||
"completed",
|
||||
total_chunks=total_chunks,
|
||||
docling_metadata=docling_metadata,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Successfully processed knowledge base document {document_id}. "
|
||||
f"Total chunks: {total_chunks}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing knowledge base document {document_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Update document status to failed
|
||||
await db_client.update_document_status(
|
||||
document_id, "failed", error_message=str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Clean up temp file
|
||||
if temp_file_path and os.path.exists(temp_file_path):
|
||||
try:
|
||||
os.remove(temp_file_path)
|
||||
logger.debug(f"Cleaned up temp file: {temp_file_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp file {temp_file_path}: {e}")
|
||||
Loading…
Add table
Add a link
Reference in a new issue