mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: knowledge base functionality for the voice agent (#120)
* feat: upload file and store embedding * feat: add documents in nodes * feat: add openai embedding service
This commit is contained in:
parent
e2fa4bbb98
commit
ef5b9e40a9
52 changed files with 4551 additions and 114 deletions
194
api/alembic/versions/dc33eef8dabe_add_document_tables.py
Normal file
194
api/alembic/versions/dc33eef8dabe_add_document_tables.py
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
"""add document tables
|
||||
|
||||
Revision ID: dc33eef8dabe
|
||||
Revises: dcb0a27d98c6
|
||||
Create Date: 2026-01-16 13:40:17.808807
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "dc33eef8dabe"
|
||||
down_revision: Union[str, None] = "dcb0a27d98c6"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# Enable pgvector extension
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
|
||||
sa.Enum(
|
||||
"pending",
|
||||
"processing",
|
||||
"completed",
|
||||
"failed",
|
||||
name="document_processing_status",
|
||||
).create(op.get_bind())
|
||||
op.create_table(
|
||||
"knowledge_base_documents",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("document_uuid", sa.String(length=36), nullable=False),
|
||||
sa.Column("organization_id", sa.Integer(), nullable=False),
|
||||
sa.Column("filename", sa.String(length=500), nullable=False),
|
||||
sa.Column("file_size_bytes", sa.Integer(), nullable=True),
|
||||
sa.Column("file_hash", sa.String(length=64), nullable=True),
|
||||
sa.Column("mime_type", sa.String(length=100), nullable=True),
|
||||
sa.Column("source_url", sa.String(), nullable=True),
|
||||
sa.Column("total_chunks", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"processing_status",
|
||||
postgresql.ENUM(
|
||||
"pending",
|
||||
"processing",
|
||||
"completed",
|
||||
"failed",
|
||||
name="document_processing_status",
|
||||
create_type=False,
|
||||
),
|
||||
server_default=sa.text("'pending'::document_processing_status"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("processing_error", sa.Text(), nullable=True),
|
||||
sa.Column("docling_metadata", sa.JSON(), nullable=False),
|
||||
sa.Column("custom_metadata", sa.JSON(), nullable=False),
|
||||
sa.Column("created_by", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("archived_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["created_by"],
|
||||
["users.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"], ["organizations.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kb_documents_created_at",
|
||||
"knowledge_base_documents",
|
||||
["created_at"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kb_documents_organization_id",
|
||||
"knowledge_base_documents",
|
||||
["organization_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kb_documents_status",
|
||||
"knowledge_base_documents",
|
||||
["processing_status"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kb_documents_uuid",
|
||||
"knowledge_base_documents",
|
||||
["document_uuid"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_knowledge_base_documents_document_uuid"),
|
||||
"knowledge_base_documents",
|
||||
["document_uuid"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_table(
|
||||
"knowledge_base_chunks",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("document_id", sa.Integer(), nullable=False),
|
||||
sa.Column("organization_id", sa.Integer(), nullable=False),
|
||||
sa.Column("chunk_text", sa.Text(), nullable=False),
|
||||
sa.Column("contextualized_text", sa.Text(), nullable=True),
|
||||
sa.Column("chunk_index", sa.Integer(), nullable=False),
|
||||
sa.Column("chunk_metadata", sa.JSON(), nullable=False),
|
||||
sa.Column("embedding_model", sa.String(length=200), nullable=False),
|
||||
sa.Column("embedding_dimension", sa.Integer(), nullable=False),
|
||||
sa.Column("embedding", Vector(1536), nullable=True),
|
||||
sa.Column("token_count", sa.Integer(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["document_id"], ["knowledge_base_documents.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"], ["organizations.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kb_chunks_chunk_index",
|
||||
"knowledge_base_chunks",
|
||||
["chunk_index"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kb_chunks_document_id",
|
||||
"knowledge_base_chunks",
|
||||
["document_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kb_chunks_embedding_ivfflat",
|
||||
"knowledge_base_chunks",
|
||||
["embedding"],
|
||||
unique=False,
|
||||
postgresql_using="ivfflat",
|
||||
postgresql_with={"lists": 100},
|
||||
postgresql_ops={"embedding": "vector_cosine_ops"},
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kb_chunks_organization_id",
|
||||
"knowledge_base_chunks",
|
||||
["organization_id"],
|
||||
unique=False,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("ix_kb_chunks_organization_id", table_name="knowledge_base_chunks")
|
||||
op.drop_index(
|
||||
"ix_kb_chunks_embedding_ivfflat",
|
||||
table_name="knowledge_base_chunks",
|
||||
postgresql_using="ivfflat",
|
||||
postgresql_with={"lists": 100},
|
||||
postgresql_ops={"embedding": "vector_cosine_ops"},
|
||||
)
|
||||
op.drop_index("ix_kb_chunks_document_id", table_name="knowledge_base_chunks")
|
||||
op.drop_index("ix_kb_chunks_chunk_index", table_name="knowledge_base_chunks")
|
||||
op.drop_table("knowledge_base_chunks")
|
||||
op.drop_index(
|
||||
op.f("ix_knowledge_base_documents_document_uuid"),
|
||||
table_name="knowledge_base_documents",
|
||||
)
|
||||
op.drop_index("ix_kb_documents_uuid", table_name="knowledge_base_documents")
|
||||
op.drop_index("ix_kb_documents_status", table_name="knowledge_base_documents")
|
||||
op.drop_index(
|
||||
"ix_kb_documents_organization_id", table_name="knowledge_base_documents"
|
||||
)
|
||||
op.drop_index("ix_kb_documents_created_at", table_name="knowledge_base_documents")
|
||||
op.drop_table("knowledge_base_documents")
|
||||
sa.Enum(
|
||||
"pending",
|
||||
"processing",
|
||||
"completed",
|
||||
"failed",
|
||||
name="document_processing_status",
|
||||
).drop(op.get_bind())
|
||||
|
||||
# Note: We don't drop the vector extension as it may be used by other tables
|
||||
# If you want to drop it, uncomment the following line:
|
||||
# op.execute('DROP EXTENSION IF EXISTS vector')
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -3,6 +3,7 @@ from api.db.api_key_client import APIKeyClient
|
|||
from api.db.campaign_client import CampaignClient
|
||||
from api.db.embed_token_client import EmbedTokenClient
|
||||
from api.db.integration_client import IntegrationClient
|
||||
from api.db.knowledge_base_client import KnowledgeBaseClient
|
||||
from api.db.looptalk_client import LoopTalkClient
|
||||
from api.db.organization_client import OrganizationClient
|
||||
from api.db.organization_configuration_client import OrganizationConfigurationClient
|
||||
|
|
@ -33,6 +34,7 @@ class DBClient(
|
|||
AgentTriggerClient,
|
||||
WebhookCredentialClient,
|
||||
ToolClient,
|
||||
KnowledgeBaseClient,
|
||||
):
|
||||
"""
|
||||
Unified database client that combines all specialized database operations.
|
||||
|
|
@ -54,6 +56,7 @@ class DBClient(
|
|||
- AgentTriggerClient: handles agent trigger operations for API-based call triggering
|
||||
- WebhookCredentialClient: handles webhook credential operations
|
||||
- ToolClient: handles tool operations for reusable HTTP API tools
|
||||
- KnowledgeBaseClient: handles knowledge base document and vector search operations
|
||||
"""
|
||||
|
||||
pass
|
||||
|
|
|
|||
499
api/db/knowledge_base_client.py
Normal file
499
api/db/knowledge_base_client.py
Normal file
|
|
@ -0,0 +1,499 @@
|
|||
"""Database client for managing knowledge base documents and chunks."""
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import KnowledgeBaseChunkModel, KnowledgeBaseDocumentModel
|
||||
|
||||
|
||||
class KnowledgeBaseClient(BaseDBClient):
|
||||
"""Client for managing knowledge base documents and vector embeddings."""
|
||||
|
||||
async def create_document(
|
||||
self,
|
||||
organization_id: int,
|
||||
created_by: int,
|
||||
filename: str,
|
||||
file_size_bytes: int,
|
||||
file_hash: str,
|
||||
mime_type: str,
|
||||
source_url: Optional[str] = None,
|
||||
custom_metadata: Optional[dict] = None,
|
||||
docling_metadata: Optional[dict] = None,
|
||||
document_uuid: Optional[str] = None,
|
||||
) -> KnowledgeBaseDocumentModel:
|
||||
"""Create a new knowledge base document record.
|
||||
|
||||
Args:
|
||||
organization_id: ID of the organization
|
||||
created_by: ID of the user uploading the document
|
||||
filename: Name of the file
|
||||
file_size_bytes: Size of the file in bytes
|
||||
file_hash: SHA-256 hash of the file
|
||||
mime_type: MIME type of the file
|
||||
source_url: Optional URL if document was fetched from web
|
||||
custom_metadata: Optional custom metadata dictionary
|
||||
docling_metadata: Optional docling processing metadata
|
||||
document_uuid: Optional UUID to use (if not provided, one will be generated)
|
||||
|
||||
Returns:
|
||||
The created KnowledgeBaseDocumentModel
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
document = KnowledgeBaseDocumentModel(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
filename=filename,
|
||||
file_size_bytes=file_size_bytes,
|
||||
file_hash=file_hash,
|
||||
mime_type=mime_type,
|
||||
source_url=source_url,
|
||||
custom_metadata=custom_metadata or {},
|
||||
docling_metadata=docling_metadata or {},
|
||||
processing_status="pending",
|
||||
total_chunks=0,
|
||||
)
|
||||
|
||||
# Use provided UUID or let the model generate one
|
||||
if document_uuid:
|
||||
document.document_uuid = document_uuid
|
||||
|
||||
session.add(document)
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
|
||||
logger.info(
|
||||
f"Created document '{filename}' ({document.document_uuid}) "
|
||||
f"for organization {organization_id}"
|
||||
)
|
||||
return document
|
||||
|
||||
async def get_document_by_id(
|
||||
self,
|
||||
document_id: int,
|
||||
) -> Optional[KnowledgeBaseDocumentModel]:
|
||||
"""Get a document by its database ID.
|
||||
|
||||
Args:
|
||||
document_id: The database ID of the document
|
||||
|
||||
Returns:
|
||||
KnowledgeBaseDocumentModel if found, None otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(KnowledgeBaseDocumentModel).where(
|
||||
KnowledgeBaseDocumentModel.id == document_id
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_document_by_uuid(
|
||||
self,
|
||||
document_uuid: str,
|
||||
organization_id: int,
|
||||
) -> Optional[KnowledgeBaseDocumentModel]:
|
||||
"""Get a document by its UUID, scoped to organization.
|
||||
|
||||
Args:
|
||||
document_uuid: The unique document UUID
|
||||
organization_id: ID of the organization
|
||||
|
||||
Returns:
|
||||
KnowledgeBaseDocumentModel if found, None otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(KnowledgeBaseDocumentModel)
|
||||
.where(
|
||||
KnowledgeBaseDocumentModel.document_uuid == document_uuid,
|
||||
KnowledgeBaseDocumentModel.organization_id == organization_id,
|
||||
KnowledgeBaseDocumentModel.is_active == True,
|
||||
)
|
||||
.options(selectinload(KnowledgeBaseDocumentModel.created_by_user))
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_document_by_hash(
|
||||
self,
|
||||
file_hash: str,
|
||||
organization_id: int,
|
||||
) -> Optional[KnowledgeBaseDocumentModel]:
|
||||
"""Check if a document with the same hash already exists.
|
||||
|
||||
Returns the first matching document if multiple exist (can happen with duplicates).
|
||||
|
||||
Args:
|
||||
file_hash: SHA-256 hash of the file
|
||||
organization_id: ID of the organization
|
||||
|
||||
Returns:
|
||||
KnowledgeBaseDocumentModel if found, None otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(KnowledgeBaseDocumentModel)
|
||||
.where(
|
||||
KnowledgeBaseDocumentModel.file_hash == file_hash,
|
||||
KnowledgeBaseDocumentModel.organization_id == organization_id,
|
||||
KnowledgeBaseDocumentModel.is_active == True,
|
||||
)
|
||||
.order_by(KnowledgeBaseDocumentModel.created_at.asc())
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_documents_for_organization(
|
||||
self,
|
||||
organization_id: int,
|
||||
processing_status: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> List[KnowledgeBaseDocumentModel]:
|
||||
"""Get all documents for an organization.
|
||||
|
||||
Args:
|
||||
organization_id: ID of the organization
|
||||
processing_status: Optional filter by status
|
||||
limit: Maximum number of documents to return
|
||||
offset: Number of documents to skip
|
||||
|
||||
Returns:
|
||||
List of KnowledgeBaseDocumentModel instances
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(KnowledgeBaseDocumentModel).where(
|
||||
KnowledgeBaseDocumentModel.organization_id == organization_id,
|
||||
KnowledgeBaseDocumentModel.is_active == True,
|
||||
)
|
||||
|
||||
if processing_status:
|
||||
query = query.where(
|
||||
KnowledgeBaseDocumentModel.processing_status == processing_status
|
||||
)
|
||||
|
||||
query = (
|
||||
query.order_by(KnowledgeBaseDocumentModel.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def update_document_metadata(
|
||||
self,
|
||||
document_id: int,
|
||||
file_size_bytes: Optional[int] = None,
|
||||
file_hash: Optional[str] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
) -> Optional[KnowledgeBaseDocumentModel]:
|
||||
"""Update document file metadata.
|
||||
|
||||
Args:
|
||||
document_id: ID of the document
|
||||
file_size_bytes: Optional file size in bytes
|
||||
file_hash: Optional SHA-256 hash of the file
|
||||
mime_type: Optional MIME type
|
||||
|
||||
Returns:
|
||||
Updated KnowledgeBaseDocumentModel
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(KnowledgeBaseDocumentModel).where(
|
||||
KnowledgeBaseDocumentModel.id == document_id
|
||||
)
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if not document:
|
||||
return None
|
||||
|
||||
if file_size_bytes is not None:
|
||||
document.file_size_bytes = file_size_bytes
|
||||
if file_hash is not None:
|
||||
document.file_hash = file_hash
|
||||
if mime_type is not None:
|
||||
document.mime_type = mime_type
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
|
||||
logger.info(f"Updated document {document_id} metadata")
|
||||
return document
|
||||
|
||||
async def update_document_status(
|
||||
self,
|
||||
document_id: int,
|
||||
status: str,
|
||||
error_message: Optional[str] = None,
|
||||
total_chunks: Optional[int] = None,
|
||||
docling_metadata: Optional[dict] = None,
|
||||
) -> Optional[KnowledgeBaseDocumentModel]:
|
||||
"""Update document processing status.
|
||||
|
||||
Args:
|
||||
document_id: ID of the document
|
||||
status: New status (pending, processing, completed, failed)
|
||||
error_message: Optional error message if status is failed
|
||||
total_chunks: Optional total number of chunks
|
||||
docling_metadata: Optional docling metadata
|
||||
|
||||
Returns:
|
||||
Updated KnowledgeBaseDocumentModel
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(KnowledgeBaseDocumentModel).where(
|
||||
KnowledgeBaseDocumentModel.id == document_id
|
||||
)
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if not document:
|
||||
return None
|
||||
|
||||
document.processing_status = status
|
||||
if error_message:
|
||||
document.processing_error = error_message
|
||||
if total_chunks is not None:
|
||||
document.total_chunks = total_chunks
|
||||
if docling_metadata:
|
||||
document.docling_metadata = docling_metadata
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
|
||||
logger.info(f"Updated document {document_id} status to {status}")
|
||||
return document
|
||||
|
||||
async def create_chunks_batch(
|
||||
self,
|
||||
chunks: List[KnowledgeBaseChunkModel],
|
||||
) -> List[KnowledgeBaseChunkModel]:
|
||||
"""Create multiple chunks in a batch.
|
||||
|
||||
Args:
|
||||
chunks: List of KnowledgeBaseChunkModel instances
|
||||
|
||||
Returns:
|
||||
List of created chunks with IDs
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
session.add_all(chunks)
|
||||
await session.commit()
|
||||
|
||||
for chunk in chunks:
|
||||
await session.refresh(chunk)
|
||||
|
||||
logger.info(f"Created {len(chunks)} chunks")
|
||||
return chunks
|
||||
|
||||
async def get_chunks_for_document(
|
||||
self,
|
||||
document_id: int,
|
||||
organization_id: int,
|
||||
) -> List[KnowledgeBaseChunkModel]:
|
||||
"""Get all chunks for a document.
|
||||
|
||||
Args:
|
||||
document_id: ID of the document
|
||||
organization_id: ID of the organization (for authorization)
|
||||
|
||||
Returns:
|
||||
List of KnowledgeBaseChunkModel instances
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(KnowledgeBaseChunkModel)
|
||||
.where(
|
||||
KnowledgeBaseChunkModel.document_id == document_id,
|
||||
KnowledgeBaseChunkModel.organization_id == organization_id,
|
||||
)
|
||||
.order_by(KnowledgeBaseChunkModel.chunk_index)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def search_similar_chunks(
|
||||
self,
|
||||
query_embedding: List[float],
|
||||
organization_id: int,
|
||||
limit: int = 5,
|
||||
document_ids: Optional[List[int]] = None,
|
||||
document_uuids: Optional[List[str]] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
) -> List[dict]:
|
||||
"""Search for similar chunks using vector similarity.
|
||||
|
||||
Returns top-k most similar chunks without any similarity threshold filtering.
|
||||
Filtering and reranking should be done at the application layer.
|
||||
|
||||
Args:
|
||||
query_embedding: The query embedding vector
|
||||
organization_id: Organization ID for scoping
|
||||
limit: Maximum number of results to return
|
||||
document_ids: Optional list of document IDs to filter by
|
||||
document_uuids: Optional list of document UUIDs to filter by
|
||||
embedding_model: Optional embedding model to filter by (for dimension compatibility)
|
||||
|
||||
Returns:
|
||||
List of dictionaries with chunk data and similarity scores, ordered by similarity (highest first)
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# Get the raw connection to execute directly with asyncpg
|
||||
# This avoids parameter binding issues with text() and asyncpg
|
||||
connection = await session.connection()
|
||||
raw_connection = await connection.get_raw_connection()
|
||||
|
||||
# Build WHERE clause conditions (no similarity threshold)
|
||||
where_conditions = [
|
||||
"c.organization_id = $2",
|
||||
"d.is_active = true",
|
||||
]
|
||||
params = [
|
||||
None,
|
||||
organization_id,
|
||||
limit,
|
||||
] # $1 will be embedding_str, $3 is limit
|
||||
param_index = 4 # Next available parameter index
|
||||
|
||||
# Add document_ids filter if provided
|
||||
if document_ids:
|
||||
placeholders = ", ".join(
|
||||
f"${param_index + i}" for i in range(len(document_ids))
|
||||
)
|
||||
where_conditions.append(f"c.document_id IN ({placeholders})")
|
||||
params.extend(document_ids)
|
||||
param_index += len(document_ids)
|
||||
|
||||
# Add document_uuids filter if provided
|
||||
if document_uuids:
|
||||
placeholders = ", ".join(
|
||||
f"${param_index + i}" for i in range(len(document_uuids))
|
||||
)
|
||||
where_conditions.append(f"d.document_uuid IN ({placeholders})")
|
||||
params.extend(document_uuids)
|
||||
param_index += len(document_uuids)
|
||||
|
||||
# Add embedding_model filter if provided (for dimension compatibility)
|
||||
if embedding_model:
|
||||
where_conditions.append(f"c.embedding_model = ${param_index}")
|
||||
params.append(embedding_model)
|
||||
param_index += 1
|
||||
|
||||
# Build the complete SQL query
|
||||
where_clause = " AND ".join(where_conditions)
|
||||
query_sql = f"""
|
||||
SELECT
|
||||
c.id,
|
||||
c.document_id,
|
||||
c.chunk_text,
|
||||
c.contextualized_text,
|
||||
c.chunk_metadata,
|
||||
c.chunk_index,
|
||||
d.filename,
|
||||
d.document_uuid,
|
||||
1 - (c.embedding <=> $1::vector) as similarity
|
||||
FROM knowledge_base_chunks c
|
||||
JOIN knowledge_base_documents d ON c.document_id = d.id
|
||||
WHERE {where_clause}
|
||||
ORDER BY c.embedding <=> $1::vector
|
||||
LIMIT $3
|
||||
"""
|
||||
|
||||
# Convert embedding to string format for PostgreSQL vector type
|
||||
embedding_str = "[" + ",".join(map(str, query_embedding)) + "]"
|
||||
params[0] = embedding_str # Set $1
|
||||
|
||||
# Execute query directly with asyncpg
|
||||
rows = await raw_connection.driver_connection.fetch(
|
||||
query_sql,
|
||||
*params,
|
||||
)
|
||||
|
||||
# Convert asyncpg records to dictionaries
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
async def delete_document(
|
||||
self,
|
||||
document_uuid: str,
|
||||
organization_id: int,
|
||||
) -> bool:
|
||||
"""Soft delete a document by setting is_active to False.
|
||||
|
||||
This will also cascade delete all chunks via the database foreign key.
|
||||
|
||||
Args:
|
||||
document_uuid: The unique document UUID
|
||||
organization_id: ID of the organization (for authorization)
|
||||
|
||||
Returns:
|
||||
True if document was deleted, False if not found
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(KnowledgeBaseDocumentModel).where(
|
||||
KnowledgeBaseDocumentModel.document_uuid == document_uuid,
|
||||
KnowledgeBaseDocumentModel.organization_id == organization_id,
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if not document:
|
||||
return False
|
||||
|
||||
document.is_active = False
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Deleted document {document_uuid} for organization {organization_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def compute_file_hash(file_path: str) -> str:
|
||||
"""Compute SHA-256 hash of a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
SHA-256 hash as hex string
|
||||
"""
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
for byte_block in iter(lambda: f.read(4096), b""):
|
||||
sha256_hash.update(byte_block)
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def get_mime_type(file_path: str) -> str:
|
||||
"""Get MIME type based on file extension.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
MIME type string
|
||||
"""
|
||||
extension = Path(file_path).suffix.lower()
|
||||
mime_types = {
|
||||
".pdf": "application/pdf",
|
||||
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
".doc": "application/msword",
|
||||
".txt": "text/plain",
|
||||
".html": "text/html",
|
||||
".md": "text/markdown",
|
||||
}
|
||||
return mime_types.get(extension, "application/octet-stream")
|
||||
178
api/db/models.py
178
api/db/models.py
|
|
@ -2,6 +2,7 @@ import uuid
|
|||
from datetime import UTC, datetime
|
||||
|
||||
from loguru import logger
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
Boolean,
|
||||
|
|
@ -14,6 +15,7 @@ from sqlalchemy import (
|
|||
Integer,
|
||||
String,
|
||||
Table,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
and_,
|
||||
text,
|
||||
|
|
@ -890,3 +892,179 @@ class ToolModel(Base):
|
|||
Index("ix_tools_status", "status"),
|
||||
Index("ix_tools_category", "category"),
|
||||
)
|
||||
|
||||
|
||||
class KnowledgeBaseDocumentModel(Base):
|
||||
"""Model for storing document-level metadata in the knowledge base.
|
||||
|
||||
Each document represents a source file (PDF, DOCX, etc.) that has been
|
||||
processed and chunked for retrieval.
|
||||
"""
|
||||
|
||||
__tablename__ = "knowledge_base_documents"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
# Public identifier for API references
|
||||
document_uuid = Column(
|
||||
String(36),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
index=True,
|
||||
default=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
# Organization scoping
|
||||
organization_id = Column(
|
||||
Integer, ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
# Document metadata
|
||||
filename = Column(String(500), nullable=False)
|
||||
file_size_bytes = Column(Integer, nullable=True)
|
||||
file_hash = Column(String(64), nullable=True) # SHA-256 hash for deduplication
|
||||
mime_type = Column(String(100), nullable=True)
|
||||
|
||||
# Processing metadata
|
||||
source_url = Column(String, nullable=True) # If document was fetched from URL
|
||||
total_chunks = Column(Integer, nullable=False, default=0)
|
||||
processing_status = Column(
|
||||
Enum(
|
||||
"pending",
|
||||
"processing",
|
||||
"completed",
|
||||
"failed",
|
||||
name="document_processing_status",
|
||||
),
|
||||
nullable=False,
|
||||
default="pending",
|
||||
server_default=text("'pending'::document_processing_status"),
|
||||
)
|
||||
processing_error = Column(Text, nullable=True)
|
||||
|
||||
# Docling conversion metadata
|
||||
docling_metadata = Column(
|
||||
JSON, nullable=False, default=dict
|
||||
) # Store docling document metadata
|
||||
|
||||
# Custom metadata (user-defined tags, categories, etc.)
|
||||
custom_metadata = Column(JSON, nullable=False, default=dict)
|
||||
|
||||
# Audit fields
|
||||
created_by = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
)
|
||||
|
||||
# Soft delete
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
archived_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationships
|
||||
organization = relationship("OrganizationModel")
|
||||
created_by_user = relationship("UserModel")
|
||||
chunks = relationship(
|
||||
"KnowledgeBaseChunkModel",
|
||||
back_populates="document",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Indexes and constraints
|
||||
__table_args__ = (
|
||||
Index("ix_kb_documents_organization_id", "organization_id"),
|
||||
Index("ix_kb_documents_uuid", "document_uuid"),
|
||||
Index("ix_kb_documents_status", "processing_status"),
|
||||
Index("ix_kb_documents_created_at", "created_at"),
|
||||
)
|
||||
|
||||
|
||||
class KnowledgeBaseChunkModel(Base):
|
||||
"""Model for storing document chunks with vector embeddings.
|
||||
|
||||
Each chunk represents a portion of a document that has been:
|
||||
1. Extracted and chunked by docling's HybridChunker
|
||||
2. Optionally contextualized with surrounding information
|
||||
3. Embedded into a vector representation for semantic search
|
||||
"""
|
||||
|
||||
__tablename__ = "knowledge_base_chunks"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
# Link to parent document
|
||||
document_id = Column(
|
||||
Integer,
|
||||
ForeignKey("knowledge_base_documents.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Organization scoping (denormalized for efficient querying)
|
||||
organization_id = Column(
|
||||
Integer, ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
# Chunk content
|
||||
chunk_text = Column(Text, nullable=False) # The actual chunk text
|
||||
contextualized_text = Column(
|
||||
Text, nullable=True
|
||||
) # Enriched text from chunker.contextualize()
|
||||
|
||||
# Chunk positioning and metadata
|
||||
chunk_index = Column(Integer, nullable=False) # Position in document (0-based)
|
||||
|
||||
# Docling chunk metadata
|
||||
chunk_metadata = Column(
|
||||
JSON, nullable=False, default=dict
|
||||
) # Store chunk.meta if available
|
||||
|
||||
# Embedding configuration
|
||||
embedding_model = Column(
|
||||
String(200), nullable=False
|
||||
) # e.g., "sentence-transformers/all-MiniLM-L6-v2"
|
||||
embedding_dimension = Column(
|
||||
Integer, nullable=False
|
||||
) # e.g., 384 for all-MiniLM-L6-v2
|
||||
|
||||
# Vector embedding (pgvector column)
|
||||
# The dimension should match the embedding_dimension field
|
||||
# Default: 1536 dimensions for OpenAI text-embedding-3-small
|
||||
# SentenceTransformer (384-dim) also supported but stored as 384-dim vectors
|
||||
embedding = Column(Vector(1536), nullable=True)
|
||||
|
||||
# Token count (useful for chunking strategy analysis)
|
||||
token_count = Column(Integer, nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
)
|
||||
|
||||
# Relationships
|
||||
document = relationship("KnowledgeBaseDocumentModel", back_populates="chunks")
|
||||
organization = relationship("OrganizationModel")
|
||||
|
||||
# Indexes and constraints
|
||||
__table_args__ = (
|
||||
Index("ix_kb_chunks_document_id", "document_id"),
|
||||
Index("ix_kb_chunks_organization_id", "organization_id"),
|
||||
Index("ix_kb_chunks_chunk_index", "chunk_index"),
|
||||
Index(
|
||||
"ix_kb_chunks_embedding_model", "embedding_model"
|
||||
), # For filtering by model
|
||||
# Vector similarity search index (using IVFFlat or HNSW)
|
||||
# IVFFlat is good for datasets with 10k-1M vectors
|
||||
# HNSW is better for larger datasets but uses more memory
|
||||
Index(
|
||||
"ix_kb_chunks_embedding_ivfflat",
|
||||
"embedding",
|
||||
postgresql_using="ivfflat",
|
||||
postgresql_with={"lists": 100}, # Adjust based on dataset size
|
||||
postgresql_ops={"embedding": "vector_cosine_ops"},
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,3 +13,6 @@ python-multipart==0.0.20
|
|||
sentry-sdk[fastapi]==2.38.0
|
||||
sqlalchemy[asyncio]==2.0.43
|
||||
msgpack==1.1.2
|
||||
docling[rapidocr]==2.68.0
|
||||
sentence-transformers==5.2.0
|
||||
pgvector==0.4.2
|
||||
|
|
|
|||
405
api/routes/knowledge_base.py
Normal file
405
api/routes/knowledge_base.py
Normal file
|
|
@ -0,0 +1,405 @@
|
|||
"""API routes for knowledge base operations."""
|
||||
|
||||
import uuid
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.schemas.knowledge_base import (
|
||||
ChunkSearchRequestSchema,
|
||||
ChunkSearchResponseSchema,
|
||||
DocumentListResponseSchema,
|
||||
DocumentResponseSchema,
|
||||
DocumentUploadRequestSchema,
|
||||
DocumentUploadResponseSchema,
|
||||
ProcessDocumentRequestSchema,
|
||||
)
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.storage import storage_fs
|
||||
from api.tasks.arq import enqueue_job
|
||||
from api.tasks.function_names import FunctionNames
|
||||
|
||||
router = APIRouter(prefix="/knowledge-base", tags=["knowledge-base"])
|
||||
|
||||
|
||||
@router.post(
|
||||
"/upload-url",
|
||||
response_model=DocumentUploadResponseSchema,
|
||||
summary="Get presigned URL for document upload",
|
||||
)
|
||||
async def get_upload_url(
|
||||
request: DocumentUploadRequestSchema,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Generate a presigned PUT URL for uploading a document.
|
||||
|
||||
This endpoint:
|
||||
1. Generates a unique document UUID for organizing the S3 key
|
||||
2. Generates a presigned S3/MinIO URL for uploading the file
|
||||
3. Returns the upload URL and document metadata
|
||||
|
||||
After uploading to the returned URL, call /process-document to create
|
||||
the document record and trigger processing.
|
||||
|
||||
Access Control:
|
||||
* All authenticated users can upload documents scoped to their organization.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Generate unique document UUID for S3 organization
|
||||
document_uuid = str(uuid.uuid4())
|
||||
|
||||
# Generate S3 key: knowledge_base/{org_id}/{document_uuid}/{filename}
|
||||
s3_key = f"knowledge_base/{user.selected_organization_id}/{document_uuid}/{request.filename}"
|
||||
|
||||
# Generate presigned PUT URL (valid for 30 minutes)
|
||||
upload_url = await storage_fs.aget_presigned_put_url(
|
||||
file_path=s3_key,
|
||||
expiration=1800, # 30 minutes
|
||||
content_type=request.mime_type,
|
||||
max_size=100_000_000, # 100MB max
|
||||
)
|
||||
|
||||
if not upload_url:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to generate presigned upload URL"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Generated upload URL for document {document_uuid}, "
|
||||
f"user {user.id}, org {user.selected_organization_id}"
|
||||
)
|
||||
|
||||
return DocumentUploadResponseSchema(
|
||||
upload_url=upload_url,
|
||||
document_uuid=document_uuid,
|
||||
s3_key=s3_key,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error generating upload URL: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to generate upload URL"
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/process-document",
|
||||
response_model=DocumentResponseSchema,
|
||||
summary="Trigger document processing",
|
||||
)
|
||||
async def process_document(
|
||||
request: ProcessDocumentRequestSchema,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Trigger asynchronous processing of an uploaded document.
|
||||
|
||||
This endpoint should be called after successfully uploading a file to the presigned URL.
|
||||
It will:
|
||||
1. Create a document record in the database with the specified UUID
|
||||
2. Enqueue a background task to process the document (chunking and embedding)
|
||||
|
||||
The document status will be updated from 'pending' -> 'processing' -> 'completed' or 'failed'.
|
||||
|
||||
Embedding Services:
|
||||
* openai (default): High-quality 1536-dimensional embeddings (requires OPENAI_API_KEY)
|
||||
* sentence_transformer: Free, offline-capable, 384-dimensional embeddings
|
||||
|
||||
Access Control:
|
||||
* Users can only process documents in their organization.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Extract filename from s3_key
|
||||
filename = request.s3_key.split("/")[-1]
|
||||
|
||||
# Create document record with the specific UUID from upload
|
||||
document = await db_client.create_document(
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.id,
|
||||
filename=filename,
|
||||
file_size_bytes=0, # Will be updated by background task
|
||||
file_hash="", # Will be computed by background task
|
||||
mime_type="application/octet-stream", # Will be detected by background task
|
||||
custom_metadata={"s3_key": request.s3_key},
|
||||
document_uuid=request.document_uuid, # Use UUID from upload
|
||||
)
|
||||
|
||||
# Enqueue background task for processing
|
||||
await enqueue_job(
|
||||
FunctionNames.PROCESS_KNOWLEDGE_BASE_DOCUMENT,
|
||||
document.id,
|
||||
request.s3_key,
|
||||
user.selected_organization_id,
|
||||
128, # max_tokens (default)
|
||||
request.embedding_service,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created document {request.document_uuid} (id={document.id}) and enqueued processing "
|
||||
f"with {request.embedding_service} embeddings, org {user.selected_organization_id}"
|
||||
)
|
||||
|
||||
return DocumentResponseSchema(
|
||||
id=document.id,
|
||||
document_uuid=request.document_uuid,
|
||||
filename=filename,
|
||||
file_size_bytes=0,
|
||||
file_hash="",
|
||||
mime_type="application/octet-stream",
|
||||
processing_status="pending",
|
||||
processing_error=None,
|
||||
total_chunks=0,
|
||||
custom_metadata={"s3_key": request.s3_key},
|
||||
docling_metadata={},
|
||||
source_url=None,
|
||||
created_at=document.created_at,
|
||||
updated_at=document.updated_at,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.id,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Error processing document: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to process document"
|
||||
) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/documents",
|
||||
response_model=DocumentListResponseSchema,
|
||||
summary="List documents",
|
||||
)
|
||||
async def list_documents(
|
||||
status: Annotated[
|
||||
Optional[str],
|
||||
Query(description="Filter by processing status"),
|
||||
] = None,
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 100,
|
||||
offset: Annotated[int, Query(ge=0)] = 0,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""List all documents for the user's organization.
|
||||
|
||||
Access Control:
|
||||
* Users can only see documents from their organization.
|
||||
"""
|
||||
|
||||
try:
|
||||
documents = await db_client.get_documents_for_organization(
|
||||
organization_id=user.selected_organization_id,
|
||||
processing_status=status,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
# Convert to response schema
|
||||
document_list = [
|
||||
DocumentResponseSchema(
|
||||
id=doc.id,
|
||||
document_uuid=doc.document_uuid,
|
||||
filename=doc.filename,
|
||||
file_size_bytes=doc.file_size_bytes,
|
||||
file_hash=doc.file_hash,
|
||||
mime_type=doc.mime_type,
|
||||
processing_status=doc.processing_status,
|
||||
processing_error=doc.processing_error,
|
||||
total_chunks=doc.total_chunks,
|
||||
custom_metadata=doc.custom_metadata,
|
||||
docling_metadata=doc.docling_metadata,
|
||||
source_url=doc.source_url,
|
||||
created_at=doc.created_at,
|
||||
updated_at=doc.updated_at,
|
||||
organization_id=doc.organization_id,
|
||||
created_by=doc.created_by,
|
||||
is_active=doc.is_active,
|
||||
)
|
||||
for doc in documents
|
||||
]
|
||||
|
||||
return DocumentListResponseSchema(
|
||||
documents=document_list,
|
||||
total=len(document_list),
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error listing documents: {exc}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list documents") from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/documents/{document_uuid}",
|
||||
response_model=DocumentResponseSchema,
|
||||
summary="Get document details",
|
||||
)
|
||||
async def get_document(
|
||||
document_uuid: str,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Get details of a specific document.
|
||||
|
||||
Access Control:
|
||||
* Users can only access documents from their organization.
|
||||
"""
|
||||
|
||||
try:
|
||||
document = await db_client.get_document_by_uuid(
|
||||
document_uuid=document_uuid,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
|
||||
if not document:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
return DocumentResponseSchema(
|
||||
id=document.id,
|
||||
document_uuid=document.document_uuid,
|
||||
filename=document.filename,
|
||||
file_size_bytes=document.file_size_bytes,
|
||||
file_hash=document.file_hash,
|
||||
mime_type=document.mime_type,
|
||||
processing_status=document.processing_status,
|
||||
processing_error=document.processing_error,
|
||||
total_chunks=document.total_chunks,
|
||||
custom_metadata=document.custom_metadata,
|
||||
docling_metadata=document.docling_metadata,
|
||||
source_url=document.source_url,
|
||||
created_at=document.created_at,
|
||||
updated_at=document.updated_at,
|
||||
organization_id=document.organization_id,
|
||||
created_by=document.created_by,
|
||||
is_active=document.is_active,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Error getting document: {exc}")
|
||||
raise HTTPException(status_code=500, detail="Failed to get document") from exc
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/documents/{document_uuid}",
|
||||
summary="Delete document",
|
||||
)
|
||||
async def delete_document(
|
||||
document_uuid: str,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Soft delete a document and its chunks.
|
||||
|
||||
Access Control:
|
||||
* Users can only delete documents from their organization.
|
||||
"""
|
||||
|
||||
try:
|
||||
success = await db_client.delete_document(
|
||||
document_uuid=document_uuid,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
logger.info(
|
||||
f"Deleted document {document_uuid}, "
|
||||
f"user {user.id}, org {user.selected_organization_id}"
|
||||
)
|
||||
|
||||
return {"success": True, "message": "Document deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Error deleting document: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to delete document"
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/search",
|
||||
response_model=ChunkSearchResponseSchema,
|
||||
summary="Search for similar chunks",
|
||||
)
|
||||
async def search_chunks(
|
||||
request: ChunkSearchRequestSchema,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Search for document chunks similar to the query.
|
||||
|
||||
This endpoint uses vector similarity search to find relevant chunks.
|
||||
Results are returned without threshold filtering - apply similarity
|
||||
thresholds at the application layer after optional reranking.
|
||||
|
||||
Access Control:
|
||||
* Users can only search documents from their organization.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Import here to avoid circular dependency
|
||||
from api.services.gen_ai import OpenAIEmbeddingService
|
||||
|
||||
# Try to get user's embeddings configuration
|
||||
user_config = await db_client.get_user_configurations(user.id)
|
||||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
|
||||
if user_config.embeddings:
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
|
||||
# Initialize embedding service with user config or fallback to env
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
)
|
||||
|
||||
# Perform search
|
||||
results = await embedding_service.search_similar_chunks(
|
||||
query=request.query,
|
||||
organization_id=user.selected_organization_id,
|
||||
limit=request.limit,
|
||||
document_uuids=request.document_uuids,
|
||||
)
|
||||
|
||||
# Apply similarity threshold if provided
|
||||
if request.min_similarity is not None:
|
||||
results = [r for r in results if r["similarity"] >= request.min_similarity]
|
||||
|
||||
# Convert to response schema
|
||||
from api.schemas.knowledge_base import ChunkResponseSchema
|
||||
|
||||
chunks = [
|
||||
ChunkResponseSchema(
|
||||
id=r["id"],
|
||||
document_id=r["document_id"],
|
||||
chunk_text=r["chunk_text"],
|
||||
contextualized_text=r.get("contextualized_text"),
|
||||
chunk_index=r["chunk_index"],
|
||||
chunk_metadata=r["chunk_metadata"],
|
||||
filename=r["filename"],
|
||||
document_uuid=r["document_uuid"],
|
||||
similarity=r["similarity"],
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
|
||||
return ChunkSearchResponseSchema(
|
||||
chunks=chunks,
|
||||
query=request.query,
|
||||
total_results=len(chunks),
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error searching chunks: {exc}")
|
||||
raise HTTPException(status_code=500, detail="Failed to search chunks") from exc
|
||||
|
|
@ -4,6 +4,7 @@ from loguru import logger
|
|||
from api.routes.campaign import router as campaign_router
|
||||
from api.routes.credentials import router as credentials_router
|
||||
from api.routes.integration import router as integration_router
|
||||
from api.routes.knowledge_base import router as knowledge_base_router
|
||||
from api.routes.looptalk import router as looptalk_router
|
||||
from api.routes.organization import router as organization_router
|
||||
from api.routes.organization_usage import router as organization_usage_router
|
||||
|
|
@ -43,6 +44,7 @@ router.include_router(webrtc_signaling_router)
|
|||
router.include_router(public_embed_router)
|
||||
router.include_router(public_agent_router)
|
||||
router.include_router(workflow_embed_router)
|
||||
router.include_router(knowledge_base_router)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ class DefaultConfigurationsResponse(TypedDict):
|
|||
llm: dict[str, dict]
|
||||
tts: dict[str, dict]
|
||||
stt: dict[str, dict]
|
||||
embeddings: dict[str, dict]
|
||||
default_providers: dict[str, str]
|
||||
|
||||
|
||||
|
|
@ -50,6 +51,10 @@ async def get_default_configurations() -> DefaultConfigurationsResponse:
|
|||
provider: model_cls.model_json_schema()
|
||||
for provider, model_cls in REGISTRY[ServiceType.STT].items()
|
||||
},
|
||||
"embeddings": {
|
||||
provider: model_cls.model_json_schema()
|
||||
for provider, model_cls in REGISTRY[ServiceType.EMBEDDINGS].items()
|
||||
},
|
||||
"default_providers": DEFAULT_SERVICE_PROVIDERS,
|
||||
}
|
||||
return configurations
|
||||
|
|
@ -69,6 +74,7 @@ class UserConfigurationRequestResponseSchema(BaseModel):
|
|||
llm: dict[str, Union[str, float]] | None = None
|
||||
tts: dict[str, Union[str, float]] | None = None
|
||||
stt: dict[str, Union[str, float]] | None = None
|
||||
embeddings: dict[str, Union[str, float]] | None = None
|
||||
test_phone_number: str | None = None
|
||||
timezone: str | None = None
|
||||
organization_pricing: dict[str, Union[float, str, bool]] | None = None
|
||||
|
|
|
|||
102
api/schemas/knowledge_base.py
Normal file
102
api/schemas/knowledge_base.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
"""Pydantic schemas for knowledge base operations."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DocumentUploadRequestSchema(BaseModel):
|
||||
"""Request schema for initiating document upload."""
|
||||
|
||||
filename: str = Field(..., description="Name of the file to upload")
|
||||
mime_type: str = Field(..., description="MIME type of the file")
|
||||
custom_metadata: Optional[Dict[str, Any]] = Field(
|
||||
default=None, description="Optional custom metadata"
|
||||
)
|
||||
|
||||
|
||||
class DocumentUploadResponseSchema(BaseModel):
|
||||
"""Response schema containing upload URL and document metadata."""
|
||||
|
||||
upload_url: str = Field(..., description="Signed URL for uploading the file")
|
||||
document_uuid: str = Field(..., description="Unique identifier for the document")
|
||||
s3_key: str = Field(..., description="S3 key where file should be uploaded")
|
||||
|
||||
|
||||
class ProcessDocumentRequestSchema(BaseModel):
|
||||
"""Request schema for triggering document processing."""
|
||||
|
||||
document_uuid: str = Field(..., description="Document UUID to process")
|
||||
s3_key: str = Field(..., description="S3 key of the uploaded file")
|
||||
embedding_service: Literal["sentence_transformer", "openai"] = Field(
|
||||
default="openai",
|
||||
description="Embedding service to use for processing. "
|
||||
"Options: 'openai' (default, 1536-dim, requires API key) or 'sentence_transformer' (free, 384-dim)",
|
||||
)
|
||||
|
||||
|
||||
class DocumentResponseSchema(BaseModel):
|
||||
"""Response schema for document metadata."""
|
||||
|
||||
id: int
|
||||
document_uuid: str
|
||||
filename: str
|
||||
file_size_bytes: int
|
||||
file_hash: str
|
||||
mime_type: str
|
||||
processing_status: str # pending, processing, completed, failed
|
||||
processing_error: Optional[str] = None
|
||||
total_chunks: int
|
||||
custom_metadata: Dict[str, Any]
|
||||
docling_metadata: Dict[str, Any]
|
||||
source_url: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
organization_id: int
|
||||
created_by: int
|
||||
is_active: bool
|
||||
|
||||
|
||||
class DocumentListResponseSchema(BaseModel):
|
||||
"""Response schema for list of documents."""
|
||||
|
||||
documents: List[DocumentResponseSchema]
|
||||
total: int
|
||||
limit: int
|
||||
offset: int
|
||||
|
||||
|
||||
class ChunkSearchRequestSchema(BaseModel):
|
||||
"""Request schema for searching similar chunks."""
|
||||
|
||||
query: str = Field(..., description="Search query text")
|
||||
limit: int = Field(default=5, ge=1, le=50, description="Maximum number of results")
|
||||
document_uuids: Optional[List[str]] = Field(
|
||||
default=None, description="Filter by specific document UUIDs"
|
||||
)
|
||||
min_similarity: Optional[float] = Field(
|
||||
default=None, ge=0.0, le=1.0, description="Minimum similarity threshold"
|
||||
)
|
||||
|
||||
|
||||
class ChunkResponseSchema(BaseModel):
|
||||
"""Response schema for a document chunk."""
|
||||
|
||||
id: int
|
||||
document_id: int
|
||||
chunk_text: str
|
||||
contextualized_text: Optional[str]
|
||||
chunk_index: int
|
||||
chunk_metadata: Dict[str, Any]
|
||||
filename: str
|
||||
document_uuid: str
|
||||
similarity: float
|
||||
|
||||
|
||||
class ChunkSearchResponseSchema(BaseModel):
|
||||
"""Response schema for chunk search results."""
|
||||
|
||||
chunks: List[ChunkResponseSchema]
|
||||
query: str
|
||||
total_results: int
|
||||
|
|
@ -3,6 +3,7 @@ from datetime import datetime
|
|||
from pydantic import BaseModel
|
||||
|
||||
from api.services.configuration.registry import (
|
||||
EmbeddingsConfig,
|
||||
LLMConfig,
|
||||
STTConfig,
|
||||
TTSConfig,
|
||||
|
|
@ -13,6 +14,7 @@ class UserConfiguration(BaseModel):
|
|||
llm: LLMConfig | None = None
|
||||
stt: STTConfig | None = None
|
||||
tts: TTSConfig | None = None
|
||||
embeddings: EmbeddingsConfig | None = None
|
||||
test_phone_number: str | None = None
|
||||
timezone: str | None = None
|
||||
last_validated_at: datetime | None = None
|
||||
|
|
|
|||
|
|
@ -48,6 +48,12 @@ class UserConfigurationValidator:
|
|||
status_list.extend(self._validate_service(configuration.llm, "llm"))
|
||||
status_list.extend(self._validate_service(configuration.stt, "stt"))
|
||||
status_list.extend(self._validate_service(configuration.tts, "tts"))
|
||||
# Embeddings is optional - only validate if configured
|
||||
status_list.extend(
|
||||
self._validate_service(
|
||||
configuration.embeddings, "embeddings", required=False
|
||||
)
|
||||
)
|
||||
|
||||
if status_list:
|
||||
raise ValueError(status_list)
|
||||
|
|
@ -55,11 +61,16 @@ class UserConfigurationValidator:
|
|||
return {"status": [{"model": "all", "message": "ok"}]}
|
||||
|
||||
def _validate_service(
|
||||
self, service_config: Optional[ServiceConfig], service_name: str
|
||||
self,
|
||||
service_config: Optional[ServiceConfig],
|
||||
service_name: str,
|
||||
required: bool = True,
|
||||
) -> list[APIKeyStatus]:
|
||||
"""Validate a service configuration and return any error statuses."""
|
||||
if not service_config:
|
||||
return [{"model": service_name, "message": "API key is missing"}]
|
||||
if required:
|
||||
return [{"model": service_name, "message": "API key is missing"}]
|
||||
return [] # Optional service not configured is OK
|
||||
|
||||
provider = service_config.provider
|
||||
api_key = service_config.api_key
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ left as ``None``.
|
|||
from api.services.configuration.registry import (
|
||||
DeepgramSTTConfiguration,
|
||||
ElevenlabsTTSConfiguration,
|
||||
OpenAIEmbeddingsConfiguration,
|
||||
OpenAILLMService,
|
||||
ServiceProviders,
|
||||
)
|
||||
|
|
@ -22,6 +23,7 @@ _DEFAULTS = {
|
|||
"llm": (ServiceProviders.OPENAI, OpenAILLMService),
|
||||
"tts": (ServiceProviders.ELEVENLABS, ElevenlabsTTSConfiguration),
|
||||
"stt": (ServiceProviders.DEEPGRAM, DeepgramSTTConfiguration),
|
||||
"embeddings": (ServiceProviders.OPENAI, OpenAIEmbeddingsConfiguration),
|
||||
}
|
||||
|
||||
# Public mapping of service name -> default provider
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ def mask_user_config(config: UserConfiguration) -> Dict[str, Any]:
|
|||
"llm": _mask_service(config.llm),
|
||||
"tts": _mask_service(config.tts),
|
||||
"stt": _mask_service(config.stt),
|
||||
"embeddings": _mask_service(config.embeddings),
|
||||
"test_phone_number": config.test_phone_number,
|
||||
"timezone": config.timezone,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing import Dict
|
|||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.configuration.masking import is_mask_of
|
||||
|
||||
SERVICE_FIELDS = ("llm", "tts", "stt")
|
||||
SERVICE_FIELDS = ("llm", "tts", "stt", "embeddings")
|
||||
|
||||
|
||||
def merge_user_configurations(
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ class ServiceType(Enum):
|
|||
LLM = auto()
|
||||
TTS = auto()
|
||||
STT = auto()
|
||||
EMBEDDINGS = auto()
|
||||
|
||||
|
||||
class ServiceProviders(str, Enum):
|
||||
|
|
@ -50,11 +51,16 @@ class BaseSTTConfiguration(BaseServiceConfiguration):
|
|||
model: str
|
||||
|
||||
|
||||
class BaseEmbeddingsConfiguration(BaseServiceConfiguration):
|
||||
model: str
|
||||
|
||||
|
||||
# Unified registry for all service types
|
||||
REGISTRY: Dict[ServiceType, Dict[str, Type[BaseServiceConfiguration]]] = {
|
||||
ServiceType.LLM: {},
|
||||
ServiceType.TTS: {},
|
||||
ServiceType.STT: {},
|
||||
ServiceType.EMBEDDINGS: {},
|
||||
}
|
||||
|
||||
T = TypeVar("T", bound=BaseServiceConfiguration)
|
||||
|
|
@ -93,6 +99,10 @@ def register_stt(cls: Type[BaseSTTConfiguration]):
|
|||
return register_service(ServiceType.STT)(cls)
|
||||
|
||||
|
||||
def register_embeddings(cls: Type[BaseEmbeddingsConfiguration]):
|
||||
return register_service(ServiceType.EMBEDDINGS)(cls)
|
||||
|
||||
|
||||
###################################################### LLM ########################################################################
|
||||
|
||||
# Suggested models for each provider (used for UI dropdown)
|
||||
|
|
@ -436,6 +446,27 @@ STTConfig = Annotated[
|
|||
Field(discriminator="provider"),
|
||||
]
|
||||
|
||||
ServiceConfig = Annotated[
|
||||
Union[LLMConfig, TTSConfig, STTConfig], Field(discriminator="provider")
|
||||
###################################################### EMBEDDINGS ########################################################################
|
||||
|
||||
OPENAI_EMBEDDING_MODELS = ["text-embedding-3-small"]
|
||||
|
||||
|
||||
@register_embeddings
|
||||
class OpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
|
||||
provider: Literal[ServiceProviders.OPENAI] = ServiceProviders.OPENAI
|
||||
model: str = Field(
|
||||
default="text-embedding-3-small",
|
||||
json_schema_extra={"examples": OPENAI_EMBEDDING_MODELS},
|
||||
)
|
||||
api_key: str
|
||||
|
||||
|
||||
EmbeddingsConfig = Annotated[
|
||||
Union[OpenAIEmbeddingsConfiguration],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
||||
ServiceConfig = Annotated[
|
||||
Union[LLMConfig, TTSConfig, STTConfig, EmbeddingsConfig],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -85,3 +85,16 @@ class BaseFileSystem(ABC):
|
|||
Optional[str]: Presigned PUT URL if successful, None otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def adownload_file(self, source_path: str, local_path: str) -> bool:
|
||||
"""Download a file from storage to local path.
|
||||
|
||||
Args:
|
||||
source_path: Path to the file in storage
|
||||
local_path: Local path where file should be downloaded
|
||||
|
||||
Returns:
|
||||
bool: True if file was downloaded successfully, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -170,3 +170,15 @@ class MinioFileSystem(BaseFileSystem):
|
|||
except Exception as e:
|
||||
logger.error(f"Error generating MinIO upload URL: {e}")
|
||||
return None
|
||||
|
||||
async def adownload_file(self, source_path: str, local_path: str) -> bool:
|
||||
"""Download a file from MinIO to local path."""
|
||||
try:
|
||||
|
||||
def _fget():
|
||||
self.client.fget_object(self.bucket_name, source_path, local_path)
|
||||
|
||||
await asyncio.to_thread(_fget)
|
||||
return True
|
||||
except S3Error:
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -126,3 +126,14 @@ class S3FileSystem(BaseFileSystem):
|
|||
return url
|
||||
except ClientError:
|
||||
return None
|
||||
|
||||
async def adownload_file(self, source_path: str, local_path: str) -> bool:
|
||||
"""Download a file from S3 to local path."""
|
||||
try:
|
||||
async with self.session.client(
|
||||
"s3", region_name=self.region_name
|
||||
) as s3_client:
|
||||
await s3_client.download_file(self.bucket_name, source_path, local_path)
|
||||
return True
|
||||
except ClientError:
|
||||
return False
|
||||
|
|
|
|||
15
api/services/gen_ai/__init__.py
Normal file
15
api/services/gen_ai/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
"""Generative AI services for embeddings and document processing."""
|
||||
|
||||
from .embedding import (
|
||||
BaseEmbeddingService,
|
||||
EmbeddingAPIKeyNotConfiguredError,
|
||||
OpenAIEmbeddingService,
|
||||
SentenceTransformerEmbeddingService,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseEmbeddingService",
|
||||
"EmbeddingAPIKeyNotConfiguredError",
|
||||
"SentenceTransformerEmbeddingService",
|
||||
"OpenAIEmbeddingService",
|
||||
]
|
||||
12
api/services/gen_ai/embedding/__init__.py
Normal file
12
api/services/gen_ai/embedding/__init__.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
"""Embedding services for document processing and retrieval."""
|
||||
|
||||
from .base import BaseEmbeddingService
|
||||
from .openai_service import EmbeddingAPIKeyNotConfiguredError, OpenAIEmbeddingService
|
||||
from .sentence_transformer_service import SentenceTransformerEmbeddingService
|
||||
|
||||
__all__ = [
|
||||
"BaseEmbeddingService",
|
||||
"EmbeddingAPIKeyNotConfiguredError",
|
||||
"SentenceTransformerEmbeddingService",
|
||||
"OpenAIEmbeddingService",
|
||||
]
|
||||
75
api/services/gen_ai/embedding/base.py
Normal file
75
api/services/gen_ai/embedding/base.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
"""Base class for embedding services."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class BaseEmbeddingService(ABC):
|
||||
"""Abstract base class for embedding services.
|
||||
|
||||
All embedding services (SentenceTransformer, OpenAI, etc.) should inherit from this class
|
||||
and implement the required methods.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_model_id(self) -> str:
|
||||
"""Return the model identifier.
|
||||
|
||||
Returns:
|
||||
String identifier for the model (e.g., 'sentence-transformers/all-MiniLM-L6-v2')
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_embedding_dimension(self) -> int:
|
||||
"""Return the embedding dimension.
|
||||
|
||||
Returns:
|
||||
Integer dimension of the embedding vectors
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a batch of texts.
|
||||
|
||||
Args:
|
||||
texts: List of text strings to embed
|
||||
|
||||
Returns:
|
||||
List of embedding vectors (each vector is a list of floats)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def embed_query(self, query: str) -> List[float]:
|
||||
"""Embed a single query text.
|
||||
|
||||
Args:
|
||||
query: Query text to embed
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def search_similar_chunks(
|
||||
self,
|
||||
query: str,
|
||||
organization_id: int,
|
||||
limit: int = 5,
|
||||
document_uuids: Optional[List[str]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search for similar chunks using vector similarity.
|
||||
|
||||
Args:
|
||||
query: Search query text
|
||||
organization_id: Organization ID for scoping
|
||||
limit: Maximum number of results to return
|
||||
document_uuids: Optional list of document UUIDs to filter by
|
||||
|
||||
Returns:
|
||||
List of dictionaries containing chunk data and similarity scores
|
||||
"""
|
||||
pass
|
||||
372
api/services/gen_ai/embedding/openai_service.py
Normal file
372
api/services/gen_ai/embedding/openai_service.py
Normal file
|
|
@ -0,0 +1,372 @@
|
|||
"""OpenAI embedding service.
|
||||
|
||||
This module provides document processing capabilities using:
|
||||
- OpenAI's text-embedding-3-small for embeddings (1536 dimensions)
|
||||
- Docling for document conversion and chunking
|
||||
- pgvector for vector similarity search
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from docling.chunking import HybridChunker
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
|
||||
from loguru import logger
|
||||
from openai import AsyncOpenAI
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from api.db.db_client import DBClient
|
||||
from api.db.models import KnowledgeBaseChunkModel
|
||||
|
||||
from .base import BaseEmbeddingService
|
||||
|
||||
# Model configuration
|
||||
DEFAULT_MODEL_ID = "text-embedding-3-small"
|
||||
EMBEDDING_DIMENSION = 1536 # Dimension for text-embedding-3-small
|
||||
|
||||
# For chunking, we'll use the same tokenizer as SentenceTransformer
|
||||
# since OpenAI uses similar tokenization
|
||||
TOKENIZER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
|
||||
class EmbeddingAPIKeyNotConfiguredError(Exception):
|
||||
"""Raised when OpenAI API key is not configured for embeddings."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
"OpenAI API key not configured. Please set your API key in "
|
||||
"Model Configurations > Embedding to use document processing."
|
||||
)
|
||||
|
||||
|
||||
class OpenAIEmbeddingService(BaseEmbeddingService):
|
||||
"""Embedding service using OpenAI's text-embedding-3-small."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_client: DBClient,
|
||||
api_key: Optional[str] = None,
|
||||
model_id: str = DEFAULT_MODEL_ID,
|
||||
max_tokens: int = 512,
|
||||
):
|
||||
"""Initialize the OpenAI embedding service.
|
||||
|
||||
Args:
|
||||
db_client: Database client for storing documents and chunks
|
||||
api_key: OpenAI API key. If not provided, the client will not be
|
||||
initialized and operations will fail with a clear error.
|
||||
model_id: OpenAI embedding model ID (default: text-embedding-3-small)
|
||||
max_tokens: Maximum number of tokens per chunk (default: 512)
|
||||
"""
|
||||
self.db = db_client
|
||||
self.model_id = model_id
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
# Only initialize OpenAI client if API key is provided
|
||||
self._api_key_configured = bool(api_key)
|
||||
if self._api_key_configured:
|
||||
self.client = AsyncOpenAI(api_key=api_key)
|
||||
logger.info(f"OpenAI embedding service initialized with model: {model_id}")
|
||||
else:
|
||||
self.client = None
|
||||
logger.warning(
|
||||
"OpenAI embedding service initialized without API key. "
|
||||
"Operations will fail until API key is configured in Model Configurations."
|
||||
)
|
||||
|
||||
# Initialize tokenizer for chunking
|
||||
# We use a HuggingFace tokenizer for consistent chunking
|
||||
logger.info(
|
||||
f"Loading tokenizer for chunking: {TOKENIZER_MODEL} with max_tokens={max_tokens}"
|
||||
)
|
||||
try:
|
||||
self.tokenizer = HuggingFaceTokenizer(
|
||||
tokenizer=AutoTokenizer.from_pretrained(
|
||||
TOKENIZER_MODEL,
|
||||
local_files_only=True,
|
||||
),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
logger.info("Loaded tokenizer from cache")
|
||||
except Exception as e:
|
||||
logger.warning(f"Tokenizer not in cache, downloading: {e}")
|
||||
self.tokenizer = HuggingFaceTokenizer(
|
||||
tokenizer=AutoTokenizer.from_pretrained(TOKENIZER_MODEL),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
logger.info("Tokenizer downloaded and cached")
|
||||
|
||||
# Initialize chunker
|
||||
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
|
||||
self.chunker = HybridChunker(tokenizer=self.tokenizer)
|
||||
|
||||
# Initialize document converter
|
||||
self.converter = DocumentConverter()
|
||||
|
||||
def get_model_id(self) -> str:
|
||||
"""Return the model identifier."""
|
||||
return self.model_id
|
||||
|
||||
def get_embedding_dimension(self) -> int:
|
||||
"""Return the embedding dimension."""
|
||||
return EMBEDDING_DIMENSION
|
||||
|
||||
def _ensure_api_key_configured(self):
|
||||
"""Check if API key is configured and raise error if not."""
|
||||
if not self._api_key_configured or self.client is None:
|
||||
raise EmbeddingAPIKeyNotConfiguredError()
|
||||
|
||||
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a batch of texts using OpenAI API.
|
||||
|
||||
Args:
|
||||
texts: List of text strings to embed
|
||||
|
||||
Returns:
|
||||
List of embedding vectors (each vector is a list of floats)
|
||||
|
||||
Raises:
|
||||
EmbeddingAPIKeyNotConfiguredError: If API key is not configured
|
||||
"""
|
||||
self._ensure_api_key_configured()
|
||||
|
||||
try:
|
||||
# OpenAI API call
|
||||
response = await self.client.embeddings.create(
|
||||
input=texts,
|
||||
model=self.model_id,
|
||||
)
|
||||
|
||||
# Extract embeddings from response
|
||||
embeddings = [item.embedding for item in response.data]
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating OpenAI embeddings: {e}")
|
||||
raise
|
||||
|
||||
async def embed_query(self, query: str) -> List[float]:
|
||||
"""Embed a single query text using OpenAI API.
|
||||
|
||||
Args:
|
||||
query: Query text to embed
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats
|
||||
|
||||
Raises:
|
||||
EmbeddingAPIKeyNotConfiguredError: If API key is not configured
|
||||
"""
|
||||
self._ensure_api_key_configured()
|
||||
embeddings = await self.embed_texts([query])
|
||||
return embeddings[0]
|
||||
|
||||
async def search_similar_chunks(
|
||||
self,
|
||||
query: str,
|
||||
organization_id: int,
|
||||
limit: int = 5,
|
||||
document_uuids: Optional[List[str]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search for similar chunks using vector similarity.
|
||||
|
||||
Args:
|
||||
query: Search query text
|
||||
organization_id: Organization ID for scoping
|
||||
limit: Maximum number of results to return
|
||||
document_uuids: Optional list of document UUIDs to filter by
|
||||
|
||||
Returns:
|
||||
List of dictionaries with chunk data and similarity scores
|
||||
|
||||
Raises:
|
||||
EmbeddingAPIKeyNotConfiguredError: If API key is not configured
|
||||
"""
|
||||
self._ensure_api_key_configured()
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await self.embed_query(query)
|
||||
|
||||
# Perform vector similarity search
|
||||
results = await self.db.search_similar_chunks(
|
||||
query_embedding=query_embedding,
|
||||
organization_id=organization_id,
|
||||
limit=limit,
|
||||
document_uuids=document_uuids,
|
||||
embedding_model=self.model_id,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def process_document(
|
||||
self,
|
||||
file_path: str,
|
||||
organization_id: int,
|
||||
created_by: int,
|
||||
custom_metadata: dict = None,
|
||||
):
|
||||
"""Process a document: convert, chunk, embed, and store in database.
|
||||
|
||||
Args:
|
||||
file_path: Path to the document file
|
||||
organization_id: Organization ID for scoping
|
||||
created_by: User ID who uploaded the document
|
||||
custom_metadata: Optional custom metadata dictionary
|
||||
|
||||
Returns:
|
||||
The created document record
|
||||
"""
|
||||
try:
|
||||
# Extract file metadata
|
||||
filename = Path(file_path).name
|
||||
file_hash = self.db.compute_file_hash(file_path)
|
||||
file_size = os.path.getsize(file_path)
|
||||
mime_type = self.db.get_mime_type(file_path)
|
||||
|
||||
# Check if document already exists
|
||||
existing_doc = await self.db.get_document_by_hash(
|
||||
file_hash, organization_id
|
||||
)
|
||||
if existing_doc:
|
||||
logger.info(f"Document already exists: {filename} (hash: {file_hash})")
|
||||
return existing_doc
|
||||
|
||||
# Create document record
|
||||
doc_record = await self.db.create_document(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
filename=filename,
|
||||
file_size_bytes=file_size,
|
||||
file_hash=file_hash,
|
||||
mime_type=mime_type,
|
||||
custom_metadata=custom_metadata or {},
|
||||
)
|
||||
|
||||
logger.info(f"Processing document with OpenAI embeddings: {filename}")
|
||||
|
||||
# Update status to processing
|
||||
await self.db.update_document_status(doc_record.id, "processing")
|
||||
|
||||
# Step 1: Convert document using docling
|
||||
logger.info("Converting document with docling...")
|
||||
conversion_result = self.converter.convert(file_path)
|
||||
doc = conversion_result.document
|
||||
|
||||
# Store docling metadata
|
||||
docling_metadata = {
|
||||
"num_pages": len(doc.pages) if hasattr(doc, "pages") else None,
|
||||
"document_type": type(doc).__name__,
|
||||
}
|
||||
|
||||
# Step 2: Chunk the document
|
||||
logger.info(f"Chunking document with max_tokens={self.max_tokens}...")
|
||||
chunks = list(self.chunker.chunk(dl_doc=doc))
|
||||
total_chunks = len(chunks)
|
||||
|
||||
logger.info(f"Generated {total_chunks} chunks")
|
||||
|
||||
# Step 3: Process each chunk
|
||||
chunk_texts = []
|
||||
chunk_records = []
|
||||
token_counts = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Get chunk text
|
||||
chunk_text = chunk.text
|
||||
|
||||
# Get contextualized text
|
||||
contextualized_text = self.chunker.contextualize(chunk=chunk)
|
||||
|
||||
# Calculate token count
|
||||
text_to_tokenize = (
|
||||
contextualized_text if contextualized_text else chunk_text
|
||||
)
|
||||
token_count = len(
|
||||
self.tokenizer.tokenizer.encode(
|
||||
text_to_tokenize, add_special_tokens=False
|
||||
)
|
||||
)
|
||||
token_counts.append(token_count)
|
||||
|
||||
# Prepare chunk metadata
|
||||
chunk_metadata = {}
|
||||
if hasattr(chunk, "meta") and chunk.meta:
|
||||
chunk_metadata = {
|
||||
"doc_items": (
|
||||
[str(item) for item in chunk.meta.doc_items]
|
||||
if hasattr(chunk.meta, "doc_items")
|
||||
else []
|
||||
),
|
||||
"headings": (
|
||||
chunk.meta.headings
|
||||
if hasattr(chunk.meta, "headings")
|
||||
else []
|
||||
),
|
||||
}
|
||||
|
||||
# Create chunk record (without embedding yet)
|
||||
chunk_record = KnowledgeBaseChunkModel(
|
||||
document_id=doc_record.id,
|
||||
organization_id=organization_id,
|
||||
chunk_text=chunk_text,
|
||||
contextualized_text=contextualized_text,
|
||||
chunk_index=i,
|
||||
chunk_metadata=chunk_metadata,
|
||||
embedding_model=self.model_id,
|
||||
embedding_dimension=EMBEDDING_DIMENSION,
|
||||
token_count=token_count,
|
||||
)
|
||||
|
||||
chunk_records.append(chunk_record)
|
||||
chunk_texts.append(text_to_tokenize)
|
||||
|
||||
# Log chunk statistics
|
||||
if token_counts:
|
||||
avg_tokens = sum(token_counts) / len(token_counts)
|
||||
min_tokens = min(token_counts)
|
||||
max_tokens = max(token_counts)
|
||||
logger.info("Chunk token statistics:")
|
||||
logger.info(f" - Average: {avg_tokens:.1f} tokens")
|
||||
logger.info(f" - Min: {min_tokens} tokens")
|
||||
logger.info(f" - Max: {max_tokens} tokens")
|
||||
|
||||
# Step 4: Generate embeddings using OpenAI API
|
||||
logger.info(f"Generating embeddings using OpenAI ({self.model_id})...")
|
||||
embeddings = await self.embed_texts(chunk_texts)
|
||||
|
||||
# Step 5: Attach embeddings to chunk records
|
||||
for chunk_record, embedding in zip(chunk_records, embeddings):
|
||||
chunk_record.embedding = embedding
|
||||
|
||||
# Step 6: Save all chunks in batch
|
||||
logger.info("Storing chunks in database...")
|
||||
await self.db.create_chunks_batch(chunk_records)
|
||||
|
||||
# Update document status to completed
|
||||
await self.db.update_document_status(
|
||||
doc_record.id,
|
||||
"completed",
|
||||
total_chunks=total_chunks,
|
||||
docling_metadata=docling_metadata,
|
||||
)
|
||||
|
||||
logger.info(f"Successfully processed document: {filename}")
|
||||
logger.info(f" - Total chunks: {total_chunks}")
|
||||
logger.info(f" - Embedding model: {self.model_id}")
|
||||
logger.info(f" - Document ID: {doc_record.id}")
|
||||
logger.info(f" - Document UUID: {doc_record.document_uuid}")
|
||||
|
||||
return doc_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing document with OpenAI: {e}")
|
||||
|
||||
# Update document status to failed if it exists
|
||||
if "doc_record" in locals():
|
||||
await self.db.update_document_status(
|
||||
doc_record.id, "failed", error_message=str(e)
|
||||
)
|
||||
|
||||
raise
|
||||
350
api/services/gen_ai/embedding/sentence_transformer_service.py
Normal file
350
api/services/gen_ai/embedding/sentence_transformer_service.py
Normal file
|
|
@ -0,0 +1,350 @@
|
|||
"""Sentence Transformer embedding service.
|
||||
|
||||
This module provides document processing capabilities using:
|
||||
- Sentence-transformers for embeddings (all-MiniLM-L6-v2)
|
||||
- Docling for document conversion and chunking
|
||||
- pgvector for vector similarity search
|
||||
|
||||
Setup for offline usage:
|
||||
1. First run: Downloads and caches models to ~/.cache/sentence_transformers
|
||||
2. Subsequent runs: Uses cached models (no internet needed)
|
||||
3. For fully offline mode: Set TRANSFORMERS_OFFLINE=1 and HF_HUB_OFFLINE=1
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from docling.chunking import HybridChunker
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
|
||||
from loguru import logger
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from api.db.db_client import DBClient
|
||||
from api.db.models import KnowledgeBaseChunkModel
|
||||
|
||||
from .base import BaseEmbeddingService
|
||||
|
||||
# Set environment variables for model caching
|
||||
os.environ.setdefault("TRANSFORMERS_OFFLINE", "0")
|
||||
os.environ.setdefault("HF_HUB_OFFLINE", "0")
|
||||
os.environ.setdefault(
|
||||
"SENTENCE_TRANSFORMERS_HOME", os.path.expanduser("~/.cache/sentence_transformers")
|
||||
)
|
||||
|
||||
# Model configuration
|
||||
DEFAULT_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
EMBEDDING_DIMENSION = 384 # Dimension for all-MiniLM-L6-v2
|
||||
|
||||
|
||||
class SentenceTransformerEmbeddingService(BaseEmbeddingService):
|
||||
"""Embedding service using Sentence Transformers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_client: DBClient,
|
||||
model_id: str = DEFAULT_MODEL_ID,
|
||||
max_tokens: int = 512,
|
||||
):
|
||||
"""Initialize the Sentence Transformer embedding service.
|
||||
|
||||
Args:
|
||||
db_client: Database client for storing documents and chunks
|
||||
model_id: Sentence-transformers model ID (default: all-MiniLM-L6-v2)
|
||||
max_tokens: Maximum number of tokens per chunk (default: 512)
|
||||
Note: This applies to the contextualized text (with headings/captions)
|
||||
"""
|
||||
self.db = db_client
|
||||
self.model_id = model_id
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
# Initialize embedding model
|
||||
logger.info(f"Loading embedding model: {model_id}")
|
||||
try:
|
||||
# Try to load from cache first (local_files_only=True)
|
||||
self.embedding_model = SentenceTransformer(
|
||||
model_id,
|
||||
cache_folder=os.environ.get("SENTENCE_TRANSFORMERS_HOME"),
|
||||
local_files_only=True,
|
||||
)
|
||||
logger.info("Loaded model from cache")
|
||||
except Exception as e:
|
||||
logger.warning(f"Model not in cache, downloading: {e}")
|
||||
# If not in cache, download it (this will cache it for next time)
|
||||
self.embedding_model = SentenceTransformer(
|
||||
model_id,
|
||||
cache_folder=os.environ.get("SENTENCE_TRANSFORMERS_HOME"),
|
||||
)
|
||||
logger.info("Model downloaded and cached")
|
||||
|
||||
# Initialize tokenizer for chunking with max_tokens
|
||||
logger.info(f"Loading tokenizer: {model_id} with max_tokens={max_tokens}")
|
||||
try:
|
||||
# Try to load from cache first
|
||||
self.tokenizer = HuggingFaceTokenizer(
|
||||
tokenizer=AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
local_files_only=True,
|
||||
),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
logger.info("Loaded tokenizer from cache")
|
||||
except Exception as e:
|
||||
logger.warning(f"Tokenizer not in cache, downloading: {e}")
|
||||
# If not in cache, download it
|
||||
self.tokenizer = HuggingFaceTokenizer(
|
||||
tokenizer=AutoTokenizer.from_pretrained(model_id),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
logger.info("Tokenizer downloaded and cached")
|
||||
|
||||
# Initialize chunker
|
||||
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
|
||||
self.chunker = HybridChunker(tokenizer=self.tokenizer)
|
||||
|
||||
# Initialize document converter
|
||||
self.converter = DocumentConverter()
|
||||
|
||||
def get_model_id(self) -> str:
|
||||
"""Return the model identifier."""
|
||||
return self.model_id
|
||||
|
||||
def get_embedding_dimension(self) -> int:
|
||||
"""Return the embedding dimension."""
|
||||
return EMBEDDING_DIMENSION
|
||||
|
||||
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a batch of texts.
|
||||
|
||||
Args:
|
||||
texts: List of text strings to embed
|
||||
|
||||
Returns:
|
||||
List of embedding vectors (each vector is a list of floats)
|
||||
"""
|
||||
embeddings = self.embedding_model.encode(
|
||||
texts,
|
||||
show_progress_bar=False,
|
||||
convert_to_numpy=True,
|
||||
)
|
||||
return [embedding.tolist() for embedding in embeddings]
|
||||
|
||||
async def embed_query(self, query: str) -> List[float]:
|
||||
"""Embed a single query text.
|
||||
|
||||
Args:
|
||||
query: Query text to embed
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats
|
||||
"""
|
||||
embedding = self.embedding_model.encode([query])[0]
|
||||
return embedding.tolist()
|
||||
|
||||
async def search_similar_chunks(
|
||||
self,
|
||||
query: str,
|
||||
organization_id: int,
|
||||
limit: int = 5,
|
||||
document_uuids: Optional[List[str]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search for similar chunks using vector similarity.
|
||||
|
||||
Returns top-k most similar chunks without any threshold filtering.
|
||||
Apply similarity thresholds and reranking at the application layer.
|
||||
|
||||
Args:
|
||||
query: Search query text
|
||||
organization_id: Organization ID for scoping
|
||||
limit: Maximum number of results to return
|
||||
document_uuids: Optional list of document UUIDs to filter by
|
||||
|
||||
Returns:
|
||||
List of dictionaries with chunk data and similarity scores
|
||||
"""
|
||||
# Generate query embedding
|
||||
query_embedding = await self.embed_query(query)
|
||||
|
||||
# Perform vector similarity search
|
||||
results = await self.db.search_similar_chunks(
|
||||
query_embedding=query_embedding,
|
||||
organization_id=organization_id,
|
||||
limit=limit,
|
||||
document_uuids=document_uuids,
|
||||
embedding_model=self.model_id,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def process_document(
|
||||
self,
|
||||
file_path: str,
|
||||
organization_id: int,
|
||||
created_by: int,
|
||||
custom_metadata: dict = None,
|
||||
):
|
||||
"""Process a document: convert, chunk, embed, and store in database.
|
||||
|
||||
Args:
|
||||
file_path: Path to the document file
|
||||
organization_id: Organization ID for scoping
|
||||
created_by: User ID who uploaded the document
|
||||
custom_metadata: Optional custom metadata dictionary
|
||||
|
||||
Returns:
|
||||
The created document record
|
||||
"""
|
||||
try:
|
||||
# Extract file metadata
|
||||
filename = Path(file_path).name
|
||||
file_hash = self.db.compute_file_hash(file_path)
|
||||
file_size = os.path.getsize(file_path)
|
||||
mime_type = self.db.get_mime_type(file_path)
|
||||
|
||||
# Check if document already exists
|
||||
existing_doc = await self.db.get_document_by_hash(
|
||||
file_hash, organization_id
|
||||
)
|
||||
if existing_doc:
|
||||
logger.info(f"Document already exists: {filename} (hash: {file_hash})")
|
||||
return existing_doc
|
||||
|
||||
# Create document record
|
||||
doc_record = await self.db.create_document(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
filename=filename,
|
||||
file_size_bytes=file_size,
|
||||
file_hash=file_hash,
|
||||
mime_type=mime_type,
|
||||
custom_metadata=custom_metadata or {},
|
||||
)
|
||||
|
||||
logger.info(f"Processing document: {filename}")
|
||||
|
||||
# Update status to processing
|
||||
await self.db.update_document_status(doc_record.id, "processing")
|
||||
|
||||
# Step 1: Convert document using docling
|
||||
logger.info("Converting document with docling...")
|
||||
conversion_result = self.converter.convert(file_path)
|
||||
doc = conversion_result.document
|
||||
|
||||
# Store docling metadata
|
||||
docling_metadata = {
|
||||
"num_pages": len(doc.pages) if hasattr(doc, "pages") else None,
|
||||
"document_type": type(doc).__name__,
|
||||
}
|
||||
|
||||
# Step 2: Chunk the document
|
||||
logger.info(f"Chunking document with max_tokens={self.max_tokens}...")
|
||||
chunks = list(self.chunker.chunk(dl_doc=doc))
|
||||
total_chunks = len(chunks)
|
||||
|
||||
logger.info(f"Generated {total_chunks} chunks")
|
||||
|
||||
# Step 3: Process each chunk
|
||||
chunk_texts = []
|
||||
chunk_records = []
|
||||
token_counts = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Get chunk text
|
||||
chunk_text = chunk.text
|
||||
|
||||
# Get contextualized text (enriched with surrounding context)
|
||||
contextualized_text = self.chunker.contextualize(chunk=chunk)
|
||||
|
||||
# Calculate actual token count using the tokenizer
|
||||
text_to_tokenize = (
|
||||
contextualized_text if contextualized_text else chunk_text
|
||||
)
|
||||
token_count = len(
|
||||
self.tokenizer.tokenizer.encode(
|
||||
text_to_tokenize, add_special_tokens=False
|
||||
)
|
||||
)
|
||||
token_counts.append(token_count)
|
||||
|
||||
# Prepare chunk metadata
|
||||
chunk_metadata = {}
|
||||
if hasattr(chunk, "meta") and chunk.meta:
|
||||
chunk_metadata = {
|
||||
"doc_items": (
|
||||
[str(item) for item in chunk.meta.doc_items]
|
||||
if hasattr(chunk.meta, "doc_items")
|
||||
else []
|
||||
),
|
||||
"headings": (
|
||||
chunk.meta.headings
|
||||
if hasattr(chunk.meta, "headings")
|
||||
else []
|
||||
),
|
||||
}
|
||||
|
||||
# Create chunk record (without embedding yet)
|
||||
chunk_record = KnowledgeBaseChunkModel(
|
||||
document_id=doc_record.id,
|
||||
organization_id=organization_id,
|
||||
chunk_text=chunk_text,
|
||||
contextualized_text=contextualized_text,
|
||||
chunk_index=i,
|
||||
chunk_metadata=chunk_metadata,
|
||||
embedding_model=self.model_id,
|
||||
embedding_dimension=EMBEDDING_DIMENSION,
|
||||
token_count=token_count,
|
||||
)
|
||||
|
||||
chunk_records.append(chunk_record)
|
||||
# Use contextualized text for embedding if available
|
||||
chunk_texts.append(text_to_tokenize)
|
||||
|
||||
# Log chunk statistics
|
||||
if token_counts:
|
||||
avg_tokens = sum(token_counts) / len(token_counts)
|
||||
min_tokens = min(token_counts)
|
||||
max_tokens = max(token_counts)
|
||||
logger.info("Chunk token statistics:")
|
||||
logger.info(f" - Average: {avg_tokens:.1f} tokens")
|
||||
logger.info(f" - Min: {min_tokens} tokens")
|
||||
logger.info(f" - Max: {max_tokens} tokens")
|
||||
|
||||
# Step 4: Generate embeddings in batch
|
||||
logger.info("Generating embeddings...")
|
||||
embeddings = await self.embed_texts(chunk_texts)
|
||||
|
||||
# Step 5: Attach embeddings to chunk records
|
||||
for chunk_record, embedding in zip(chunk_records, embeddings):
|
||||
chunk_record.embedding = embedding
|
||||
|
||||
# Step 6: Save all chunks in batch
|
||||
logger.info("Storing chunks in database...")
|
||||
await self.db.create_chunks_batch(chunk_records)
|
||||
|
||||
# Update document status to completed
|
||||
await self.db.update_document_status(
|
||||
doc_record.id,
|
||||
"completed",
|
||||
total_chunks=total_chunks,
|
||||
docling_metadata=docling_metadata,
|
||||
)
|
||||
|
||||
logger.info(f"Successfully processed document: {filename}")
|
||||
logger.info(f" - Total chunks: {total_chunks}")
|
||||
logger.info(f" - Document ID: {doc_record.id}")
|
||||
logger.info(f" - Document UUID: {doc_record.document_uuid}")
|
||||
|
||||
return doc_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing document: {e}")
|
||||
|
||||
# Update document status to failed if it exists
|
||||
if "doc_record" in locals():
|
||||
await self.db.update_document_status(
|
||||
doc_record.id, "failed", error_message=str(e)
|
||||
)
|
||||
|
||||
raise
|
||||
44
api/services/pricing/embeddings.py
Normal file
44
api/services/pricing/embeddings.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
"""
|
||||
Embeddings pricing models for different providers.
|
||||
|
||||
Prices are per token for embedding models.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import PricingModel
|
||||
|
||||
|
||||
class EmbeddingPricingModel(PricingModel):
|
||||
"""Pricing model for token-based embedding services."""
|
||||
|
||||
def __init__(self, token_price: Decimal):
|
||||
"""Initialize with price per token.
|
||||
|
||||
Args:
|
||||
token_price: Cost per token for embedding
|
||||
"""
|
||||
self.token_price = token_price
|
||||
|
||||
def calculate_cost(self, token_count: int) -> Decimal:
|
||||
"""Calculate cost for embedding token usage."""
|
||||
return Decimal(token_count) * self.token_price
|
||||
|
||||
|
||||
# Embeddings pricing registry
|
||||
EMBEDDINGS_PRICING: Dict[str, Dict[str, EmbeddingPricingModel]] = {
|
||||
ServiceProviders.OPENAI: {
|
||||
"text-embedding-3-small": EmbeddingPricingModel(
|
||||
token_price=Decimal("0.02") / 1_000_000, # $0.02 per 1M tokens
|
||||
),
|
||||
"text-embedding-3-large": EmbeddingPricingModel(
|
||||
token_price=Decimal("0.13") / 1_000_000, # $0.13 per 1M tokens
|
||||
),
|
||||
"text-embedding-ada-002": EmbeddingPricingModel(
|
||||
token_price=Decimal("0.10") / 1_000_000, # $0.10 per 1M tokens (legacy)
|
||||
),
|
||||
},
|
||||
}
|
||||
|
|
@ -4,6 +4,7 @@ Main pricing registry that combines all service type pricing models.
|
|||
|
||||
from typing import Dict
|
||||
|
||||
from .embeddings import EMBEDDINGS_PRICING
|
||||
from .llm import LLM_PRICING
|
||||
from .stt import STT_PRICING
|
||||
from .tts import TTS_PRICING
|
||||
|
|
@ -13,4 +14,5 @@ PRICING_REGISTRY: Dict = {
|
|||
"llm": LLM_PRICING,
|
||||
"tts": TTS_PRICING,
|
||||
"stt": STT_PRICING,
|
||||
"embeddings": EMBEDDINGS_PRICING,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ class NodeDataDTO(BaseModel):
|
|||
delayed_start: bool = False
|
||||
delayed_start_duration: Optional[float] = None
|
||||
tool_uuids: Optional[List[str]] = None
|
||||
document_uuids: Optional[List[str]] = None
|
||||
trigger_path: Optional[str] = None
|
||||
# Webhook node specific fields
|
||||
enabled: bool = True
|
||||
|
|
|
|||
|
|
@ -41,6 +41,10 @@ from api.services.workflow.pipecat_engine_variable_extractor import (
|
|||
VariableExtractionManager,
|
||||
)
|
||||
from api.services.workflow.tools.calculator import get_calculator_tools, safe_calculator
|
||||
from api.services.workflow.tools.knowledge_base import (
|
||||
get_knowledge_base_tool,
|
||||
retrieve_from_knowledge_base,
|
||||
)
|
||||
from api.services.workflow.tools.timezone import (
|
||||
convert_time,
|
||||
get_current_time,
|
||||
|
|
@ -290,6 +294,48 @@ class PipecatEngine:
|
|||
self.llm.register_function("get_current_time", get_current_time_func)
|
||||
self.llm.register_function("convert_time", convert_time_func)
|
||||
|
||||
async def _register_knowledge_base_function(
|
||||
self, document_uuids: list[str]
|
||||
) -> None:
|
||||
"""Register knowledge base retrieval function with the LLM.
|
||||
|
||||
Args:
|
||||
document_uuids: List of document UUIDs to filter the search by
|
||||
"""
|
||||
logger.debug(
|
||||
f"Registering knowledge base retrieval function with {len(document_uuids)} document(s)"
|
||||
)
|
||||
|
||||
async def retrieve_kb_func(function_call_params: FunctionCallParams) -> None:
|
||||
logger.info("LLM Function Call EXECUTED: retrieve_from_knowledge_base")
|
||||
logger.info(f"Arguments: {function_call_params.arguments}")
|
||||
try:
|
||||
query = function_call_params.arguments.get("query", "")
|
||||
organization_id = await self._get_organization_id()
|
||||
|
||||
if not organization_id:
|
||||
raise ValueError(
|
||||
"Organization ID not available for knowledge base retrieval"
|
||||
)
|
||||
|
||||
result = await retrieve_from_knowledge_base(
|
||||
query=query,
|
||||
organization_id=organization_id,
|
||||
document_uuids=document_uuids,
|
||||
limit=3, # Return top 3 most relevant chunks
|
||||
)
|
||||
|
||||
await function_call_params.result_callback(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Knowledge base retrieval failed: {e}")
|
||||
await function_call_params.result_callback(
|
||||
{"error": str(e), "chunks": [], "query": query, "total_results": 0}
|
||||
)
|
||||
|
||||
# Register the function with the LLM
|
||||
self.llm.register_function("retrieve_from_knowledge_base", retrieve_kb_func)
|
||||
|
||||
async def _perform_variable_extraction_if_needed(
|
||||
self, previous_node: Optional[Node]
|
||||
) -> None:
|
||||
|
|
@ -346,6 +392,10 @@ class PipecatEngine:
|
|||
if node.tool_uuids and self._custom_tool_manager:
|
||||
await self._custom_tool_manager.register_handlers(node.tool_uuids)
|
||||
|
||||
# Register knowledge base retrieval handler if node has documents
|
||||
if node.document_uuids:
|
||||
await self._register_knowledge_base_function(node.document_uuids)
|
||||
|
||||
# Set up system message and functions
|
||||
(
|
||||
system_message,
|
||||
|
|
@ -575,6 +625,17 @@ class PipecatEngine:
|
|||
# Add built-in function schemas (calculator and timezone tools)
|
||||
functions.extend(self.builtin_function_schemas)
|
||||
|
||||
# Add knowledge base retrieval tool if node has documents
|
||||
if node.document_uuids:
|
||||
kb_tool_def = get_knowledge_base_tool(node.document_uuids)
|
||||
kb_schema = get_function_schema(
|
||||
kb_tool_def["function"]["name"],
|
||||
kb_tool_def["function"]["description"],
|
||||
properties=kb_tool_def["function"]["parameters"].get("properties", {}),
|
||||
required=kb_tool_def["function"]["parameters"].get("required", []),
|
||||
)
|
||||
functions.append(kb_schema)
|
||||
|
||||
# Add custom tools from node.tool_uuids
|
||||
if node.tool_uuids and self._custom_tool_manager:
|
||||
custom_tool_schemas = await self._custom_tool_manager.get_tool_schemas(
|
||||
|
|
|
|||
305
api/services/workflow/tools/knowledge_base.py
Normal file
305
api/services/workflow/tools/knowledge_base.py
Normal file
|
|
@ -0,0 +1,305 @@
|
|||
"""Knowledge Base retrieval tool for workflow execution.
|
||||
|
||||
This module provides vector similarity search capabilities for retrieving
|
||||
relevant information from the knowledge base during conversations.
|
||||
|
||||
Implements OpenTelemetry tracing for observability in Langfuse.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from opentelemetry import trace
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.gen_ai import OpenAIEmbeddingService
|
||||
from api.services.pipecat.tracing_config import is_tracing_enabled
|
||||
from pipecat.utils.tracing.context_registry import (
|
||||
get_current_conversation_context,
|
||||
get_current_turn_context,
|
||||
)
|
||||
|
||||
|
||||
async def retrieve_from_knowledge_base(
|
||||
query: str,
|
||||
organization_id: int,
|
||||
document_uuids: Optional[List[str]] = None,
|
||||
limit: int = 3,
|
||||
embeddings_api_key: Optional[str] = None,
|
||||
embeddings_model: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Retrieve relevant information from the knowledge base using vector similarity search.
|
||||
|
||||
Uses OpenAI text-embedding-3-small for embeddings by default. This provides
|
||||
high-quality 1536-dimensional embeddings for accurate retrieval.
|
||||
|
||||
This function includes OpenTelemetry tracing for Langfuse observability.
|
||||
|
||||
Args:
|
||||
query: The search query to find relevant information
|
||||
organization_id: Organization ID for scoping the search
|
||||
document_uuids: Optional list of document UUIDs to filter by
|
||||
limit: Maximum number of chunks to return (default: 3)
|
||||
embeddings_api_key: Optional API key for embedding service
|
||||
embeddings_model: Optional model ID for embedding service
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- chunks: List of relevant text chunks with metadata
|
||||
- query: The original query
|
||||
- total_results: Number of results returned
|
||||
"""
|
||||
# Create span for retrieval operation if tracing is enabled
|
||||
if is_tracing_enabled():
|
||||
try:
|
||||
# Get parent context from turn or conversation
|
||||
turn_context = get_current_turn_context()
|
||||
conversation_context = get_current_conversation_context()
|
||||
parent_context = turn_context or conversation_context
|
||||
|
||||
# Get tracer
|
||||
tracer = trace.get_tracer("pipecat")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to setup tracing context: {e}")
|
||||
# Fall back to non-traced execution
|
||||
return await _perform_retrieval(
|
||||
query,
|
||||
organization_id,
|
||||
document_uuids,
|
||||
limit,
|
||||
embeddings_api_key,
|
||||
embeddings_model,
|
||||
)
|
||||
|
||||
# Create span with parent context
|
||||
if parent_context:
|
||||
with tracer.start_as_current_span(
|
||||
"knowledge_base_retrieval", context=parent_context
|
||||
) as span:
|
||||
try:
|
||||
# Mark trace as public for Langfuse
|
||||
span.set_attribute("langfuse.trace.public", True)
|
||||
|
||||
# Add operation metadata
|
||||
span.set_attribute(
|
||||
"gen_ai.operation.name", "knowledge_base_retrieval"
|
||||
)
|
||||
span.set_attribute("retrieval.query", query)
|
||||
span.set_attribute("retrieval.limit", limit)
|
||||
span.set_attribute("retrieval.organization_id", organization_id)
|
||||
|
||||
# Add document filter info
|
||||
if document_uuids:
|
||||
span.set_attribute(
|
||||
"retrieval.document_count", len(document_uuids)
|
||||
)
|
||||
span.set_attribute(
|
||||
"retrieval.document_uuids", json.dumps(document_uuids)
|
||||
)
|
||||
|
||||
# Perform the actual retrieval
|
||||
result = await _perform_retrieval(
|
||||
query,
|
||||
organization_id,
|
||||
document_uuids,
|
||||
limit,
|
||||
embeddings_api_key,
|
||||
embeddings_model,
|
||||
)
|
||||
|
||||
# Add result metadata to span
|
||||
span.set_attribute(
|
||||
"retrieval.results_count", result["total_results"]
|
||||
)
|
||||
|
||||
if result.get("error"):
|
||||
span.set_attribute("retrieval.error", result["error"])
|
||||
span.set_status(
|
||||
trace.Status(trace.StatusCode.ERROR, result["error"])
|
||||
)
|
||||
else:
|
||||
# Add similarity scores
|
||||
if result["chunks"]:
|
||||
similarities = [
|
||||
chunk["similarity"] for chunk in result["chunks"]
|
||||
]
|
||||
span.set_attribute(
|
||||
"retrieval.avg_similarity",
|
||||
round(sum(similarities) / len(similarities), 4),
|
||||
)
|
||||
span.set_attribute(
|
||||
"retrieval.max_similarity", max(similarities)
|
||||
)
|
||||
span.set_attribute(
|
||||
"retrieval.min_similarity", min(similarities)
|
||||
)
|
||||
|
||||
# Add retrieved documents info
|
||||
filenames = list(
|
||||
set(chunk["filename"] for chunk in result["chunks"])
|
||||
)
|
||||
span.set_attribute(
|
||||
"retrieval.source_files", json.dumps(filenames)
|
||||
)
|
||||
|
||||
# Add output as JSON for Langfuse
|
||||
output_data = {
|
||||
"query": query,
|
||||
"chunks_retrieved": len(result["chunks"]),
|
||||
"chunks": [
|
||||
{
|
||||
"text": chunk["text"][:200] + "..."
|
||||
if len(chunk["text"]) > 200
|
||||
else chunk["text"],
|
||||
"filename": chunk["filename"],
|
||||
"similarity": chunk["similarity"],
|
||||
}
|
||||
for chunk in result["chunks"]
|
||||
],
|
||||
}
|
||||
span.set_attribute("output", json.dumps(output_data))
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in traced retrieval: {e}")
|
||||
span.record_exception(e)
|
||||
span.set_status(trace.Status(trace.StatusCode.ERROR, str(e)))
|
||||
raise
|
||||
else:
|
||||
# No parent context - perform retrieval without tracing
|
||||
logger.debug(
|
||||
"No parent context available for knowledge base retrieval tracing"
|
||||
)
|
||||
return await _perform_retrieval(
|
||||
query,
|
||||
organization_id,
|
||||
document_uuids,
|
||||
limit,
|
||||
embeddings_api_key,
|
||||
embeddings_model,
|
||||
)
|
||||
else:
|
||||
# Tracing is disabled - perform retrieval without tracing
|
||||
return await _perform_retrieval(
|
||||
query,
|
||||
organization_id,
|
||||
document_uuids,
|
||||
limit,
|
||||
embeddings_api_key,
|
||||
embeddings_model,
|
||||
)
|
||||
|
||||
|
||||
async def _perform_retrieval(
|
||||
query: str,
|
||||
organization_id: int,
|
||||
document_uuids: Optional[List[str]],
|
||||
limit: int,
|
||||
embeddings_api_key: Optional[str] = None,
|
||||
embeddings_model: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Internal function to perform the actual retrieval operation.
|
||||
|
||||
Separated from tracing logic for cleaner code organization.
|
||||
Uses OpenAI embeddings by default for high-quality retrieval.
|
||||
"""
|
||||
try:
|
||||
# Create a new embedding service instance
|
||||
# Uses OpenAI text-embedding-3-small by default, or user-provided config
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
max_tokens=128, # This is only used for chunking, not for retrieval
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
)
|
||||
|
||||
# Perform vector similarity search
|
||||
results = await embedding_service.search_similar_chunks(
|
||||
query=query,
|
||||
organization_id=organization_id,
|
||||
limit=limit,
|
||||
document_uuids=document_uuids,
|
||||
)
|
||||
|
||||
# Format results for LLM consumption
|
||||
chunks = []
|
||||
for result in results:
|
||||
chunk_info = {
|
||||
"text": result.get("contextualized_text") or result.get("chunk_text"),
|
||||
"filename": result.get("filename"),
|
||||
"similarity": round(result.get("similarity", 0), 4),
|
||||
"chunk_index": result.get("chunk_index"),
|
||||
}
|
||||
chunks.append(chunk_info)
|
||||
|
||||
logger.info(
|
||||
f"Knowledge base retrieval: query='{query}', "
|
||||
f"results={len(chunks)}, "
|
||||
f"document_filter={document_uuids}"
|
||||
)
|
||||
|
||||
return {
|
||||
"chunks": chunks,
|
||||
"query": query,
|
||||
"total_results": len(chunks),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving from knowledge base: {e}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"chunks": [],
|
||||
"query": query,
|
||||
"total_results": 0,
|
||||
}
|
||||
|
||||
|
||||
def get_knowledge_base_tool(
|
||||
document_uuids: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get knowledge base retrieval tool definition for LLM function calling.
|
||||
|
||||
Args:
|
||||
document_uuids: Optional list of document UUIDs to include in description
|
||||
|
||||
Returns:
|
||||
Tool definition compatible with LLM function calling
|
||||
"""
|
||||
# Build description based on whether specific documents are filtered
|
||||
if document_uuids and len(document_uuids) > 0:
|
||||
description = (
|
||||
"Retrieve relevant information from specific documents in the knowledge base. "
|
||||
"Use this tool when you need to look up facts, policies, procedures, or any information "
|
||||
"that might be stored in the available documents. The search will only look in the "
|
||||
f"documents associated with this conversation step ({len(document_uuids)} document(s) available)."
|
||||
)
|
||||
else:
|
||||
description = (
|
||||
"Retrieve relevant information from the knowledge base. "
|
||||
"Use this tool when you need to look up facts, policies, procedures, or any information "
|
||||
"that might be stored in the knowledge base documents."
|
||||
)
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "retrieve_from_knowledge_base",
|
||||
"description": description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The search query to find relevant information. "
|
||||
"Be specific and use natural language. "
|
||||
"Example: 'What is the refund policy for canceled orders?'"
|
||||
),
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -48,6 +48,7 @@ class Node:
|
|||
self.delayed_start = data.delayed_start
|
||||
self.delayed_start_duration = data.delayed_start_duration
|
||||
self.tool_uuids = data.tool_uuids
|
||||
self.document_uuids = data.document_uuids
|
||||
|
||||
self.data = data
|
||||
|
||||
|
|
@ -189,16 +190,6 @@ class WorkflowGraph:
|
|||
in_d, out_d = in_deg[n.id], out_deg[n.id]
|
||||
|
||||
match n.node_type:
|
||||
case NodeType.startNode:
|
||||
if in_d != 0 or out_d < 1:
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.node,
|
||||
id=n.id,
|
||||
field=None,
|
||||
message=f"StartNode must have at least 1 outgoing edge",
|
||||
)
|
||||
)
|
||||
case NodeType.endNode:
|
||||
if in_d < 1 or out_d != 0:
|
||||
errors.append(
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ from api.tasks.campaign_tasks import (
|
|||
process_campaign_batch,
|
||||
sync_campaign_source,
|
||||
)
|
||||
from api.tasks.knowledge_base_processing import process_knowledge_base_document
|
||||
from api.tasks.run_integrations import run_integrations_post_workflow_run
|
||||
from api.tasks.s3_upload import (
|
||||
upload_audio_to_s3,
|
||||
|
|
@ -64,6 +65,7 @@ class WorkerSettings:
|
|||
sync_campaign_source,
|
||||
process_campaign_batch,
|
||||
monitor_campaign_progress,
|
||||
process_knowledge_base_document,
|
||||
]
|
||||
cron_jobs = []
|
||||
redis_settings = REDIS_SETTINGS
|
||||
|
|
|
|||
|
|
@ -7,3 +7,4 @@ class FunctionNames:
|
|||
SYNC_CAMPAIGN_SOURCE = "sync_campaign_source"
|
||||
PROCESS_CAMPAIGN_BATCH = "process_campaign_batch"
|
||||
MONITOR_CAMPAIGN_PROGRESS = "monitor_campaign_progress"
|
||||
PROCESS_KNOWLEDGE_BASE_DOCUMENT = "process_knowledge_base_document"
|
||||
|
|
|
|||
311
api/tasks/knowledge_base_processing.py
Normal file
311
api/tasks/knowledge_base_processing.py
Normal file
|
|
@ -0,0 +1,311 @@
|
|||
"""ARQ background task for processing knowledge base documents."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Literal
|
||||
|
||||
from docling.chunking import HybridChunker
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
|
||||
from loguru import logger
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import KnowledgeBaseChunkModel
|
||||
from api.services.gen_ai import (
|
||||
OpenAIEmbeddingService,
|
||||
SentenceTransformerEmbeddingService,
|
||||
)
|
||||
from api.services.storage import storage_fs
|
||||
|
||||
# For tokenization/chunking - use SentenceTransformer tokenizer as baseline
|
||||
TOKENIZER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
|
||||
async def process_knowledge_base_document(
|
||||
ctx,
|
||||
document_id: int,
|
||||
s3_key: str,
|
||||
organization_id: int,
|
||||
max_tokens: int = 128,
|
||||
embedding_service: Literal["sentence_transformer", "openai"] = "openai",
|
||||
):
|
||||
"""Process a knowledge base document: download, chunk, embed, and store.
|
||||
|
||||
Args:
|
||||
ctx: ARQ context
|
||||
document_id: Database ID of the document
|
||||
s3_key: S3 key where the file is stored
|
||||
organization_id: Organization ID
|
||||
max_tokens: Maximum number of tokens per chunk (default: 128)
|
||||
embedding_service: Embedding service to use (default: "openai")
|
||||
- "openai": Use OpenAI text-embedding-3-small (1536-dim, requires API key)
|
||||
- "sentence_transformer": Use SentenceTransformer (all-MiniLM-L6-v2, 384-dim, free)
|
||||
"""
|
||||
logger.info(
|
||||
f"Starting knowledge base document processing for document_id={document_id}, "
|
||||
f"s3_key={s3_key}, organization_id={organization_id}"
|
||||
)
|
||||
|
||||
temp_file_path = None
|
||||
|
||||
try:
|
||||
# Update status to processing
|
||||
await db_client.update_document_status(document_id, "processing")
|
||||
|
||||
# Extract file extension from S3 key
|
||||
filename = s3_key.split("/")[-1]
|
||||
file_extension = (
|
||||
os.path.splitext(filename)[1] or ".bin"
|
||||
) # Default to .bin if no extension
|
||||
|
||||
# Create temp file for download with correct extension
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=file_extension)
|
||||
temp_file_path = temp_file.name
|
||||
temp_file.close()
|
||||
|
||||
# Download file from S3
|
||||
logger.info(f"Downloading file from S3: {s3_key}")
|
||||
download_success = await storage_fs.adownload_file(s3_key, temp_file_path)
|
||||
|
||||
if not download_success:
|
||||
raise Exception(f"Failed to download file from S3: {s3_key}")
|
||||
|
||||
if not os.path.exists(temp_file_path):
|
||||
raise FileNotFoundError(f"Downloaded file not found: {temp_file_path}")
|
||||
|
||||
file_size = os.path.getsize(temp_file_path)
|
||||
logger.info(f"Downloaded file size: {file_size} bytes")
|
||||
|
||||
# Compute file hash and get mime type
|
||||
file_hash = db_client.compute_file_hash(temp_file_path)
|
||||
mime_type = db_client.get_mime_type(temp_file_path)
|
||||
filename = s3_key.split("/")[-1]
|
||||
|
||||
# Get document record
|
||||
document = await db_client.get_document_by_id(document_id)
|
||||
if not document:
|
||||
raise Exception(f"Document {document_id} not found")
|
||||
|
||||
# Check if a document with this hash already exists (reject duplicates)
|
||||
existing_doc = await db_client.get_document_by_hash(file_hash, organization_id)
|
||||
if existing_doc and existing_doc.id != document_id:
|
||||
error_message = (
|
||||
f"This file is a duplicate of '{existing_doc.filename}'. "
|
||||
f"Please delete the duplicate files and consolidate them into a single unique file before uploading."
|
||||
)
|
||||
logger.warning(
|
||||
f"Duplicate document detected: {document_id} is duplicate of {existing_doc.id} "
|
||||
f"({existing_doc.filename})"
|
||||
)
|
||||
# Update file metadata
|
||||
await db_client.update_document_metadata(
|
||||
document_id,
|
||||
file_size_bytes=file_size,
|
||||
file_hash=file_hash,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
# Mark as failed with duplicate error message
|
||||
await db_client.update_document_status(
|
||||
document_id,
|
||||
"failed",
|
||||
error_message=error_message,
|
||||
docling_metadata={
|
||||
"duplicate_of": existing_doc.document_uuid,
|
||||
"duplicate_filename": existing_doc.filename,
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
# Update document with file metadata
|
||||
await db_client.update_document_metadata(
|
||||
document_id,
|
||||
file_size_bytes=file_size,
|
||||
file_hash=file_hash,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
|
||||
# Initialize the embedding service based on the parameter
|
||||
if embedding_service == "openai":
|
||||
logger.info(
|
||||
f"Initializing OpenAI embedding service with max_tokens={max_tokens}"
|
||||
)
|
||||
# Try to get user's embeddings configuration
|
||||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
if document.created_by:
|
||||
user_config = await db_client.get_user_configurations(
|
||||
document.created_by
|
||||
)
|
||||
if user_config.embeddings:
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
logger.info(
|
||||
f"Using user embeddings config: model={embeddings_model}"
|
||||
)
|
||||
|
||||
# Check if API key is configured
|
||||
if not embeddings_api_key:
|
||||
error_message = (
|
||||
"OpenAI API key not configured. Please set your API key in "
|
||||
"Model Configurations > Embedding to process documents."
|
||||
)
|
||||
logger.warning(f"Document {document_id}: {error_message}")
|
||||
await db_client.update_document_status(
|
||||
document_id, "failed", error_message=error_message
|
||||
)
|
||||
return
|
||||
|
||||
service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
max_tokens=max_tokens,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
)
|
||||
elif embedding_service == "sentence_transformer":
|
||||
logger.info(
|
||||
f"Initializing SentenceTransformer embedding service with max_tokens={max_tokens}"
|
||||
)
|
||||
service = SentenceTransformerEmbeddingService(
|
||||
db_client=db_client,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid embedding_service: {embedding_service}. "
|
||||
f"Must be 'sentence_transformer' or 'openai'"
|
||||
)
|
||||
|
||||
# Step 1: Convert document with docling
|
||||
logger.info("Converting document with docling")
|
||||
converter = DocumentConverter()
|
||||
conversion_result = converter.convert(temp_file_path)
|
||||
doc = conversion_result.document
|
||||
|
||||
# Store docling metadata
|
||||
docling_metadata = {
|
||||
"num_pages": len(doc.pages) if hasattr(doc, "pages") else None,
|
||||
"document_type": type(doc).__name__,
|
||||
}
|
||||
|
||||
# Step 2: Initialize tokenizer for chunking
|
||||
logger.info(
|
||||
f"Loading tokenizer: {TOKENIZER_MODEL} with max_tokens={max_tokens}"
|
||||
)
|
||||
tokenizer = HuggingFaceTokenizer(
|
||||
tokenizer=AutoTokenizer.from_pretrained(TOKENIZER_MODEL),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# Step 3: Initialize chunker
|
||||
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
|
||||
chunker = HybridChunker(tokenizer=tokenizer)
|
||||
|
||||
# Step 4: Chunk the document
|
||||
logger.info(f"Chunking document with max_tokens={max_tokens}")
|
||||
chunks = list(chunker.chunk(dl_doc=doc))
|
||||
total_chunks = len(chunks)
|
||||
logger.info(f"Generated {total_chunks} chunks")
|
||||
|
||||
# Step 5: Process each chunk
|
||||
chunk_texts = []
|
||||
chunk_records = []
|
||||
token_counts = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk_text = chunk.text
|
||||
contextualized_text = chunker.contextualize(chunk=chunk)
|
||||
|
||||
# Calculate token count
|
||||
text_to_tokenize = (
|
||||
contextualized_text if contextualized_text else chunk_text
|
||||
)
|
||||
token_count = len(
|
||||
tokenizer.tokenizer.encode(text_to_tokenize, add_special_tokens=False)
|
||||
)
|
||||
token_counts.append(token_count)
|
||||
|
||||
# Prepare chunk metadata
|
||||
chunk_metadata = {}
|
||||
if hasattr(chunk, "meta") and chunk.meta:
|
||||
chunk_metadata = {
|
||||
"doc_items": (
|
||||
[str(item) for item in chunk.meta.doc_items]
|
||||
if hasattr(chunk.meta, "doc_items")
|
||||
else []
|
||||
),
|
||||
"headings": (
|
||||
chunk.meta.headings if hasattr(chunk.meta, "headings") else []
|
||||
),
|
||||
}
|
||||
|
||||
# Create chunk record (without embedding yet)
|
||||
chunk_record = KnowledgeBaseChunkModel(
|
||||
document_id=document_id,
|
||||
organization_id=organization_id,
|
||||
chunk_text=chunk_text,
|
||||
contextualized_text=contextualized_text,
|
||||
chunk_index=i,
|
||||
chunk_metadata=chunk_metadata,
|
||||
embedding_model=service.get_model_id(),
|
||||
embedding_dimension=service.get_embedding_dimension(),
|
||||
token_count=token_count,
|
||||
)
|
||||
|
||||
chunk_records.append(chunk_record)
|
||||
chunk_texts.append(text_to_tokenize)
|
||||
|
||||
# Log chunk statistics
|
||||
if token_counts:
|
||||
avg_tokens = sum(token_counts) / len(token_counts)
|
||||
min_tokens = min(token_counts)
|
||||
max_tokens_actual = max(token_counts)
|
||||
logger.info("Chunk token statistics:")
|
||||
logger.info(f" - Average: {avg_tokens:.1f} tokens")
|
||||
logger.info(f" - Min: {min_tokens} tokens")
|
||||
logger.info(f" - Max: {max_tokens_actual} tokens")
|
||||
|
||||
# Step 6: Generate embeddings using the embedding service
|
||||
logger.info(f"Generating embeddings using {embedding_service}")
|
||||
embeddings = await service.embed_texts(chunk_texts)
|
||||
|
||||
# Step 7: Attach embeddings to chunk records
|
||||
for chunk_record, embedding in zip(chunk_records, embeddings):
|
||||
chunk_record.embedding = embedding
|
||||
|
||||
# Step 8: Save chunks in database
|
||||
logger.info("Storing chunks in database")
|
||||
await db_client.create_chunks_batch(chunk_records)
|
||||
|
||||
# Step 9: Update document status to completed
|
||||
await db_client.update_document_status(
|
||||
document_id,
|
||||
"completed",
|
||||
total_chunks=total_chunks,
|
||||
docling_metadata=docling_metadata,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Successfully processed knowledge base document {document_id}. "
|
||||
f"Total chunks: {total_chunks}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing knowledge base document {document_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Update document status to failed
|
||||
await db_client.update_document_status(
|
||||
document_id, "failed", error_message=str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Clean up temp file
|
||||
if temp_file_path and os.path.exists(temp_file_path):
|
||||
try:
|
||||
os.remove(temp_file_path)
|
||||
logger.debug(f"Cleaned up temp file: {temp_file_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp file {temp_file_path}: {e}")
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
services:
|
||||
postgres:
|
||||
image: postgres:17
|
||||
image: pgvector/pgvector:pg17
|
||||
environment:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
|
|
|
|||
249
ui/src/app/files/DocumentList.tsx
Normal file
249
ui/src/app/files/DocumentList.tsx
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
'use client';
|
||||
|
||||
import { FileText, RefreshCw, Search, Trash2 } from 'lucide-react';
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
import { toast } from 'sonner';
|
||||
|
||||
import {
|
||||
deleteDocumentApiV1KnowledgeBaseDocumentsDocumentUuidDelete,
|
||||
listDocumentsApiV1KnowledgeBaseDocumentsGet,
|
||||
} from '@/client/sdk.gen';
|
||||
import type { DocumentResponseSchema } from '@/client/types.gen';
|
||||
import { Badge } from '@/components/ui/badge';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Input } from '@/components/ui/input';
|
||||
import { Skeleton } from '@/components/ui/skeleton';
|
||||
import logger from '@/lib/logger';
|
||||
|
||||
interface DocumentListProps {
|
||||
accessToken: string;
|
||||
refreshTrigger: number;
|
||||
}
|
||||
|
||||
export default function DocumentList({ accessToken, refreshTrigger }: DocumentListProps) {
|
||||
const [documents, setDocuments] = useState<DocumentResponseSchema[]>([]);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [searchQuery, setSearchQuery] = useState('');
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const fetchDocuments = useCallback(async () => {
|
||||
if (!accessToken) return;
|
||||
|
||||
try {
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
|
||||
const response = await listDocumentsApiV1KnowledgeBaseDocumentsGet({
|
||||
headers: {
|
||||
'Authorization': `Bearer ${accessToken}`,
|
||||
},
|
||||
query: {
|
||||
limit: 100,
|
||||
offset: 0,
|
||||
},
|
||||
});
|
||||
|
||||
if (response.error || !response.data) {
|
||||
throw new Error('Failed to fetch documents');
|
||||
}
|
||||
|
||||
setDocuments(response.data.documents);
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : 'Failed to fetch documents');
|
||||
logger.error('Error fetching documents:', err);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, [accessToken]);
|
||||
|
||||
// Fetch documents on mount and when refreshTrigger changes
|
||||
useEffect(() => {
|
||||
fetchDocuments();
|
||||
}, [fetchDocuments, refreshTrigger]);
|
||||
|
||||
// Poll for documents that are processing
|
||||
useEffect(() => {
|
||||
const processingDocs = documents.filter(
|
||||
(doc) => doc.processing_status === 'processing' || doc.processing_status === 'pending'
|
||||
);
|
||||
|
||||
if (processingDocs.length === 0) return;
|
||||
|
||||
const pollInterval = setInterval(() => {
|
||||
logger.info(`Polling for ${processingDocs.length} processing documents...`);
|
||||
fetchDocuments();
|
||||
}, 5000); // Poll every 5 seconds
|
||||
|
||||
return () => clearInterval(pollInterval);
|
||||
}, [documents, fetchDocuments]);
|
||||
|
||||
const handleDelete = async (documentUuid: string, filename: string) => {
|
||||
if (!confirm(`Are you sure you want to delete "${filename}"?`)) return;
|
||||
|
||||
try {
|
||||
const response = await deleteDocumentApiV1KnowledgeBaseDocumentsDocumentUuidDelete({
|
||||
path: {
|
||||
document_uuid: documentUuid,
|
||||
},
|
||||
headers: {
|
||||
'Authorization': `Bearer ${accessToken}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (response.error) {
|
||||
throw new Error('Failed to delete document');
|
||||
}
|
||||
|
||||
toast.success(`Deleted "${filename}"`);
|
||||
fetchDocuments();
|
||||
} catch (err) {
|
||||
toast.error(err instanceof Error ? err.message : 'Failed to delete document');
|
||||
logger.error('Error deleting document:', err);
|
||||
}
|
||||
};
|
||||
|
||||
const getStatusBadge = (status: string) => {
|
||||
switch (status) {
|
||||
case 'completed':
|
||||
return <Badge className="bg-green-500">Completed</Badge>;
|
||||
case 'processing':
|
||||
return (
|
||||
<Badge variant="secondary" className="animate-pulse">
|
||||
Processing
|
||||
</Badge>
|
||||
);
|
||||
case 'pending':
|
||||
return <Badge variant="outline">Pending</Badge>;
|
||||
case 'failed':
|
||||
return <Badge variant="destructive">Failed</Badge>;
|
||||
default:
|
||||
return <Badge variant="outline">{status}</Badge>;
|
||||
}
|
||||
};
|
||||
|
||||
const formatFileSize = (bytes: number): string => {
|
||||
if (bytes === 0) return '0 B';
|
||||
const k = 1024;
|
||||
const sizes = ['B', 'KB', 'MB', 'GB'];
|
||||
const i = Math.floor(Math.log(bytes) / Math.log(k));
|
||||
return `${parseFloat((bytes / Math.pow(k, i)).toFixed(2))} ${sizes[i]}`;
|
||||
};
|
||||
|
||||
const formatDate = (dateString: string): string => {
|
||||
const date = new Date(dateString);
|
||||
return date.toLocaleDateString() + ' ' + date.toLocaleTimeString();
|
||||
};
|
||||
|
||||
const filteredDocuments = documents.filter((doc) =>
|
||||
doc.filename.toLowerCase().includes(searchQuery.toLowerCase())
|
||||
);
|
||||
|
||||
if (isLoading && documents.length === 0) {
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
{[1, 2, 3].map((i) => (
|
||||
<div key={i} className="flex items-center justify-between p-4 border rounded-lg">
|
||||
<div className="space-y-2 flex-1">
|
||||
<Skeleton className="h-4 w-48" />
|
||||
<Skeleton className="h-3 w-64" />
|
||||
</div>
|
||||
<Skeleton className="h-8 w-24" />
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="p-4 bg-destructive/10 border border-destructive/20 rounded-lg text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
{/* Search and Refresh */}
|
||||
<div className="flex items-center gap-4">
|
||||
<div className="relative flex-1">
|
||||
<Search className="absolute left-3 top-1/2 transform -translate-y-1/2 h-4 w-4 text-muted-foreground" />
|
||||
<Input
|
||||
placeholder="Search documents..."
|
||||
value={searchQuery}
|
||||
onChange={(e) => setSearchQuery(e.target.value)}
|
||||
className="pl-10"
|
||||
/>
|
||||
</div>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
onClick={fetchDocuments}
|
||||
disabled={isLoading}
|
||||
>
|
||||
<RefreshCw className={`h-4 w-4 ${isLoading ? 'animate-spin' : ''}`} />
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{/* Document List */}
|
||||
{filteredDocuments.length === 0 ? (
|
||||
<div className="text-center py-12">
|
||||
<FileText className="w-12 h-12 text-muted-foreground mx-auto mb-4" />
|
||||
<p className="text-muted-foreground">
|
||||
{searchQuery
|
||||
? 'No documents match your search'
|
||||
: 'No documents uploaded yet'}
|
||||
</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-3">
|
||||
{filteredDocuments.map((doc) => (
|
||||
<div
|
||||
key={doc.document_uuid}
|
||||
className="flex items-center justify-between p-4 border rounded-lg hover:bg-muted/50 transition-colors"
|
||||
>
|
||||
<div className="flex items-center gap-4 flex-1">
|
||||
<div className="w-10 h-10 rounded-lg bg-primary/10 flex items-center justify-center">
|
||||
<FileText className="w-5 h-5 text-primary" />
|
||||
</div>
|
||||
<div className="flex-1 min-w-0">
|
||||
<div className="flex items-center gap-2 mb-1">
|
||||
<span className="font-medium truncate">{doc.filename}</span>
|
||||
{getStatusBadge(doc.processing_status)}
|
||||
</div>
|
||||
<div className="flex items-center gap-4 text-sm text-muted-foreground">
|
||||
<span>{formatFileSize(doc.file_size_bytes)}</span>
|
||||
{doc.processing_status === 'completed' && (
|
||||
<span>{doc.total_chunks} chunks</span>
|
||||
)}
|
||||
<span>{formatDate(doc.created_at)}</span>
|
||||
</div>
|
||||
{doc.processing_error && (
|
||||
<p className="text-xs text-destructive mt-1">
|
||||
Error: {doc.processing_error}
|
||||
</p>
|
||||
)}
|
||||
{doc.docling_metadata &&
|
||||
typeof doc.docling_metadata === 'object' &&
|
||||
'duplicate_of' in doc.docling_metadata && (
|
||||
<p className="text-xs text-muted-foreground mt-1">
|
||||
Duplicate of another document
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={() => handleDelete(doc.document_uuid, doc.filename)}
|
||||
className="text-destructive hover:text-destructive/90"
|
||||
>
|
||||
<Trash2 className="w-4 h-4" />
|
||||
</Button>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
220
ui/src/app/files/DocumentUpload.tsx
Normal file
220
ui/src/app/files/DocumentUpload.tsx
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
'use client';
|
||||
|
||||
import { Upload } from 'lucide-react';
|
||||
import { useRef, useState } from 'react';
|
||||
import { toast } from 'sonner';
|
||||
|
||||
import {
|
||||
getUploadUrlApiV1KnowledgeBaseUploadUrlPost,
|
||||
processDocumentApiV1KnowledgeBaseProcessDocumentPost,
|
||||
} from '@/client/sdk.gen';
|
||||
import type { DocumentUploadResponseSchema } from '@/client/types.gen';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Progress } from '@/components/ui/progress';
|
||||
import logger from '@/lib/logger';
|
||||
|
||||
interface DocumentUploadProps {
|
||||
accessToken: string;
|
||||
onUploadSuccess: () => void;
|
||||
}
|
||||
|
||||
const MAX_FILE_SIZE = 100 * 1024 * 1024; // 100MB
|
||||
const ACCEPTED_FILE_TYPES = ['.pdf', '.docx', '.doc', '.txt'];
|
||||
|
||||
export default function DocumentUpload({ accessToken, onUploadSuccess }: DocumentUploadProps) {
|
||||
const [uploading, setUploading] = useState(false);
|
||||
const [uploadProgress, setUploadProgress] = useState(0);
|
||||
const [dragActive, setDragActive] = useState(false);
|
||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const validateFile = (file: File): boolean => {
|
||||
// Validate file type
|
||||
const fileExtension = '.' + file.name.split('.').pop()?.toLowerCase();
|
||||
if (!ACCEPTED_FILE_TYPES.includes(fileExtension)) {
|
||||
toast.error(`Please select a supported file type: ${ACCEPTED_FILE_TYPES.join(', ')}`);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Validate file size
|
||||
if (file.size > MAX_FILE_SIZE) {
|
||||
toast.error('File size must be less than 100MB');
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
const uploadFile = async (file: File) => {
|
||||
if (!validateFile(file)) return;
|
||||
|
||||
setUploading(true);
|
||||
setUploadProgress(0);
|
||||
|
||||
try {
|
||||
// Step 1: Request presigned upload URL
|
||||
logger.info('Requesting presigned upload URL for:', file.name);
|
||||
const uploadUrlResponse = await getUploadUrlApiV1KnowledgeBaseUploadUrlPost({
|
||||
body: {
|
||||
filename: file.name,
|
||||
mime_type: file.type || 'application/octet-stream',
|
||||
custom_metadata: {
|
||||
original_filename: file.name,
|
||||
uploaded_at: new Date().toISOString(),
|
||||
},
|
||||
},
|
||||
headers: {
|
||||
'Authorization': `Bearer ${accessToken}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (uploadUrlResponse.error || !uploadUrlResponse.data) {
|
||||
throw new Error('Failed to get upload URL');
|
||||
}
|
||||
|
||||
const uploadData: DocumentUploadResponseSchema = uploadUrlResponse.data;
|
||||
logger.info('Received presigned URL, uploading file...');
|
||||
|
||||
setUploadProgress(25);
|
||||
|
||||
// Step 2: Upload file directly to S3/MinIO using PUT
|
||||
const uploadResponse = await fetch(uploadData.upload_url, {
|
||||
method: 'PUT',
|
||||
body: file,
|
||||
headers: {
|
||||
'Content-Type': file.type || 'application/octet-stream',
|
||||
},
|
||||
});
|
||||
|
||||
if (!uploadResponse.ok) {
|
||||
throw new Error('Failed to upload file to storage');
|
||||
}
|
||||
|
||||
setUploadProgress(75);
|
||||
logger.info('File uploaded successfully, triggering processing...');
|
||||
|
||||
// Step 3: Trigger document processing
|
||||
const processResponse = await processDocumentApiV1KnowledgeBaseProcessDocumentPost({
|
||||
body: {
|
||||
document_uuid: uploadData.document_uuid,
|
||||
s3_key: uploadData.s3_key,
|
||||
},
|
||||
headers: {
|
||||
'Authorization': `Bearer ${accessToken}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (processResponse.error) {
|
||||
throw new Error('Failed to trigger processing');
|
||||
}
|
||||
|
||||
setUploadProgress(100);
|
||||
logger.info('Document processing triggered successfully');
|
||||
|
||||
toast.success(`File uploaded: ${file.name}. Processing started.`);
|
||||
onUploadSuccess();
|
||||
} catch (error) {
|
||||
logger.error('Error uploading document:', error);
|
||||
toast.error(error instanceof Error ? error.message : 'Failed to upload document');
|
||||
} finally {
|
||||
setUploading(false);
|
||||
setUploadProgress(0);
|
||||
// Reset file input
|
||||
if (fileInputRef.current) {
|
||||
fileInputRef.current.value = '';
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleFileSelect = async (event: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const file = event.target.files?.[0];
|
||||
if (file) {
|
||||
await uploadFile(file);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDrag = (e: React.DragEvent) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
if (e.type === 'dragenter' || e.type === 'dragover') {
|
||||
setDragActive(true);
|
||||
} else if (e.type === 'dragleave') {
|
||||
setDragActive(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDrop = async (e: React.DragEvent) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
setDragActive(false);
|
||||
|
||||
const file = e.dataTransfer.files?.[0];
|
||||
if (file) {
|
||||
await uploadFile(file);
|
||||
}
|
||||
};
|
||||
|
||||
const handleButtonClick = () => {
|
||||
fileInputRef.current?.click();
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<input
|
||||
ref={fileInputRef}
|
||||
type="file"
|
||||
accept={ACCEPTED_FILE_TYPES.join(',')}
|
||||
onChange={handleFileSelect}
|
||||
className="hidden"
|
||||
disabled={uploading}
|
||||
/>
|
||||
|
||||
{/* Drag and Drop Area */}
|
||||
<div
|
||||
className={`
|
||||
border-2 border-dashed rounded-lg p-8 text-center transition-colors
|
||||
${dragActive ? 'border-primary bg-primary/5' : 'border-muted-foreground/25'}
|
||||
${uploading ? 'opacity-50 pointer-events-none' : 'cursor-pointer hover:border-primary hover:bg-muted/50'}
|
||||
`}
|
||||
onDragEnter={handleDrag}
|
||||
onDragLeave={handleDrag}
|
||||
onDragOver={handleDrag}
|
||||
onDrop={handleDrop}
|
||||
onClick={handleButtonClick}
|
||||
>
|
||||
<Upload className="w-12 h-12 mx-auto mb-4 text-muted-foreground" />
|
||||
<p className="text-lg font-medium mb-2">
|
||||
{uploading ? 'Uploading...' : 'Drop your document here'}
|
||||
</p>
|
||||
<p className="text-sm text-muted-foreground mb-4">
|
||||
or click to browse
|
||||
</p>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Supported formats: {ACCEPTED_FILE_TYPES.join(', ')} (Max 100MB)
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Upload Progress */}
|
||||
{uploading && (
|
||||
<div className="space-y-2">
|
||||
<div className="flex justify-between text-sm">
|
||||
<span>Uploading...</span>
|
||||
<span>{uploadProgress}%</span>
|
||||
</div>
|
||||
<Progress value={uploadProgress} />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Manual Upload Button */}
|
||||
<div className="flex justify-center">
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
onClick={handleButtonClick}
|
||||
disabled={uploading}
|
||||
>
|
||||
{uploading ? 'Uploading...' : 'Choose File'}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
104
ui/src/app/files/page.tsx
Normal file
104
ui/src/app/files/page.tsx
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||
import { useAuth } from "@/lib/auth";
|
||||
|
||||
import DocumentList from "./DocumentList";
|
||||
import DocumentUpload from "./DocumentUpload";
|
||||
|
||||
export default function FilesPage() {
|
||||
const { user, getAccessToken, redirectToLogin, loading } = useAuth();
|
||||
const [refreshKey, setRefreshKey] = useState(0);
|
||||
const [accessToken, setAccessToken] = useState<string>('');
|
||||
|
||||
// Redirect if not authenticated
|
||||
useEffect(() => {
|
||||
if (!loading && !user) {
|
||||
redirectToLogin();
|
||||
}
|
||||
}, [loading, user, redirectToLogin]);
|
||||
|
||||
// Get access token
|
||||
const fetchAccessToken = useCallback(async () => {
|
||||
if (user) {
|
||||
const token = await getAccessToken();
|
||||
setAccessToken(token);
|
||||
}
|
||||
}, [user, getAccessToken]);
|
||||
|
||||
useEffect(() => {
|
||||
fetchAccessToken();
|
||||
}, [fetchAccessToken]);
|
||||
|
||||
const handleUploadSuccess = () => {
|
||||
// Trigger refresh of document list
|
||||
setRefreshKey(prev => prev + 1);
|
||||
};
|
||||
|
||||
if (loading || !user || !accessToken) {
|
||||
return (
|
||||
<div className="container mx-auto px-4 py-8">
|
||||
<div className="space-y-4">
|
||||
<Skeleton className="h-12 w-64" />
|
||||
<Skeleton className="h-64 w-full" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="container mx-auto px-4 py-8">
|
||||
<div className="mb-8">
|
||||
<h1 className="text-3xl font-bold mb-2">Knowledge Base Files</h1>
|
||||
<p className="text-muted-foreground">
|
||||
Upload and manage documents for your voice agents to reference.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<Tabs defaultValue="all" className="space-y-6">
|
||||
<TabsList>
|
||||
<TabsTrigger value="all">All Files</TabsTrigger>
|
||||
<TabsTrigger value="upload">Upload New</TabsTrigger>
|
||||
</TabsList>
|
||||
|
||||
<TabsContent value="all" className="space-y-4">
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>Your Documents</CardTitle>
|
||||
<CardDescription>
|
||||
View and manage your uploaded documents
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<DocumentList
|
||||
accessToken={accessToken}
|
||||
refreshTrigger={refreshKey}
|
||||
/>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value="upload" className="space-y-4">
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>Upload Document</CardTitle>
|
||||
<CardDescription>
|
||||
Upload a PDF or document file to add to your knowledge base
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<DocumentUpload
|
||||
accessToken={accessToken}
|
||||
onUploadSuccess={handleUploadSuccess}
|
||||
/>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</TabsContent>
|
||||
</Tabs>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -7,8 +7,10 @@ import {
|
|||
ReactFlow,
|
||||
} from "@xyflow/react";
|
||||
import { BrushCleaning, Maximize2, Minus, Plus, Rocket, Settings, Variable } from 'lucide-react';
|
||||
import React, { useMemo, useState } from 'react';
|
||||
import React, { useEffect, useMemo, useState } from 'react';
|
||||
|
||||
import { listDocumentsApiV1KnowledgeBaseDocumentsGet, listToolsApiV1ToolsGet } from '@/client';
|
||||
import type { DocumentResponseSchema, ToolResponse } from '@/client/types.gen';
|
||||
import { FlowEdge, FlowNode, NodeType } from "@/components/flow/types";
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/tooltip';
|
||||
|
|
@ -63,6 +65,8 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
|
|||
const [isConfigurationsDialogOpen, setIsConfigurationsDialogOpen] = useState(false);
|
||||
const [isEmbedDialogOpen, setIsEmbedDialogOpen] = useState(false);
|
||||
const [isPhoneCallDialogOpen, setIsPhoneCallDialogOpen] = useState(false);
|
||||
const [documents, setDocuments] = useState<DocumentResponseSchema[] | undefined>(undefined);
|
||||
const [tools, setTools] = useState<ToolResponse[] | undefined>(undefined);
|
||||
|
||||
const {
|
||||
rfInstance,
|
||||
|
|
@ -95,6 +99,36 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
|
|||
getAccessToken
|
||||
});
|
||||
|
||||
// Fetch documents and tools once for the entire workflow
|
||||
useEffect(() => {
|
||||
const fetchData = async () => {
|
||||
try {
|
||||
const accessToken = await getAccessToken();
|
||||
|
||||
// Fetch documents
|
||||
const documentsResponse = await listDocumentsApiV1KnowledgeBaseDocumentsGet({
|
||||
headers: { Authorization: `Bearer ${accessToken}` },
|
||||
query: { limit: 100 },
|
||||
});
|
||||
if (documentsResponse.data) {
|
||||
setDocuments(documentsResponse.data.documents);
|
||||
}
|
||||
|
||||
// Fetch tools
|
||||
const toolsResponse = await listToolsApiV1ToolsGet({
|
||||
headers: { Authorization: `Bearer ${accessToken}` },
|
||||
});
|
||||
if (toolsResponse.data) {
|
||||
setTools(toolsResponse.data);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch documents and tools:', error);
|
||||
}
|
||||
};
|
||||
|
||||
fetchData();
|
||||
}, [getAccessToken]);
|
||||
|
||||
// Memoize defaultEdgeOptions to prevent unnecessary re-renders
|
||||
const defaultEdgeOptions = useMemo(() => ({
|
||||
animated: true,
|
||||
|
|
@ -102,7 +136,11 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
|
|||
}), []);
|
||||
|
||||
// Memoize the context value to prevent unnecessary re-renders
|
||||
const workflowContextValue = useMemo(() => ({ saveWorkflow }), [saveWorkflow]);
|
||||
const workflowContextValue = useMemo(() => ({
|
||||
saveWorkflow,
|
||||
documents,
|
||||
tools
|
||||
}), [saveWorkflow, documents, tools]);
|
||||
|
||||
return (
|
||||
<WorkflowProvider value={workflowContextValue}>
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
import { createContext, useContext } from 'react';
|
||||
|
||||
import type { DocumentResponseSchema, ToolResponse } from '@/client/types.gen';
|
||||
|
||||
interface WorkflowContextType {
|
||||
saveWorkflow: (updateWorkflowDefinition?: boolean) => Promise<void>;
|
||||
documents?: DocumentResponseSchema[];
|
||||
tools?: ToolResponse[];
|
||||
}
|
||||
|
||||
const WorkflowContext = createContext<WorkflowContextType | undefined>(undefined);
|
||||
|
|
@ -15,3 +19,8 @@ export const useWorkflow = () => {
|
|||
}
|
||||
return context;
|
||||
};
|
||||
|
||||
// Optional hook that doesn't throw if context is not available
|
||||
export const useWorkflowOptional = () => {
|
||||
return useContext(WorkflowContext);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
// This file is auto-generated by @hey-api/openapi-ts
|
||||
|
||||
import { type ClientOptions as DefaultClientOptions, type Config, createClient, createConfig } from '@hey-api/client-fetch';
|
||||
|
||||
import { createClientConfig } from '../lib/apiClient';
|
||||
import type { ClientOptions } from './types.gen';
|
||||
import { type Config, type ClientOptions as DefaultClientOptions, createClient, createConfig } from '@hey-api/client-fetch';
|
||||
import { createClientConfig } from '../lib/apiClient';
|
||||
|
||||
/**
|
||||
* The `createClientConfig()` function will be called on client initialization
|
||||
|
|
@ -17,4 +16,4 @@ export type CreateClientConfig<T extends DefaultClientOptions = ClientOptions> =
|
|||
|
||||
export const client = createClient(createClientConfig(createConfig<ClientOptions>({
|
||||
baseUrl: 'http://127.0.0.1:8000'
|
||||
})));
|
||||
})));
|
||||
|
|
@ -1,3 +1,3 @@
|
|||
// This file is auto-generated by @hey-api/openapi-ts
|
||||
export * from './sdk.gen';
|
||||
export * from './types.gen';
|
||||
export * from './sdk.gen';
|
||||
File diff suppressed because one or more lines are too long
|
|
@ -83,6 +83,54 @@ export type CampaignsResponse = {
|
|||
campaigns: Array<CampaignResponse>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Response schema for a document chunk.
|
||||
*/
|
||||
export type ChunkResponseSchema = {
|
||||
id: number;
|
||||
document_id: number;
|
||||
chunk_text: string;
|
||||
contextualized_text: string | null;
|
||||
chunk_index: number;
|
||||
chunk_metadata: {
|
||||
[key: string]: unknown;
|
||||
};
|
||||
filename: string;
|
||||
document_uuid: string;
|
||||
similarity: number;
|
||||
};
|
||||
|
||||
/**
|
||||
* Request schema for searching similar chunks.
|
||||
*/
|
||||
export type ChunkSearchRequestSchema = {
|
||||
/**
|
||||
* Search query text
|
||||
*/
|
||||
query: string;
|
||||
/**
|
||||
* Maximum number of results
|
||||
*/
|
||||
limit?: number;
|
||||
/**
|
||||
* Filter by specific document UUIDs
|
||||
*/
|
||||
document_uuids?: Array<string> | null;
|
||||
/**
|
||||
* Minimum similarity threshold
|
||||
*/
|
||||
min_similarity?: number | null;
|
||||
};
|
||||
|
||||
/**
|
||||
* Response schema for chunk search results.
|
||||
*/
|
||||
export type ChunkSearchResponseSchema = {
|
||||
chunks: Array<ChunkResponseSchema>;
|
||||
query: string;
|
||||
total_results: number;
|
||||
};
|
||||
|
||||
/**
|
||||
* Request schema for Cloudonix configuration.
|
||||
*/
|
||||
|
|
@ -303,11 +351,91 @@ export type DefaultConfigurationsResponse = {
|
|||
[key: string]: unknown;
|
||||
};
|
||||
};
|
||||
embeddings: {
|
||||
[key: string]: {
|
||||
[key: string]: unknown;
|
||||
};
|
||||
};
|
||||
default_providers: {
|
||||
[key: string]: string;
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Response schema for list of documents.
|
||||
*/
|
||||
export type DocumentListResponseSchema = {
|
||||
documents: Array<DocumentResponseSchema>;
|
||||
total: number;
|
||||
limit: number;
|
||||
offset: number;
|
||||
};
|
||||
|
||||
/**
|
||||
* Response schema for document metadata.
|
||||
*/
|
||||
export type DocumentResponseSchema = {
|
||||
id: number;
|
||||
document_uuid: string;
|
||||
filename: string;
|
||||
file_size_bytes: number;
|
||||
file_hash: string;
|
||||
mime_type: string;
|
||||
processing_status: string;
|
||||
processing_error?: string | null;
|
||||
total_chunks: number;
|
||||
custom_metadata: {
|
||||
[key: string]: unknown;
|
||||
};
|
||||
docling_metadata: {
|
||||
[key: string]: unknown;
|
||||
};
|
||||
source_url?: string | null;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
organization_id: number;
|
||||
created_by: number;
|
||||
is_active: boolean;
|
||||
};
|
||||
|
||||
/**
|
||||
* Request schema for initiating document upload.
|
||||
*/
|
||||
export type DocumentUploadRequestSchema = {
|
||||
/**
|
||||
* Name of the file to upload
|
||||
*/
|
||||
filename: string;
|
||||
/**
|
||||
* MIME type of the file
|
||||
*/
|
||||
mime_type: string;
|
||||
/**
|
||||
* Optional custom metadata
|
||||
*/
|
||||
custom_metadata?: {
|
||||
[key: string]: unknown;
|
||||
} | null;
|
||||
};
|
||||
|
||||
/**
|
||||
* Response schema containing upload URL and document metadata.
|
||||
*/
|
||||
export type DocumentUploadResponseSchema = {
|
||||
/**
|
||||
* Signed URL for uploading the file
|
||||
*/
|
||||
upload_url: string;
|
||||
/**
|
||||
* Unique identifier for the document
|
||||
*/
|
||||
document_uuid: string;
|
||||
/**
|
||||
* S3 key where file should be uploaded
|
||||
*/
|
||||
s3_key: string;
|
||||
};
|
||||
|
||||
export type DuplicateTemplateRequest = {
|
||||
template_id: number;
|
||||
workflow_name: string;
|
||||
|
|
@ -537,6 +665,24 @@ export type PresignedUploadUrlResponse = {
|
|||
expires_in: number;
|
||||
};
|
||||
|
||||
/**
|
||||
* Request schema for triggering document processing.
|
||||
*/
|
||||
export type ProcessDocumentRequestSchema = {
|
||||
/**
|
||||
* Document UUID to process
|
||||
*/
|
||||
document_uuid: string;
|
||||
/**
|
||||
* S3 key of the uploaded file
|
||||
*/
|
||||
s3_key: string;
|
||||
/**
|
||||
* Embedding service to use for processing. Options: 'openai' (default, 1536-dim, requires API key) or 'sentence_transformer' (free, 384-dim)
|
||||
*/
|
||||
embedding_service?: 'sentence_transformer' | 'openai';
|
||||
};
|
||||
|
||||
export type S3SignedUrlResponse = {
|
||||
url: string;
|
||||
expires_in: number;
|
||||
|
|
@ -787,6 +933,9 @@ export type UserConfigurationRequestResponseSchema = {
|
|||
stt?: {
|
||||
[key: string]: string | number;
|
||||
} | null;
|
||||
embeddings?: {
|
||||
[key: string]: string | number;
|
||||
} | null;
|
||||
test_phone_number?: string | null;
|
||||
timezone?: string | null;
|
||||
organization_pricing?: {
|
||||
|
|
@ -4126,6 +4275,213 @@ export type CreateOrUpdateEmbedTokenApiV1WorkflowWorkflowIdEmbedTokenPostRespons
|
|||
|
||||
export type CreateOrUpdateEmbedTokenApiV1WorkflowWorkflowIdEmbedTokenPostResponse = CreateOrUpdateEmbedTokenApiV1WorkflowWorkflowIdEmbedTokenPostResponses[keyof CreateOrUpdateEmbedTokenApiV1WorkflowWorkflowIdEmbedTokenPostResponses];
|
||||
|
||||
export type GetUploadUrlApiV1KnowledgeBaseUploadUrlPostData = {
|
||||
body: DocumentUploadRequestSchema;
|
||||
headers?: {
|
||||
authorization?: string | null;
|
||||
'X-API-Key'?: string | null;
|
||||
};
|
||||
path?: never;
|
||||
query?: never;
|
||||
url: '/api/v1/knowledge-base/upload-url';
|
||||
};
|
||||
|
||||
export type GetUploadUrlApiV1KnowledgeBaseUploadUrlPostErrors = {
|
||||
/**
|
||||
* Not found
|
||||
*/
|
||||
404: unknown;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HttpValidationError;
|
||||
};
|
||||
|
||||
export type GetUploadUrlApiV1KnowledgeBaseUploadUrlPostError = GetUploadUrlApiV1KnowledgeBaseUploadUrlPostErrors[keyof GetUploadUrlApiV1KnowledgeBaseUploadUrlPostErrors];
|
||||
|
||||
export type GetUploadUrlApiV1KnowledgeBaseUploadUrlPostResponses = {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: DocumentUploadResponseSchema;
|
||||
};
|
||||
|
||||
export type GetUploadUrlApiV1KnowledgeBaseUploadUrlPostResponse = GetUploadUrlApiV1KnowledgeBaseUploadUrlPostResponses[keyof GetUploadUrlApiV1KnowledgeBaseUploadUrlPostResponses];
|
||||
|
||||
export type ProcessDocumentApiV1KnowledgeBaseProcessDocumentPostData = {
|
||||
body: ProcessDocumentRequestSchema;
|
||||
headers?: {
|
||||
authorization?: string | null;
|
||||
'X-API-Key'?: string | null;
|
||||
};
|
||||
path?: never;
|
||||
query?: never;
|
||||
url: '/api/v1/knowledge-base/process-document';
|
||||
};
|
||||
|
||||
export type ProcessDocumentApiV1KnowledgeBaseProcessDocumentPostErrors = {
|
||||
/**
|
||||
* Not found
|
||||
*/
|
||||
404: unknown;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HttpValidationError;
|
||||
};
|
||||
|
||||
export type ProcessDocumentApiV1KnowledgeBaseProcessDocumentPostError = ProcessDocumentApiV1KnowledgeBaseProcessDocumentPostErrors[keyof ProcessDocumentApiV1KnowledgeBaseProcessDocumentPostErrors];
|
||||
|
||||
export type ProcessDocumentApiV1KnowledgeBaseProcessDocumentPostResponses = {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: DocumentResponseSchema;
|
||||
};
|
||||
|
||||
export type ProcessDocumentApiV1KnowledgeBaseProcessDocumentPostResponse = ProcessDocumentApiV1KnowledgeBaseProcessDocumentPostResponses[keyof ProcessDocumentApiV1KnowledgeBaseProcessDocumentPostResponses];
|
||||
|
||||
export type ListDocumentsApiV1KnowledgeBaseDocumentsGetData = {
|
||||
body?: never;
|
||||
headers?: {
|
||||
authorization?: string | null;
|
||||
'X-API-Key'?: string | null;
|
||||
};
|
||||
path?: never;
|
||||
query?: {
|
||||
/**
|
||||
* Filter by processing status
|
||||
*/
|
||||
status?: string | null;
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
};
|
||||
url: '/api/v1/knowledge-base/documents';
|
||||
};
|
||||
|
||||
export type ListDocumentsApiV1KnowledgeBaseDocumentsGetErrors = {
|
||||
/**
|
||||
* Not found
|
||||
*/
|
||||
404: unknown;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HttpValidationError;
|
||||
};
|
||||
|
||||
export type ListDocumentsApiV1KnowledgeBaseDocumentsGetError = ListDocumentsApiV1KnowledgeBaseDocumentsGetErrors[keyof ListDocumentsApiV1KnowledgeBaseDocumentsGetErrors];
|
||||
|
||||
export type ListDocumentsApiV1KnowledgeBaseDocumentsGetResponses = {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: DocumentListResponseSchema;
|
||||
};
|
||||
|
||||
export type ListDocumentsApiV1KnowledgeBaseDocumentsGetResponse = ListDocumentsApiV1KnowledgeBaseDocumentsGetResponses[keyof ListDocumentsApiV1KnowledgeBaseDocumentsGetResponses];
|
||||
|
||||
export type DeleteDocumentApiV1KnowledgeBaseDocumentsDocumentUuidDeleteData = {
|
||||
body?: never;
|
||||
headers?: {
|
||||
authorization?: string | null;
|
||||
'X-API-Key'?: string | null;
|
||||
};
|
||||
path: {
|
||||
document_uuid: string;
|
||||
};
|
||||
query?: never;
|
||||
url: '/api/v1/knowledge-base/documents/{document_uuid}';
|
||||
};
|
||||
|
||||
export type DeleteDocumentApiV1KnowledgeBaseDocumentsDocumentUuidDeleteErrors = {
|
||||
/**
|
||||
* Not found
|
||||
*/
|
||||
404: unknown;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HttpValidationError;
|
||||
};
|
||||
|
||||
export type DeleteDocumentApiV1KnowledgeBaseDocumentsDocumentUuidDeleteError = DeleteDocumentApiV1KnowledgeBaseDocumentsDocumentUuidDeleteErrors[keyof DeleteDocumentApiV1KnowledgeBaseDocumentsDocumentUuidDeleteErrors];
|
||||
|
||||
export type DeleteDocumentApiV1KnowledgeBaseDocumentsDocumentUuidDeleteResponses = {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: unknown;
|
||||
};
|
||||
|
||||
export type GetDocumentApiV1KnowledgeBaseDocumentsDocumentUuidGetData = {
|
||||
body?: never;
|
||||
headers?: {
|
||||
authorization?: string | null;
|
||||
'X-API-Key'?: string | null;
|
||||
};
|
||||
path: {
|
||||
document_uuid: string;
|
||||
};
|
||||
query?: never;
|
||||
url: '/api/v1/knowledge-base/documents/{document_uuid}';
|
||||
};
|
||||
|
||||
export type GetDocumentApiV1KnowledgeBaseDocumentsDocumentUuidGetErrors = {
|
||||
/**
|
||||
* Not found
|
||||
*/
|
||||
404: unknown;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HttpValidationError;
|
||||
};
|
||||
|
||||
export type GetDocumentApiV1KnowledgeBaseDocumentsDocumentUuidGetError = GetDocumentApiV1KnowledgeBaseDocumentsDocumentUuidGetErrors[keyof GetDocumentApiV1KnowledgeBaseDocumentsDocumentUuidGetErrors];
|
||||
|
||||
export type GetDocumentApiV1KnowledgeBaseDocumentsDocumentUuidGetResponses = {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: DocumentResponseSchema;
|
||||
};
|
||||
|
||||
export type GetDocumentApiV1KnowledgeBaseDocumentsDocumentUuidGetResponse = GetDocumentApiV1KnowledgeBaseDocumentsDocumentUuidGetResponses[keyof GetDocumentApiV1KnowledgeBaseDocumentsDocumentUuidGetResponses];
|
||||
|
||||
export type SearchChunksApiV1KnowledgeBaseSearchPostData = {
|
||||
body: ChunkSearchRequestSchema;
|
||||
headers?: {
|
||||
authorization?: string | null;
|
||||
'X-API-Key'?: string | null;
|
||||
};
|
||||
path?: never;
|
||||
query?: never;
|
||||
url: '/api/v1/knowledge-base/search';
|
||||
};
|
||||
|
||||
export type SearchChunksApiV1KnowledgeBaseSearchPostErrors = {
|
||||
/**
|
||||
* Not found
|
||||
*/
|
||||
404: unknown;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HttpValidationError;
|
||||
};
|
||||
|
||||
export type SearchChunksApiV1KnowledgeBaseSearchPostError = SearchChunksApiV1KnowledgeBaseSearchPostErrors[keyof SearchChunksApiV1KnowledgeBaseSearchPostErrors];
|
||||
|
||||
export type SearchChunksApiV1KnowledgeBaseSearchPostResponses = {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: ChunkSearchResponseSchema;
|
||||
};
|
||||
|
||||
export type SearchChunksApiV1KnowledgeBaseSearchPostResponse = SearchChunksApiV1KnowledgeBaseSearchPostResponses[keyof SearchChunksApiV1KnowledgeBaseSearchPostResponses];
|
||||
|
||||
export type HealthApiV1HealthGetData = {
|
||||
body?: never;
|
||||
path?: never;
|
||||
|
|
@ -4149,4 +4505,4 @@ export type HealthApiV1HealthGetResponses = {
|
|||
|
||||
export type ClientOptions = {
|
||||
baseUrl: 'http://127.0.0.1:8000' | (string & {});
|
||||
};
|
||||
};
|
||||
|
|
@ -14,7 +14,7 @@ import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
|||
import { VoiceSelector } from "@/components/VoiceSelector";
|
||||
import { useUserConfig } from "@/context/UserConfigContext";
|
||||
|
||||
type ServiceSegment = "llm" | "tts" | "stt";
|
||||
type ServiceSegment = "llm" | "tts" | "stt" | "embeddings";
|
||||
|
||||
interface SchemaProperty {
|
||||
type?: string;
|
||||
|
|
@ -41,6 +41,7 @@ const TAB_CONFIG: { key: ServiceSegment; label: string }[] = [
|
|||
{ key: "llm", label: "LLM" },
|
||||
{ key: "tts", label: "Voice" },
|
||||
{ key: "stt", label: "Transcriber" },
|
||||
{ key: "embeddings", label: "Embedding" },
|
||||
];
|
||||
|
||||
// Display names for language codes (Deepgram + Sarvam)
|
||||
|
|
@ -109,12 +110,14 @@ export default function ServiceConfiguration() {
|
|||
const [schemas, setSchemas] = useState<Record<ServiceSegment, Record<string, ProviderSchema>>>({
|
||||
llm: {},
|
||||
tts: {},
|
||||
stt: {}
|
||||
stt: {},
|
||||
embeddings: {}
|
||||
});
|
||||
const [serviceProviders, setServiceProviders] = useState<Record<ServiceSegment, string>>({
|
||||
llm: "",
|
||||
tts: "",
|
||||
stt: ""
|
||||
stt: "",
|
||||
embeddings: ""
|
||||
});
|
||||
const [isManualModelInput, setIsManualModelInput] = useState(false);
|
||||
const [hasCheckedManualMode, setHasCheckedManualMode] = useState(false);
|
||||
|
|
@ -136,7 +139,8 @@ export default function ServiceConfiguration() {
|
|||
setSchemas({
|
||||
llm: response.data.llm as Record<string, ProviderSchema>,
|
||||
tts: response.data.tts as Record<string, ProviderSchema>,
|
||||
stt: response.data.stt as Record<string, ProviderSchema>
|
||||
stt: response.data.stt as Record<string, ProviderSchema>,
|
||||
embeddings: response.data.embeddings as Record<string, ProviderSchema>
|
||||
});
|
||||
} else {
|
||||
console.error("Failed to fetch configurations");
|
||||
|
|
@ -147,7 +151,8 @@ export default function ServiceConfiguration() {
|
|||
const selectedProviders: Record<ServiceSegment, string> = {
|
||||
llm: response.data.default_providers.llm,
|
||||
tts: response.data.default_providers.tts,
|
||||
stt: response.data.default_providers.stt
|
||||
stt: response.data.default_providers.stt,
|
||||
embeddings: response.data.default_providers.embeddings
|
||||
};
|
||||
|
||||
const setServicePropertyValues = (service: ServiceSegment) => {
|
||||
|
|
@ -173,6 +178,7 @@ export default function ServiceConfiguration() {
|
|||
setServicePropertyValues("llm");
|
||||
setServicePropertyValues("tts");
|
||||
setServicePropertyValues("stt");
|
||||
setServicePropertyValues("embeddings");
|
||||
|
||||
// IMPORTANT: Reset form values BEFORE changing providers
|
||||
// Otherwise, Radix Select sees old values that don't match new provider's enum
|
||||
|
|
@ -246,7 +252,7 @@ export default function ServiceConfiguration() {
|
|||
setApiError(null);
|
||||
setIsSaving(true);
|
||||
|
||||
const userConfig = {
|
||||
const userConfig: Record<ServiceSegment, Record<string, string | number>> = {
|
||||
llm: {
|
||||
provider: serviceProviders.llm,
|
||||
api_key: data.llm_api_key as string,
|
||||
|
|
@ -259,6 +265,11 @@ export default function ServiceConfiguration() {
|
|||
stt: {
|
||||
provider: serviceProviders.stt,
|
||||
api_key: data.stt_api_key as string
|
||||
},
|
||||
embeddings: {
|
||||
provider: serviceProviders.embeddings,
|
||||
api_key: data.embeddings_api_key as string,
|
||||
model: data.embeddings_model as string
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -273,12 +284,25 @@ export default function ServiceConfiguration() {
|
|||
}
|
||||
});
|
||||
|
||||
// Build save config - only include embeddings if api_key is provided
|
||||
const saveConfig: {
|
||||
llm: Record<string, string | number>;
|
||||
tts: Record<string, string | number>;
|
||||
stt: Record<string, string | number>;
|
||||
embeddings?: Record<string, string | number>;
|
||||
} = {
|
||||
llm: userConfig.llm,
|
||||
tts: userConfig.tts,
|
||||
stt: userConfig.stt
|
||||
};
|
||||
|
||||
// Only include embeddings if user has configured it (has api_key)
|
||||
if (userConfig.embeddings.api_key) {
|
||||
saveConfig.embeddings = userConfig.embeddings;
|
||||
}
|
||||
|
||||
try {
|
||||
await saveUserConfig({
|
||||
llm: userConfig.llm,
|
||||
tts: userConfig.tts,
|
||||
stt: userConfig.stt
|
||||
});
|
||||
await saveUserConfig(saveConfig);
|
||||
setApiError(null);
|
||||
} catch (error: unknown) {
|
||||
if (error instanceof Error) {
|
||||
|
|
@ -543,7 +567,7 @@ export default function ServiceConfiguration() {
|
|||
<Card>
|
||||
<CardContent className="pt-6">
|
||||
<Tabs defaultValue="llm" className="w-full">
|
||||
<TabsList className="grid w-full grid-cols-3 mb-6">
|
||||
<TabsList className="grid w-full grid-cols-4 mb-6">
|
||||
{TAB_CONFIG.map(({ key, label }) => (
|
||||
<TabsTrigger key={key} value={key}>
|
||||
{label}
|
||||
|
|
|
|||
65
ui/src/components/flow/DocumentBadges.tsx
Normal file
65
ui/src/components/flow/DocumentBadges.tsx
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
|
||||
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
|
||||
import type { DocumentResponseSchema } from "@/client/types.gen";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
|
||||
interface DocumentBadgesProps {
|
||||
documentUuids: string[];
|
||||
onStaleUuidsDetected?: (staleUuids: string[]) => void;
|
||||
}
|
||||
|
||||
export const DocumentBadges = ({ documentUuids, onStaleUuidsDetected }: DocumentBadgesProps) => {
|
||||
const { documents } = useWorkflow();
|
||||
const [documentNames, setDocumentNames] = useState<Record<string, string>>({});
|
||||
|
||||
const processDocuments = useCallback((docs: DocumentResponseSchema[]) => {
|
||||
const nameMap: Record<string, string> = {};
|
||||
const validUuids = new Set<string>();
|
||||
|
||||
docs
|
||||
.filter((doc) => documentUuids.includes(doc.document_uuid))
|
||||
.forEach((doc) => {
|
||||
nameMap[doc.document_uuid] = doc.filename;
|
||||
validUuids.add(doc.document_uuid);
|
||||
});
|
||||
setDocumentNames(nameMap);
|
||||
|
||||
// Detect stale UUIDs - this only runs when we have loaded data (not undefined)
|
||||
if (onStaleUuidsDetected) {
|
||||
const staleUuids = documentUuids.filter(uuid => !validUuids.has(uuid));
|
||||
if (staleUuids.length > 0) {
|
||||
onStaleUuidsDetected(staleUuids);
|
||||
}
|
||||
}
|
||||
}, [documentUuids, onStaleUuidsDetected]);
|
||||
|
||||
useEffect(() => {
|
||||
if (documentUuids.length > 0 && documents !== undefined) {
|
||||
processDocuments(documents);
|
||||
} else if (documentUuids.length === 0) {
|
||||
setDocumentNames({});
|
||||
}
|
||||
}, [documentUuids, documents, processDocuments]);
|
||||
|
||||
if (documentUuids.length === 0) {
|
||||
return <></>;
|
||||
}
|
||||
|
||||
// Show loading while data hasn't loaded yet
|
||||
if (documents === undefined) {
|
||||
return <Badge variant="outline">Loading...</Badge>;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
{documentUuids.map((uuid) => (
|
||||
<Badge key={uuid} variant="outline">
|
||||
{documentNames[uuid] || uuid}
|
||||
</Badge>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
};
|
||||
137
ui/src/components/flow/DocumentSelector.tsx
Normal file
137
ui/src/components/flow/DocumentSelector.tsx
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
"use client";
|
||||
|
||||
import { FileText } from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { useMemo } from "react";
|
||||
|
||||
import type { DocumentResponseSchema } from "@/client/types.gen";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Checkbox } from "@/components/ui/checkbox";
|
||||
import { Label } from "@/components/ui/label";
|
||||
|
||||
interface DocumentSelectorProps {
|
||||
value: string[];
|
||||
onChange: (uuids: string[]) => void;
|
||||
documents: DocumentResponseSchema[];
|
||||
disabled?: boolean;
|
||||
label?: string;
|
||||
description?: string;
|
||||
showLabel?: boolean;
|
||||
}
|
||||
|
||||
export const DocumentSelector = ({
|
||||
value,
|
||||
onChange,
|
||||
documents,
|
||||
disabled = false,
|
||||
label = "Knowledge Base Documents",
|
||||
description = "Select documents that the agent can reference during conversations.",
|
||||
showLabel = true,
|
||||
}: DocumentSelectorProps) => {
|
||||
// Only show completed documents
|
||||
const completedDocuments = useMemo(
|
||||
() => documents.filter((doc) => doc.processing_status === "completed"),
|
||||
[documents]
|
||||
);
|
||||
|
||||
const handleToggle = (documentUuid: string, checked: boolean) => {
|
||||
if (checked) {
|
||||
onChange([...value, documentUuid]);
|
||||
} else {
|
||||
onChange(value.filter((uuid) => uuid !== documentUuid));
|
||||
}
|
||||
};
|
||||
|
||||
const formatFileSize = (bytes: number): string => {
|
||||
if (bytes === 0) return "0 Bytes";
|
||||
const k = 1024;
|
||||
const sizes = ["Bytes", "KB", "MB", "GB"];
|
||||
const i = Math.floor(Math.log(bytes) / Math.log(k));
|
||||
return Math.round(bytes / Math.pow(k, i) * 100) / 100 + " " + sizes[i];
|
||||
};
|
||||
|
||||
if (completedDocuments.length === 0) {
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
{showLabel && (
|
||||
<>
|
||||
<Label>{label}</Label>
|
||||
{description && (
|
||||
<Label className="text-xs text-muted-foreground">{description}</Label>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
<div className="border rounded-md p-4 space-y-3">
|
||||
<div className="text-sm text-muted-foreground text-center">
|
||||
No documents available. Upload documents to the knowledge base first.
|
||||
</div>
|
||||
<div className="flex justify-center">
|
||||
<Link href="/files">
|
||||
<Button variant="outline" size="sm">
|
||||
Upload Documents
|
||||
</Button>
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
{showLabel && (
|
||||
<>
|
||||
<Label>{label}</Label>
|
||||
{description && (
|
||||
<Label className="text-xs text-muted-foreground">{description}</Label>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
<div className="border rounded-md max-h-[300px] overflow-y-auto">
|
||||
<div className="divide-y">
|
||||
{completedDocuments.map((doc) => (
|
||||
<div
|
||||
key={doc.document_uuid}
|
||||
className="flex items-start gap-3 p-3 hover:bg-muted/50 transition-colors"
|
||||
>
|
||||
<Checkbox
|
||||
id={`doc-${doc.document_uuid}`}
|
||||
checked={value.includes(doc.document_uuid)}
|
||||
onCheckedChange={(checked) =>
|
||||
handleToggle(doc.document_uuid, checked as boolean)
|
||||
}
|
||||
disabled={disabled}
|
||||
/>
|
||||
<div className="flex-1 space-y-1">
|
||||
<label
|
||||
htmlFor={`doc-${doc.document_uuid}`}
|
||||
className="flex items-center gap-2 cursor-pointer"
|
||||
>
|
||||
<div className="w-8 h-8 rounded-md bg-blue-500/10 flex items-center justify-center flex-shrink-0">
|
||||
<FileText className="w-4 h-4 text-blue-500" />
|
||||
</div>
|
||||
<div className="flex-1 min-w-0">
|
||||
<div className="text-sm font-medium truncate">
|
||||
{doc.filename}
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{formatFileSize(doc.file_size_bytes)} • {doc.total_chunks} chunks
|
||||
</div>
|
||||
</div>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center justify-between text-xs text-muted-foreground pt-1">
|
||||
<span>
|
||||
{value.length} {value.length === 1 ? "document" : "documents"} selected
|
||||
</span>
|
||||
<Link href="/files" className="hover:underline">
|
||||
Manage Documents
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
|
@ -2,43 +2,43 @@
|
|||
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
|
||||
import { listToolsApiV1ToolsGet } from "@/client/sdk.gen";
|
||||
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
|
||||
import type { ToolResponse } from "@/client/types.gen";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { useAuth } from "@/lib/auth";
|
||||
|
||||
interface ToolBadgesProps {
|
||||
toolUuids: string[];
|
||||
onStaleUuidsDetected?: (staleUuids: string[]) => void;
|
||||
}
|
||||
|
||||
export function ToolBadges({ toolUuids }: ToolBadgesProps) {
|
||||
const { getAccessToken } = useAuth();
|
||||
const [tools, setTools] = useState<ToolResponse[]>([]);
|
||||
export function ToolBadges({ toolUuids, onStaleUuidsDetected }: ToolBadgesProps) {
|
||||
const { tools } = useWorkflow();
|
||||
const [selectedTools, setSelectedTools] = useState<ToolResponse[]>([]);
|
||||
|
||||
const fetchTools = useCallback(async () => {
|
||||
try {
|
||||
const accessToken = await getAccessToken();
|
||||
const response = await listToolsApiV1ToolsGet({
|
||||
headers: { Authorization: `Bearer ${accessToken}` },
|
||||
});
|
||||
if (response.data) {
|
||||
setTools(response.data);
|
||||
const processTools = useCallback((toolsData: ToolResponse[]) => {
|
||||
const filtered = toolsData.filter(tool => toolUuids.includes(tool.tool_uuid));
|
||||
setSelectedTools(filtered);
|
||||
|
||||
// Detect stale UUIDs - this only runs when we have loaded data (not undefined)
|
||||
if (onStaleUuidsDetected) {
|
||||
const validUuids = new Set(toolsData.map(tool => tool.tool_uuid));
|
||||
const staleUuids = toolUuids.filter(uuid => !validUuids.has(uuid));
|
||||
if (staleUuids.length > 0) {
|
||||
onStaleUuidsDetected(staleUuids);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch tools:", error);
|
||||
}
|
||||
}, [getAccessToken]);
|
||||
}, [toolUuids, onStaleUuidsDetected]);
|
||||
|
||||
useEffect(() => {
|
||||
if (toolUuids.length > 0) {
|
||||
fetchTools();
|
||||
if (toolUuids.length > 0 && tools !== undefined) {
|
||||
processTools(tools);
|
||||
} else if (toolUuids.length === 0) {
|
||||
setSelectedTools([]);
|
||||
}
|
||||
}, [toolUuids.length, fetchTools]);
|
||||
}, [toolUuids, tools, processTools]);
|
||||
|
||||
const selectedTools = tools.filter((tool) => toolUuids.includes(tool.tool_uuid));
|
||||
|
||||
if (selectedTools.length === 0 && toolUuids.length > 0) {
|
||||
// Still loading or tools not found
|
||||
// Show loading while data hasn't loaded yet
|
||||
if (tools === undefined && toolUuids.length > 0) {
|
||||
return (
|
||||
<div className="flex flex-wrap gap-1">
|
||||
<Badge variant="outline" className="text-xs">
|
||||
|
|
|
|||
|
|
@ -1,20 +1,18 @@
|
|||
"use client";
|
||||
|
||||
import { ExternalLink, Loader2 } from "lucide-react";
|
||||
import { ExternalLink } from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
|
||||
import { renderToolIcon } from "@/app/tools/config";
|
||||
import { listToolsApiV1ToolsGet } from "@/client/sdk.gen";
|
||||
import type { ToolResponse } from "@/client/types.gen";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Checkbox } from "@/components/ui/checkbox";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { useAuth } from "@/lib/auth";
|
||||
|
||||
interface ToolSelectorProps {
|
||||
value: string[];
|
||||
onChange: (uuids: string[]) => void;
|
||||
tools: ToolResponse[];
|
||||
disabled?: boolean;
|
||||
label?: string;
|
||||
description?: string;
|
||||
|
|
@ -24,43 +22,14 @@ interface ToolSelectorProps {
|
|||
export function ToolSelector({
|
||||
value,
|
||||
onChange,
|
||||
tools,
|
||||
disabled = false,
|
||||
label = "Tools",
|
||||
description = "Select tools that the agent can use during the conversation.",
|
||||
showLabel = true,
|
||||
}: ToolSelectorProps) {
|
||||
const { getAccessToken } = useAuth();
|
||||
|
||||
const [tools, setTools] = useState<ToolResponse[]>([]);
|
||||
const [loading, setLoading] = useState(false);
|
||||
|
||||
const fetchTools = useCallback(async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const accessToken = await getAccessToken();
|
||||
const response = await listToolsApiV1ToolsGet({
|
||||
headers: { Authorization: `Bearer ${accessToken}` },
|
||||
query: { status: "active" },
|
||||
});
|
||||
if (response.error) {
|
||||
console.error("Failed to fetch tools:", response.error);
|
||||
setTools([]);
|
||||
return;
|
||||
}
|
||||
if (response.data) {
|
||||
setTools(response.data);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch tools:", error);
|
||||
setTools([]);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, [getAccessToken]);
|
||||
|
||||
useEffect(() => {
|
||||
fetchTools();
|
||||
}, [fetchTools]);
|
||||
// Filter to only show active tools
|
||||
const activeTools = tools.filter((tool) => tool.status === "active");
|
||||
|
||||
const handleToggle = (toolUuid: string, checked: boolean) => {
|
||||
if (checked) {
|
||||
|
|
@ -83,12 +52,7 @@ export function ToolSelector({
|
|||
</>
|
||||
)}
|
||||
|
||||
{loading ? (
|
||||
<div className="flex items-center gap-2 p-3 border rounded-md">
|
||||
<Loader2 className="h-4 w-4 animate-spin" />
|
||||
<span className="text-sm text-muted-foreground">Loading tools...</span>
|
||||
</div>
|
||||
) : tools.length === 0 ? (
|
||||
{activeTools.length === 0 ? (
|
||||
<div className="p-4 border rounded-md text-center">
|
||||
<p className="text-sm text-muted-foreground mb-2">
|
||||
No tools available.
|
||||
|
|
@ -102,7 +66,7 @@ export function ToolSelector({
|
|||
</div>
|
||||
) : (
|
||||
<div className="border rounded-md divide-y">
|
||||
{tools.map((tool) => {
|
||||
{activeTools.map((tool) => {
|
||||
const isSelected = value.includes(tool.tool_uuid);
|
||||
return (
|
||||
<label
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
import { NodeProps, NodeToolbar, Position } from "@xyflow/react";
|
||||
import { Edit, Headset, PlusIcon, Trash2Icon, Wrench } from "lucide-react";
|
||||
import { memo, useEffect, useMemo, useState } from "react";
|
||||
import { Edit, FileText, Headset, PlusIcon, Trash2Icon, Wrench } from "lucide-react";
|
||||
import { memo, useCallback, useEffect, useMemo, useState } from "react";
|
||||
|
||||
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
|
||||
import type { DocumentResponseSchema, ToolResponse } from "@/client/types.gen";
|
||||
import { DocumentBadges } from "@/components/flow/DocumentBadges";
|
||||
import { DocumentSelector } from "@/components/flow/DocumentSelector";
|
||||
import { ToolBadges } from "@/components/flow/ToolBadges";
|
||||
import { ToolSelector } from "@/components/flow/ToolSelector";
|
||||
import { ExtractionVariable, FlowNodeData } from "@/components/flow/types";
|
||||
|
|
@ -34,6 +37,10 @@ interface AgentNodeEditFormProps {
|
|||
setAddGlobalPrompt: (value: boolean) => void;
|
||||
toolUuids: string[];
|
||||
setToolUuids: (value: string[]) => void;
|
||||
documentUuids: string[];
|
||||
setDocumentUuids: (value: string[]) => void;
|
||||
tools: ToolResponse[];
|
||||
documents: DocumentResponseSchema[];
|
||||
}
|
||||
|
||||
interface AgentNodeProps extends NodeProps {
|
||||
|
|
@ -42,7 +49,7 @@ interface AgentNodeProps extends NodeProps {
|
|||
|
||||
export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
|
||||
const { open, setOpen, handleSaveNodeData, handleDeleteNode } = useNodeHandlers({ id });
|
||||
const { saveWorkflow } = useWorkflow();
|
||||
const { saveWorkflow, tools, documents } = useWorkflow();
|
||||
|
||||
// Form state
|
||||
const [prompt, setPrompt] = useState(data.prompt);
|
||||
|
|
@ -55,6 +62,7 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
|
|||
const [variables, setVariables] = useState<ExtractionVariable[]>(data.extraction_variables ?? []);
|
||||
const [addGlobalPrompt, setAddGlobalPrompt] = useState(data.add_global_prompt ?? true);
|
||||
const [toolUuids, setToolUuids] = useState<string[]>(data.tool_uuids ?? []);
|
||||
const [documentUuids, setDocumentUuids] = useState<string[]>(data.document_uuids ?? []);
|
||||
|
||||
// Compute if form has unsaved changes (only check prompt, name)
|
||||
const isDirty = useMemo(() => {
|
||||
|
|
@ -75,6 +83,7 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
|
|||
extraction_variables: variables,
|
||||
add_global_prompt: addGlobalPrompt,
|
||||
tool_uuids: toolUuids.length > 0 ? toolUuids : undefined,
|
||||
document_uuids: documentUuids.length > 0 ? documentUuids : undefined,
|
||||
});
|
||||
setOpen(false);
|
||||
// Save the workflow after updating node data with a small delay to ensure state is updated
|
||||
|
|
@ -94,6 +103,7 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
|
|||
setVariables(data.extraction_variables ?? []);
|
||||
setAddGlobalPrompt(data.add_global_prompt ?? true);
|
||||
setToolUuids(data.tool_uuids ?? []);
|
||||
setDocumentUuids(data.document_uuids ?? []);
|
||||
}
|
||||
setOpen(newOpen);
|
||||
};
|
||||
|
|
@ -109,9 +119,34 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
|
|||
setVariables(data.extraction_variables ?? []);
|
||||
setAddGlobalPrompt(data.add_global_prompt ?? true);
|
||||
setToolUuids(data.tool_uuids ?? []);
|
||||
setDocumentUuids(data.document_uuids ?? []);
|
||||
}
|
||||
}, [data, open]);
|
||||
|
||||
// Handle cleanup of stale document UUIDs
|
||||
const handleStaleDocuments = useCallback((staleUuids: string[]) => {
|
||||
const cleanedUuids = (data.document_uuids ?? []).filter(uuid => !staleUuids.includes(uuid));
|
||||
handleSaveNodeData({
|
||||
...data,
|
||||
document_uuids: cleanedUuids.length > 0 ? cleanedUuids : undefined,
|
||||
});
|
||||
setTimeout(async () => {
|
||||
await saveWorkflow();
|
||||
}, 100);
|
||||
}, [data, handleSaveNodeData, saveWorkflow]);
|
||||
|
||||
// Handle cleanup of stale tool UUIDs
|
||||
const handleStaleTools = useCallback((staleUuids: string[]) => {
|
||||
const cleanedUuids = (data.tool_uuids ?? []).filter(uuid => !staleUuids.includes(uuid));
|
||||
handleSaveNodeData({
|
||||
...data,
|
||||
tool_uuids: cleanedUuids.length > 0 ? cleanedUuids : undefined,
|
||||
});
|
||||
setTimeout(async () => {
|
||||
await saveWorkflow();
|
||||
}, 100);
|
||||
}, [data, handleSaveNodeData, saveWorkflow]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<NodeContent
|
||||
|
|
@ -136,7 +171,16 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
|
|||
<Wrench className="h-3 w-3" />
|
||||
<span>Tools:</span>
|
||||
</div>
|
||||
<ToolBadges toolUuids={data.tool_uuids} />
|
||||
<ToolBadges toolUuids={data.tool_uuids} onStaleUuidsDetected={handleStaleTools} />
|
||||
</div>
|
||||
)}
|
||||
{data.document_uuids && data.document_uuids.length > 0 && (
|
||||
<div className="mt-3 pt-3 border-t border-border/50">
|
||||
<div className="flex items-center gap-1.5 text-xs text-muted-foreground mb-2">
|
||||
<FileText className="h-3 w-3" />
|
||||
<span>Documents:</span>
|
||||
</div>
|
||||
<DocumentBadges documentUuids={data.document_uuids} onStaleUuidsDetected={handleStaleDocuments} />
|
||||
</div>
|
||||
)}
|
||||
</NodeContent>
|
||||
|
|
@ -179,6 +223,10 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
|
|||
setAddGlobalPrompt={setAddGlobalPrompt}
|
||||
toolUuids={toolUuids}
|
||||
setToolUuids={setToolUuids}
|
||||
documentUuids={documentUuids}
|
||||
setDocumentUuids={setDocumentUuids}
|
||||
tools={tools ?? []}
|
||||
documents={documents ?? []}
|
||||
/>
|
||||
)}
|
||||
</NodeEditDialog>
|
||||
|
|
@ -203,6 +251,10 @@ const AgentNodeEditForm = ({
|
|||
setAddGlobalPrompt,
|
||||
toolUuids,
|
||||
setToolUuids,
|
||||
documentUuids,
|
||||
setDocumentUuids,
|
||||
tools,
|
||||
documents,
|
||||
}: AgentNodeEditFormProps) => {
|
||||
const handleVariableNameChange = (idx: number, value: string) => {
|
||||
const newVars = [...variables];
|
||||
|
|
@ -343,9 +395,20 @@ const AgentNodeEditForm = ({
|
|||
<ToolSelector
|
||||
value={toolUuids}
|
||||
onChange={setToolUuids}
|
||||
tools={tools}
|
||||
description="Select tools that the agent can invoke during this conversation step."
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Documents Section */}
|
||||
<div className="pt-4 border-t mt-4">
|
||||
<DocumentSelector
|
||||
value={documentUuids}
|
||||
onChange={setDocumentUuids}
|
||||
documents={documents}
|
||||
description="Select documents from the knowledge base that the agent can reference during this conversation step."
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
import { NodeProps, NodeToolbar, Position } from "@xyflow/react";
|
||||
import { Edit, Play, PlusIcon, Trash2Icon, Wrench } from "lucide-react";
|
||||
import { memo, useEffect, useMemo, useState } from "react";
|
||||
import { Edit, FileText, Play, PlusIcon, Trash2Icon, Wrench } from "lucide-react";
|
||||
import { memo, useCallback, useEffect, useMemo, useState } from "react";
|
||||
|
||||
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
|
||||
import type { DocumentResponseSchema, ToolResponse } from "@/client/types.gen";
|
||||
import { DocumentBadges } from "@/components/flow/DocumentBadges";
|
||||
import { DocumentSelector } from "@/components/flow/DocumentSelector";
|
||||
import { ToolBadges } from "@/components/flow/ToolBadges";
|
||||
import { ToolSelector } from "@/components/flow/ToolSelector";
|
||||
import { ExtractionVariable, FlowNodeData } from "@/components/flow/types";
|
||||
|
|
@ -41,6 +44,10 @@ interface StartCallEditFormProps {
|
|||
setVariables: (vars: ExtractionVariable[]) => void;
|
||||
toolUuids: string[];
|
||||
setToolUuids: (value: string[]) => void;
|
||||
documentUuids: string[];
|
||||
setDocumentUuids: (value: string[]) => void;
|
||||
tools: ToolResponse[];
|
||||
documents: DocumentResponseSchema[];
|
||||
}
|
||||
|
||||
interface StartCallNodeProps extends NodeProps {
|
||||
|
|
@ -52,7 +59,7 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
id,
|
||||
additionalData: { is_start: true }
|
||||
});
|
||||
const { saveWorkflow } = useWorkflow();
|
||||
const { saveWorkflow, tools, documents } = useWorkflow();
|
||||
|
||||
// Form state
|
||||
const [prompt, setPrompt] = useState(data.prompt ?? "");
|
||||
|
|
@ -66,6 +73,7 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
const [extractionPrompt, setExtractionPrompt] = useState(data.extraction_prompt ?? "");
|
||||
const [variables, setVariables] = useState<ExtractionVariable[]>(data.extraction_variables ?? []);
|
||||
const [toolUuids, setToolUuids] = useState<string[]>(data.tool_uuids ?? []);
|
||||
const [documentUuids, setDocumentUuids] = useState<string[]>(data.document_uuids ?? []);
|
||||
|
||||
// Compute if form has unsaved changes (only check prompt, name)
|
||||
const isDirty = useMemo(() => {
|
||||
|
|
@ -89,6 +97,7 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
extraction_prompt: extractionPrompt,
|
||||
extraction_variables: variables,
|
||||
tool_uuids: toolUuids.length > 0 ? toolUuids : undefined,
|
||||
document_uuids: documentUuids.length > 0 ? documentUuids : undefined,
|
||||
});
|
||||
setOpen(false);
|
||||
// Save the workflow after updating node data with a small delay to ensure state is updated
|
||||
|
|
@ -111,6 +120,7 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
setExtractionPrompt(data.extraction_prompt ?? "");
|
||||
setVariables(data.extraction_variables ?? []);
|
||||
setToolUuids(data.tool_uuids ?? []);
|
||||
setDocumentUuids(data.document_uuids ?? []);
|
||||
}
|
||||
setOpen(newOpen);
|
||||
};
|
||||
|
|
@ -129,9 +139,34 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
setExtractionPrompt(data.extraction_prompt ?? "");
|
||||
setVariables(data.extraction_variables ?? []);
|
||||
setToolUuids(data.tool_uuids ?? []);
|
||||
setDocumentUuids(data.document_uuids ?? []);
|
||||
}
|
||||
}, [data, open]);
|
||||
|
||||
// Handle cleanup of stale document UUIDs
|
||||
const handleStaleDocuments = useCallback((staleUuids: string[]) => {
|
||||
const cleanedUuids = (data.document_uuids ?? []).filter(uuid => !staleUuids.includes(uuid));
|
||||
handleSaveNodeData({
|
||||
...data,
|
||||
document_uuids: cleanedUuids.length > 0 ? cleanedUuids : undefined,
|
||||
});
|
||||
setTimeout(async () => {
|
||||
await saveWorkflow();
|
||||
}, 100);
|
||||
}, [data, handleSaveNodeData, saveWorkflow]);
|
||||
|
||||
// Handle cleanup of stale tool UUIDs
|
||||
const handleStaleTools = useCallback((staleUuids: string[]) => {
|
||||
const cleanedUuids = (data.tool_uuids ?? []).filter(uuid => !staleUuids.includes(uuid));
|
||||
handleSaveNodeData({
|
||||
...data,
|
||||
tool_uuids: cleanedUuids.length > 0 ? cleanedUuids : undefined,
|
||||
});
|
||||
setTimeout(async () => {
|
||||
await saveWorkflow();
|
||||
}, 100);
|
||||
}, [data, handleSaveNodeData, saveWorkflow]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<NodeContent
|
||||
|
|
@ -155,7 +190,16 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
<Wrench className="h-3 w-3" />
|
||||
<span>Tools:</span>
|
||||
</div>
|
||||
<ToolBadges toolUuids={data.tool_uuids} />
|
||||
<ToolBadges toolUuids={data.tool_uuids} onStaleUuidsDetected={handleStaleTools} />
|
||||
</div>
|
||||
)}
|
||||
{data.document_uuids && data.document_uuids.length > 0 && (
|
||||
<div className="mt-3 pt-3 border-t border-border/50">
|
||||
<div className="flex items-center gap-1.5 text-xs text-muted-foreground mb-2">
|
||||
<FileText className="h-3 w-3" />
|
||||
<span>Documents:</span>
|
||||
</div>
|
||||
<DocumentBadges documentUuids={data.document_uuids} onStaleUuidsDetected={handleStaleDocuments} />
|
||||
</div>
|
||||
)}
|
||||
</NodeContent>
|
||||
|
|
@ -199,6 +243,10 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
setVariables={setVariables}
|
||||
toolUuids={toolUuids}
|
||||
setToolUuids={setToolUuids}
|
||||
documentUuids={documentUuids}
|
||||
setDocumentUuids={setDocumentUuids}
|
||||
tools={tools ?? []}
|
||||
documents={documents ?? []}
|
||||
/>
|
||||
)}
|
||||
</NodeEditDialog>
|
||||
|
|
@ -229,6 +277,10 @@ const StartCallEditForm = ({
|
|||
setVariables,
|
||||
toolUuids,
|
||||
setToolUuids,
|
||||
documentUuids,
|
||||
setDocumentUuids,
|
||||
tools,
|
||||
documents,
|
||||
}: StartCallEditFormProps) => {
|
||||
const handleVariableNameChange = (idx: number, value: string) => {
|
||||
const newVars = [...variables];
|
||||
|
|
@ -414,9 +466,20 @@ const StartCallEditForm = ({
|
|||
<ToolSelector
|
||||
value={toolUuids}
|
||||
onChange={setToolUuids}
|
||||
tools={tools}
|
||||
description="Select tools that the agent can invoke during this conversation step."
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Documents Section */}
|
||||
<div className="pt-4 border-t mt-4">
|
||||
<DocumentSelector
|
||||
value={documentUuids}
|
||||
onChange={setDocumentUuids}
|
||||
documents={documents}
|
||||
description="Select documents from the knowledge base that the agent can reference during this conversation step."
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -42,6 +42,8 @@ export type FlowNodeData = {
|
|||
};
|
||||
// Tools - array of tool UUIDs that can be invoked by this node
|
||||
tool_uuids?: string[];
|
||||
// Documents - array of knowledge base document UUIDs that can be referenced by this node
|
||||
document_uuids?: string[];
|
||||
}
|
||||
|
||||
export type FlowNode = {
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import {
|
|||
ChevronLeft,
|
||||
ChevronRight,
|
||||
CircleDollarSign,
|
||||
Database,
|
||||
FileText,
|
||||
HelpCircle,
|
||||
Home,
|
||||
|
|
@ -114,6 +115,11 @@ export function AppSidebar() {
|
|||
url: "/tools",
|
||||
icon: Wrench,
|
||||
},
|
||||
{
|
||||
title: "Files",
|
||||
url: "/files",
|
||||
icon: Database,
|
||||
},
|
||||
// {
|
||||
// title: "Integrations",
|
||||
// url: "/integrations",
|
||||
|
|
|
|||
|
|
@ -18,6 +18,9 @@ export type SaveUserConfigFunctionParams = {
|
|||
stt?: {
|
||||
[key: string]: string | number;
|
||||
} | null;
|
||||
embeddings?: {
|
||||
[key: string]: string | number;
|
||||
} | null;
|
||||
test_phone_number?: string | null;
|
||||
timezone?: string | null;
|
||||
};
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue