feat: knowledge base functionality for the voice agent (#120)

* feat: upload file and store embedding

* feat: add documents in nodes

* feat: add openai embedding service
This commit is contained in:
Abhishek 2026-01-17 14:37:03 +05:30 committed by GitHub
parent e2fa4bbb98
commit ef5b9e40a9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
52 changed files with 4551 additions and 114 deletions

View file

@ -0,0 +1,194 @@
"""add document tables
Revision ID: dc33eef8dabe
Revises: dcb0a27d98c6
Create Date: 2026-01-16 13:40:17.808807
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from pgvector.sqlalchemy import Vector
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "dc33eef8dabe"
down_revision: Union[str, None] = "dcb0a27d98c6"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Enable pgvector extension
op.execute("CREATE EXTENSION IF NOT EXISTS vector")
sa.Enum(
"pending",
"processing",
"completed",
"failed",
name="document_processing_status",
).create(op.get_bind())
op.create_table(
"knowledge_base_documents",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("document_uuid", sa.String(length=36), nullable=False),
sa.Column("organization_id", sa.Integer(), nullable=False),
sa.Column("filename", sa.String(length=500), nullable=False),
sa.Column("file_size_bytes", sa.Integer(), nullable=True),
sa.Column("file_hash", sa.String(length=64), nullable=True),
sa.Column("mime_type", sa.String(length=100), nullable=True),
sa.Column("source_url", sa.String(), nullable=True),
sa.Column("total_chunks", sa.Integer(), nullable=False),
sa.Column(
"processing_status",
postgresql.ENUM(
"pending",
"processing",
"completed",
"failed",
name="document_processing_status",
create_type=False,
),
server_default=sa.text("'pending'::document_processing_status"),
nullable=False,
),
sa.Column("processing_error", sa.Text(), nullable=True),
sa.Column("docling_metadata", sa.JSON(), nullable=False),
sa.Column("custom_metadata", sa.JSON(), nullable=False),
sa.Column("created_by", sa.Integer(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("archived_at", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(
["created_by"],
["users.id"],
),
sa.ForeignKeyConstraint(
["organization_id"], ["organizations.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_kb_documents_created_at",
"knowledge_base_documents",
["created_at"],
unique=False,
)
op.create_index(
"ix_kb_documents_organization_id",
"knowledge_base_documents",
["organization_id"],
unique=False,
)
op.create_index(
"ix_kb_documents_status",
"knowledge_base_documents",
["processing_status"],
unique=False,
)
op.create_index(
"ix_kb_documents_uuid",
"knowledge_base_documents",
["document_uuid"],
unique=False,
)
op.create_index(
op.f("ix_knowledge_base_documents_document_uuid"),
"knowledge_base_documents",
["document_uuid"],
unique=True,
)
op.create_table(
"knowledge_base_chunks",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("document_id", sa.Integer(), nullable=False),
sa.Column("organization_id", sa.Integer(), nullable=False),
sa.Column("chunk_text", sa.Text(), nullable=False),
sa.Column("contextualized_text", sa.Text(), nullable=True),
sa.Column("chunk_index", sa.Integer(), nullable=False),
sa.Column("chunk_metadata", sa.JSON(), nullable=False),
sa.Column("embedding_model", sa.String(length=200), nullable=False),
sa.Column("embedding_dimension", sa.Integer(), nullable=False),
sa.Column("embedding", Vector(1536), nullable=True),
sa.Column("token_count", sa.Integer(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(
["document_id"], ["knowledge_base_documents.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(
["organization_id"], ["organizations.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_kb_chunks_chunk_index",
"knowledge_base_chunks",
["chunk_index"],
unique=False,
)
op.create_index(
"ix_kb_chunks_document_id",
"knowledge_base_chunks",
["document_id"],
unique=False,
)
op.create_index(
"ix_kb_chunks_embedding_ivfflat",
"knowledge_base_chunks",
["embedding"],
unique=False,
postgresql_using="ivfflat",
postgresql_with={"lists": 100},
postgresql_ops={"embedding": "vector_cosine_ops"},
)
op.create_index(
"ix_kb_chunks_organization_id",
"knowledge_base_chunks",
["organization_id"],
unique=False,
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("ix_kb_chunks_organization_id", table_name="knowledge_base_chunks")
op.drop_index(
"ix_kb_chunks_embedding_ivfflat",
table_name="knowledge_base_chunks",
postgresql_using="ivfflat",
postgresql_with={"lists": 100},
postgresql_ops={"embedding": "vector_cosine_ops"},
)
op.drop_index("ix_kb_chunks_document_id", table_name="knowledge_base_chunks")
op.drop_index("ix_kb_chunks_chunk_index", table_name="knowledge_base_chunks")
op.drop_table("knowledge_base_chunks")
op.drop_index(
op.f("ix_knowledge_base_documents_document_uuid"),
table_name="knowledge_base_documents",
)
op.drop_index("ix_kb_documents_uuid", table_name="knowledge_base_documents")
op.drop_index("ix_kb_documents_status", table_name="knowledge_base_documents")
op.drop_index(
"ix_kb_documents_organization_id", table_name="knowledge_base_documents"
)
op.drop_index("ix_kb_documents_created_at", table_name="knowledge_base_documents")
op.drop_table("knowledge_base_documents")
sa.Enum(
"pending",
"processing",
"completed",
"failed",
name="document_processing_status",
).drop(op.get_bind())
# Note: We don't drop the vector extension as it may be used by other tables
# If you want to drop it, uncomment the following line:
# op.execute('DROP EXTENSION IF EXISTS vector')
# ### end Alembic commands ###

View file

@ -3,6 +3,7 @@ from api.db.api_key_client import APIKeyClient
from api.db.campaign_client import CampaignClient
from api.db.embed_token_client import EmbedTokenClient
from api.db.integration_client import IntegrationClient
from api.db.knowledge_base_client import KnowledgeBaseClient
from api.db.looptalk_client import LoopTalkClient
from api.db.organization_client import OrganizationClient
from api.db.organization_configuration_client import OrganizationConfigurationClient
@ -33,6 +34,7 @@ class DBClient(
AgentTriggerClient,
WebhookCredentialClient,
ToolClient,
KnowledgeBaseClient,
):
"""
Unified database client that combines all specialized database operations.
@ -54,6 +56,7 @@ class DBClient(
- AgentTriggerClient: handles agent trigger operations for API-based call triggering
- WebhookCredentialClient: handles webhook credential operations
- ToolClient: handles tool operations for reusable HTTP API tools
- KnowledgeBaseClient: handles knowledge base document and vector search operations
"""
pass

View file

@ -0,0 +1,499 @@
"""Database client for managing knowledge base documents and chunks."""
import hashlib
from pathlib import Path
from typing import List, Optional
from loguru import logger
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from api.db.base_client import BaseDBClient
from api.db.models import KnowledgeBaseChunkModel, KnowledgeBaseDocumentModel
class KnowledgeBaseClient(BaseDBClient):
"""Client for managing knowledge base documents and vector embeddings."""
async def create_document(
self,
organization_id: int,
created_by: int,
filename: str,
file_size_bytes: int,
file_hash: str,
mime_type: str,
source_url: Optional[str] = None,
custom_metadata: Optional[dict] = None,
docling_metadata: Optional[dict] = None,
document_uuid: Optional[str] = None,
) -> KnowledgeBaseDocumentModel:
"""Create a new knowledge base document record.
Args:
organization_id: ID of the organization
created_by: ID of the user uploading the document
filename: Name of the file
file_size_bytes: Size of the file in bytes
file_hash: SHA-256 hash of the file
mime_type: MIME type of the file
source_url: Optional URL if document was fetched from web
custom_metadata: Optional custom metadata dictionary
docling_metadata: Optional docling processing metadata
document_uuid: Optional UUID to use (if not provided, one will be generated)
Returns:
The created KnowledgeBaseDocumentModel
"""
async with self.async_session() as session:
document = KnowledgeBaseDocumentModel(
organization_id=organization_id,
created_by=created_by,
filename=filename,
file_size_bytes=file_size_bytes,
file_hash=file_hash,
mime_type=mime_type,
source_url=source_url,
custom_metadata=custom_metadata or {},
docling_metadata=docling_metadata or {},
processing_status="pending",
total_chunks=0,
)
# Use provided UUID or let the model generate one
if document_uuid:
document.document_uuid = document_uuid
session.add(document)
await session.commit()
await session.refresh(document)
logger.info(
f"Created document '{filename}' ({document.document_uuid}) "
f"for organization {organization_id}"
)
return document
async def get_document_by_id(
self,
document_id: int,
) -> Optional[KnowledgeBaseDocumentModel]:
"""Get a document by its database ID.
Args:
document_id: The database ID of the document
Returns:
KnowledgeBaseDocumentModel if found, None otherwise
"""
async with self.async_session() as session:
query = select(KnowledgeBaseDocumentModel).where(
KnowledgeBaseDocumentModel.id == document_id
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_document_by_uuid(
self,
document_uuid: str,
organization_id: int,
) -> Optional[KnowledgeBaseDocumentModel]:
"""Get a document by its UUID, scoped to organization.
Args:
document_uuid: The unique document UUID
organization_id: ID of the organization
Returns:
KnowledgeBaseDocumentModel if found, None otherwise
"""
async with self.async_session() as session:
query = (
select(KnowledgeBaseDocumentModel)
.where(
KnowledgeBaseDocumentModel.document_uuid == document_uuid,
KnowledgeBaseDocumentModel.organization_id == organization_id,
KnowledgeBaseDocumentModel.is_active == True,
)
.options(selectinload(KnowledgeBaseDocumentModel.created_by_user))
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_document_by_hash(
self,
file_hash: str,
organization_id: int,
) -> Optional[KnowledgeBaseDocumentModel]:
"""Check if a document with the same hash already exists.
Returns the first matching document if multiple exist (can happen with duplicates).
Args:
file_hash: SHA-256 hash of the file
organization_id: ID of the organization
Returns:
KnowledgeBaseDocumentModel if found, None otherwise
"""
async with self.async_session() as session:
query = (
select(KnowledgeBaseDocumentModel)
.where(
KnowledgeBaseDocumentModel.file_hash == file_hash,
KnowledgeBaseDocumentModel.organization_id == organization_id,
KnowledgeBaseDocumentModel.is_active == True,
)
.order_by(KnowledgeBaseDocumentModel.created_at.asc())
.limit(1)
)
result = await session.execute(query)
return result.scalars().first()
async def get_documents_for_organization(
self,
organization_id: int,
processing_status: Optional[str] = None,
limit: int = 100,
offset: int = 0,
) -> List[KnowledgeBaseDocumentModel]:
"""Get all documents for an organization.
Args:
organization_id: ID of the organization
processing_status: Optional filter by status
limit: Maximum number of documents to return
offset: Number of documents to skip
Returns:
List of KnowledgeBaseDocumentModel instances
"""
async with self.async_session() as session:
query = select(KnowledgeBaseDocumentModel).where(
KnowledgeBaseDocumentModel.organization_id == organization_id,
KnowledgeBaseDocumentModel.is_active == True,
)
if processing_status:
query = query.where(
KnowledgeBaseDocumentModel.processing_status == processing_status
)
query = (
query.order_by(KnowledgeBaseDocumentModel.created_at.desc())
.limit(limit)
.offset(offset)
)
result = await session.execute(query)
return list(result.scalars().all())
async def update_document_metadata(
self,
document_id: int,
file_size_bytes: Optional[int] = None,
file_hash: Optional[str] = None,
mime_type: Optional[str] = None,
) -> Optional[KnowledgeBaseDocumentModel]:
"""Update document file metadata.
Args:
document_id: ID of the document
file_size_bytes: Optional file size in bytes
file_hash: Optional SHA-256 hash of the file
mime_type: Optional MIME type
Returns:
Updated KnowledgeBaseDocumentModel
"""
async with self.async_session() as session:
query = select(KnowledgeBaseDocumentModel).where(
KnowledgeBaseDocumentModel.id == document_id
)
result = await session.execute(query)
document = result.scalar_one_or_none()
if not document:
return None
if file_size_bytes is not None:
document.file_size_bytes = file_size_bytes
if file_hash is not None:
document.file_hash = file_hash
if mime_type is not None:
document.mime_type = mime_type
await session.commit()
await session.refresh(document)
logger.info(f"Updated document {document_id} metadata")
return document
async def update_document_status(
self,
document_id: int,
status: str,
error_message: Optional[str] = None,
total_chunks: Optional[int] = None,
docling_metadata: Optional[dict] = None,
) -> Optional[KnowledgeBaseDocumentModel]:
"""Update document processing status.
Args:
document_id: ID of the document
status: New status (pending, processing, completed, failed)
error_message: Optional error message if status is failed
total_chunks: Optional total number of chunks
docling_metadata: Optional docling metadata
Returns:
Updated KnowledgeBaseDocumentModel
"""
async with self.async_session() as session:
query = select(KnowledgeBaseDocumentModel).where(
KnowledgeBaseDocumentModel.id == document_id
)
result = await session.execute(query)
document = result.scalar_one_or_none()
if not document:
return None
document.processing_status = status
if error_message:
document.processing_error = error_message
if total_chunks is not None:
document.total_chunks = total_chunks
if docling_metadata:
document.docling_metadata = docling_metadata
await session.commit()
await session.refresh(document)
logger.info(f"Updated document {document_id} status to {status}")
return document
async def create_chunks_batch(
self,
chunks: List[KnowledgeBaseChunkModel],
) -> List[KnowledgeBaseChunkModel]:
"""Create multiple chunks in a batch.
Args:
chunks: List of KnowledgeBaseChunkModel instances
Returns:
List of created chunks with IDs
"""
async with self.async_session() as session:
session.add_all(chunks)
await session.commit()
for chunk in chunks:
await session.refresh(chunk)
logger.info(f"Created {len(chunks)} chunks")
return chunks
async def get_chunks_for_document(
self,
document_id: int,
organization_id: int,
) -> List[KnowledgeBaseChunkModel]:
"""Get all chunks for a document.
Args:
document_id: ID of the document
organization_id: ID of the organization (for authorization)
Returns:
List of KnowledgeBaseChunkModel instances
"""
async with self.async_session() as session:
query = (
select(KnowledgeBaseChunkModel)
.where(
KnowledgeBaseChunkModel.document_id == document_id,
KnowledgeBaseChunkModel.organization_id == organization_id,
)
.order_by(KnowledgeBaseChunkModel.chunk_index)
)
result = await session.execute(query)
return list(result.scalars().all())
async def search_similar_chunks(
self,
query_embedding: List[float],
organization_id: int,
limit: int = 5,
document_ids: Optional[List[int]] = None,
document_uuids: Optional[List[str]] = None,
embedding_model: Optional[str] = None,
) -> List[dict]:
"""Search for similar chunks using vector similarity.
Returns top-k most similar chunks without any similarity threshold filtering.
Filtering and reranking should be done at the application layer.
Args:
query_embedding: The query embedding vector
organization_id: Organization ID for scoping
limit: Maximum number of results to return
document_ids: Optional list of document IDs to filter by
document_uuids: Optional list of document UUIDs to filter by
embedding_model: Optional embedding model to filter by (for dimension compatibility)
Returns:
List of dictionaries with chunk data and similarity scores, ordered by similarity (highest first)
"""
async with self.async_session() as session:
# Get the raw connection to execute directly with asyncpg
# This avoids parameter binding issues with text() and asyncpg
connection = await session.connection()
raw_connection = await connection.get_raw_connection()
# Build WHERE clause conditions (no similarity threshold)
where_conditions = [
"c.organization_id = $2",
"d.is_active = true",
]
params = [
None,
organization_id,
limit,
] # $1 will be embedding_str, $3 is limit
param_index = 4 # Next available parameter index
# Add document_ids filter if provided
if document_ids:
placeholders = ", ".join(
f"${param_index + i}" for i in range(len(document_ids))
)
where_conditions.append(f"c.document_id IN ({placeholders})")
params.extend(document_ids)
param_index += len(document_ids)
# Add document_uuids filter if provided
if document_uuids:
placeholders = ", ".join(
f"${param_index + i}" for i in range(len(document_uuids))
)
where_conditions.append(f"d.document_uuid IN ({placeholders})")
params.extend(document_uuids)
param_index += len(document_uuids)
# Add embedding_model filter if provided (for dimension compatibility)
if embedding_model:
where_conditions.append(f"c.embedding_model = ${param_index}")
params.append(embedding_model)
param_index += 1
# Build the complete SQL query
where_clause = " AND ".join(where_conditions)
query_sql = f"""
SELECT
c.id,
c.document_id,
c.chunk_text,
c.contextualized_text,
c.chunk_metadata,
c.chunk_index,
d.filename,
d.document_uuid,
1 - (c.embedding <=> $1::vector) as similarity
FROM knowledge_base_chunks c
JOIN knowledge_base_documents d ON c.document_id = d.id
WHERE {where_clause}
ORDER BY c.embedding <=> $1::vector
LIMIT $3
"""
# Convert embedding to string format for PostgreSQL vector type
embedding_str = "[" + ",".join(map(str, query_embedding)) + "]"
params[0] = embedding_str # Set $1
# Execute query directly with asyncpg
rows = await raw_connection.driver_connection.fetch(
query_sql,
*params,
)
# Convert asyncpg records to dictionaries
return [dict(row) for row in rows]
async def delete_document(
self,
document_uuid: str,
organization_id: int,
) -> bool:
"""Soft delete a document by setting is_active to False.
This will also cascade delete all chunks via the database foreign key.
Args:
document_uuid: The unique document UUID
organization_id: ID of the organization (for authorization)
Returns:
True if document was deleted, False if not found
"""
async with self.async_session() as session:
query = select(KnowledgeBaseDocumentModel).where(
KnowledgeBaseDocumentModel.document_uuid == document_uuid,
KnowledgeBaseDocumentModel.organization_id == organization_id,
)
result = await session.execute(query)
document = result.scalar_one_or_none()
if not document:
return False
document.is_active = False
await session.commit()
logger.info(
f"Deleted document {document_uuid} for organization {organization_id}"
)
return True
@staticmethod
def compute_file_hash(file_path: str) -> str:
"""Compute SHA-256 hash of a file.
Args:
file_path: Path to the file
Returns:
SHA-256 hash as hex string
"""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
@staticmethod
def get_mime_type(file_path: str) -> str:
"""Get MIME type based on file extension.
Args:
file_path: Path to the file
Returns:
MIME type string
"""
extension = Path(file_path).suffix.lower()
mime_types = {
".pdf": "application/pdf",
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
".doc": "application/msword",
".txt": "text/plain",
".html": "text/html",
".md": "text/markdown",
}
return mime_types.get(extension, "application/octet-stream")

View file

@ -2,6 +2,7 @@ import uuid
from datetime import UTC, datetime
from loguru import logger
from pgvector.sqlalchemy import Vector
from sqlalchemy import (
JSON,
Boolean,
@ -14,6 +15,7 @@ from sqlalchemy import (
Integer,
String,
Table,
Text,
UniqueConstraint,
and_,
text,
@ -890,3 +892,179 @@ class ToolModel(Base):
Index("ix_tools_status", "status"),
Index("ix_tools_category", "category"),
)
class KnowledgeBaseDocumentModel(Base):
"""Model for storing document-level metadata in the knowledge base.
Each document represents a source file (PDF, DOCX, etc.) that has been
processed and chunked for retrieval.
"""
__tablename__ = "knowledge_base_documents"
id = Column(Integer, primary_key=True, index=True)
# Public identifier for API references
document_uuid = Column(
String(36),
unique=True,
nullable=False,
index=True,
default=lambda: str(uuid.uuid4()),
)
# Organization scoping
organization_id = Column(
Integer, ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False
)
# Document metadata
filename = Column(String(500), nullable=False)
file_size_bytes = Column(Integer, nullable=True)
file_hash = Column(String(64), nullable=True) # SHA-256 hash for deduplication
mime_type = Column(String(100), nullable=True)
# Processing metadata
source_url = Column(String, nullable=True) # If document was fetched from URL
total_chunks = Column(Integer, nullable=False, default=0)
processing_status = Column(
Enum(
"pending",
"processing",
"completed",
"failed",
name="document_processing_status",
),
nullable=False,
default="pending",
server_default=text("'pending'::document_processing_status"),
)
processing_error = Column(Text, nullable=True)
# Docling conversion metadata
docling_metadata = Column(
JSON, nullable=False, default=dict
) # Store docling document metadata
# Custom metadata (user-defined tags, categories, etc.)
custom_metadata = Column(JSON, nullable=False, default=dict)
# Audit fields
created_by = Column(Integer, ForeignKey("users.id"), nullable=False)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
updated_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
)
# Soft delete
is_active = Column(Boolean, default=True, nullable=False)
archived_at = Column(DateTime(timezone=True), nullable=True)
# Relationships
organization = relationship("OrganizationModel")
created_by_user = relationship("UserModel")
chunks = relationship(
"KnowledgeBaseChunkModel",
back_populates="document",
cascade="all, delete-orphan",
)
# Indexes and constraints
__table_args__ = (
Index("ix_kb_documents_organization_id", "organization_id"),
Index("ix_kb_documents_uuid", "document_uuid"),
Index("ix_kb_documents_status", "processing_status"),
Index("ix_kb_documents_created_at", "created_at"),
)
class KnowledgeBaseChunkModel(Base):
"""Model for storing document chunks with vector embeddings.
Each chunk represents a portion of a document that has been:
1. Extracted and chunked by docling's HybridChunker
2. Optionally contextualized with surrounding information
3. Embedded into a vector representation for semantic search
"""
__tablename__ = "knowledge_base_chunks"
id = Column(Integer, primary_key=True, index=True)
# Link to parent document
document_id = Column(
Integer,
ForeignKey("knowledge_base_documents.id", ondelete="CASCADE"),
nullable=False,
)
# Organization scoping (denormalized for efficient querying)
organization_id = Column(
Integer, ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False
)
# Chunk content
chunk_text = Column(Text, nullable=False) # The actual chunk text
contextualized_text = Column(
Text, nullable=True
) # Enriched text from chunker.contextualize()
# Chunk positioning and metadata
chunk_index = Column(Integer, nullable=False) # Position in document (0-based)
# Docling chunk metadata
chunk_metadata = Column(
JSON, nullable=False, default=dict
) # Store chunk.meta if available
# Embedding configuration
embedding_model = Column(
String(200), nullable=False
) # e.g., "sentence-transformers/all-MiniLM-L6-v2"
embedding_dimension = Column(
Integer, nullable=False
) # e.g., 384 for all-MiniLM-L6-v2
# Vector embedding (pgvector column)
# The dimension should match the embedding_dimension field
# Default: 1536 dimensions for OpenAI text-embedding-3-small
# SentenceTransformer (384-dim) also supported but stored as 384-dim vectors
embedding = Column(Vector(1536), nullable=True)
# Token count (useful for chunking strategy analysis)
token_count = Column(Integer, nullable=True)
# Timestamps
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
updated_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
)
# Relationships
document = relationship("KnowledgeBaseDocumentModel", back_populates="chunks")
organization = relationship("OrganizationModel")
# Indexes and constraints
__table_args__ = (
Index("ix_kb_chunks_document_id", "document_id"),
Index("ix_kb_chunks_organization_id", "organization_id"),
Index("ix_kb_chunks_chunk_index", "chunk_index"),
Index(
"ix_kb_chunks_embedding_model", "embedding_model"
), # For filtering by model
# Vector similarity search index (using IVFFlat or HNSW)
# IVFFlat is good for datasets with 10k-1M vectors
# HNSW is better for larger datasets but uses more memory
Index(
"ix_kb_chunks_embedding_ivfflat",
"embedding",
postgresql_using="ivfflat",
postgresql_with={"lists": 100}, # Adjust based on dataset size
postgresql_ops={"embedding": "vector_cosine_ops"},
),
)

View file

@ -13,3 +13,6 @@ python-multipart==0.0.20
sentry-sdk[fastapi]==2.38.0
sqlalchemy[asyncio]==2.0.43
msgpack==1.1.2
docling[rapidocr]==2.68.0
sentence-transformers==5.2.0
pgvector==0.4.2

View file

@ -0,0 +1,405 @@
"""API routes for knowledge base operations."""
import uuid
from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from loguru import logger
from api.db import db_client
from api.schemas.knowledge_base import (
ChunkSearchRequestSchema,
ChunkSearchResponseSchema,
DocumentListResponseSchema,
DocumentResponseSchema,
DocumentUploadRequestSchema,
DocumentUploadResponseSchema,
ProcessDocumentRequestSchema,
)
from api.services.auth.depends import get_user
from api.services.storage import storage_fs
from api.tasks.arq import enqueue_job
from api.tasks.function_names import FunctionNames
router = APIRouter(prefix="/knowledge-base", tags=["knowledge-base"])
@router.post(
"/upload-url",
response_model=DocumentUploadResponseSchema,
summary="Get presigned URL for document upload",
)
async def get_upload_url(
request: DocumentUploadRequestSchema,
user=Depends(get_user),
):
"""Generate a presigned PUT URL for uploading a document.
This endpoint:
1. Generates a unique document UUID for organizing the S3 key
2. Generates a presigned S3/MinIO URL for uploading the file
3. Returns the upload URL and document metadata
After uploading to the returned URL, call /process-document to create
the document record and trigger processing.
Access Control:
* All authenticated users can upload documents scoped to their organization.
"""
try:
# Generate unique document UUID for S3 organization
document_uuid = str(uuid.uuid4())
# Generate S3 key: knowledge_base/{org_id}/{document_uuid}/{filename}
s3_key = f"knowledge_base/{user.selected_organization_id}/{document_uuid}/{request.filename}"
# Generate presigned PUT URL (valid for 30 minutes)
upload_url = await storage_fs.aget_presigned_put_url(
file_path=s3_key,
expiration=1800, # 30 minutes
content_type=request.mime_type,
max_size=100_000_000, # 100MB max
)
if not upload_url:
raise HTTPException(
status_code=500, detail="Failed to generate presigned upload URL"
)
logger.info(
f"Generated upload URL for document {document_uuid}, "
f"user {user.id}, org {user.selected_organization_id}"
)
return DocumentUploadResponseSchema(
upload_url=upload_url,
document_uuid=document_uuid,
s3_key=s3_key,
)
except Exception as exc:
logger.error(f"Error generating upload URL: {exc}")
raise HTTPException(
status_code=500, detail="Failed to generate upload URL"
) from exc
@router.post(
"/process-document",
response_model=DocumentResponseSchema,
summary="Trigger document processing",
)
async def process_document(
request: ProcessDocumentRequestSchema,
user=Depends(get_user),
):
"""Trigger asynchronous processing of an uploaded document.
This endpoint should be called after successfully uploading a file to the presigned URL.
It will:
1. Create a document record in the database with the specified UUID
2. Enqueue a background task to process the document (chunking and embedding)
The document status will be updated from 'pending' -> 'processing' -> 'completed' or 'failed'.
Embedding Services:
* openai (default): High-quality 1536-dimensional embeddings (requires OPENAI_API_KEY)
* sentence_transformer: Free, offline-capable, 384-dimensional embeddings
Access Control:
* Users can only process documents in their organization.
"""
try:
# Extract filename from s3_key
filename = request.s3_key.split("/")[-1]
# Create document record with the specific UUID from upload
document = await db_client.create_document(
organization_id=user.selected_organization_id,
created_by=user.id,
filename=filename,
file_size_bytes=0, # Will be updated by background task
file_hash="", # Will be computed by background task
mime_type="application/octet-stream", # Will be detected by background task
custom_metadata={"s3_key": request.s3_key},
document_uuid=request.document_uuid, # Use UUID from upload
)
# Enqueue background task for processing
await enqueue_job(
FunctionNames.PROCESS_KNOWLEDGE_BASE_DOCUMENT,
document.id,
request.s3_key,
user.selected_organization_id,
128, # max_tokens (default)
request.embedding_service,
)
logger.info(
f"Created document {request.document_uuid} (id={document.id}) and enqueued processing "
f"with {request.embedding_service} embeddings, org {user.selected_organization_id}"
)
return DocumentResponseSchema(
id=document.id,
document_uuid=request.document_uuid,
filename=filename,
file_size_bytes=0,
file_hash="",
mime_type="application/octet-stream",
processing_status="pending",
processing_error=None,
total_chunks=0,
custom_metadata={"s3_key": request.s3_key},
docling_metadata={},
source_url=None,
created_at=document.created_at,
updated_at=document.updated_at,
organization_id=user.selected_organization_id,
created_by=user.id,
is_active=True,
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"Error processing document: {exc}")
raise HTTPException(
status_code=500, detail="Failed to process document"
) from exc
@router.get(
"/documents",
response_model=DocumentListResponseSchema,
summary="List documents",
)
async def list_documents(
status: Annotated[
Optional[str],
Query(description="Filter by processing status"),
] = None,
limit: Annotated[int, Query(ge=1, le=100)] = 100,
offset: Annotated[int, Query(ge=0)] = 0,
user=Depends(get_user),
):
"""List all documents for the user's organization.
Access Control:
* Users can only see documents from their organization.
"""
try:
documents = await db_client.get_documents_for_organization(
organization_id=user.selected_organization_id,
processing_status=status,
limit=limit,
offset=offset,
)
# Convert to response schema
document_list = [
DocumentResponseSchema(
id=doc.id,
document_uuid=doc.document_uuid,
filename=doc.filename,
file_size_bytes=doc.file_size_bytes,
file_hash=doc.file_hash,
mime_type=doc.mime_type,
processing_status=doc.processing_status,
processing_error=doc.processing_error,
total_chunks=doc.total_chunks,
custom_metadata=doc.custom_metadata,
docling_metadata=doc.docling_metadata,
source_url=doc.source_url,
created_at=doc.created_at,
updated_at=doc.updated_at,
organization_id=doc.organization_id,
created_by=doc.created_by,
is_active=doc.is_active,
)
for doc in documents
]
return DocumentListResponseSchema(
documents=document_list,
total=len(document_list),
limit=limit,
offset=offset,
)
except Exception as exc:
logger.error(f"Error listing documents: {exc}")
raise HTTPException(status_code=500, detail="Failed to list documents") from exc
@router.get(
"/documents/{document_uuid}",
response_model=DocumentResponseSchema,
summary="Get document details",
)
async def get_document(
document_uuid: str,
user=Depends(get_user),
):
"""Get details of a specific document.
Access Control:
* Users can only access documents from their organization.
"""
try:
document = await db_client.get_document_by_uuid(
document_uuid=document_uuid,
organization_id=user.selected_organization_id,
)
if not document:
raise HTTPException(status_code=404, detail="Document not found")
return DocumentResponseSchema(
id=document.id,
document_uuid=document.document_uuid,
filename=document.filename,
file_size_bytes=document.file_size_bytes,
file_hash=document.file_hash,
mime_type=document.mime_type,
processing_status=document.processing_status,
processing_error=document.processing_error,
total_chunks=document.total_chunks,
custom_metadata=document.custom_metadata,
docling_metadata=document.docling_metadata,
source_url=document.source_url,
created_at=document.created_at,
updated_at=document.updated_at,
organization_id=document.organization_id,
created_by=document.created_by,
is_active=document.is_active,
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"Error getting document: {exc}")
raise HTTPException(status_code=500, detail="Failed to get document") from exc
@router.delete(
"/documents/{document_uuid}",
summary="Delete document",
)
async def delete_document(
document_uuid: str,
user=Depends(get_user),
):
"""Soft delete a document and its chunks.
Access Control:
* Users can only delete documents from their organization.
"""
try:
success = await db_client.delete_document(
document_uuid=document_uuid,
organization_id=user.selected_organization_id,
)
if not success:
raise HTTPException(status_code=404, detail="Document not found")
logger.info(
f"Deleted document {document_uuid}, "
f"user {user.id}, org {user.selected_organization_id}"
)
return {"success": True, "message": "Document deleted successfully"}
except HTTPException:
raise
except Exception as exc:
logger.error(f"Error deleting document: {exc}")
raise HTTPException(
status_code=500, detail="Failed to delete document"
) from exc
@router.post(
"/search",
response_model=ChunkSearchResponseSchema,
summary="Search for similar chunks",
)
async def search_chunks(
request: ChunkSearchRequestSchema,
user=Depends(get_user),
):
"""Search for document chunks similar to the query.
This endpoint uses vector similarity search to find relevant chunks.
Results are returned without threshold filtering - apply similarity
thresholds at the application layer after optional reranking.
Access Control:
* Users can only search documents from their organization.
"""
try:
# Import here to avoid circular dependency
from api.services.gen_ai import OpenAIEmbeddingService
# Try to get user's embeddings configuration
user_config = await db_client.get_user_configurations(user.id)
embeddings_api_key = None
embeddings_model = None
if user_config.embeddings:
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
# Initialize embedding service with user config or fallback to env
embedding_service = OpenAIEmbeddingService(
db_client=db_client,
api_key=embeddings_api_key,
model_id=embeddings_model or "text-embedding-3-small",
)
# Perform search
results = await embedding_service.search_similar_chunks(
query=request.query,
organization_id=user.selected_organization_id,
limit=request.limit,
document_uuids=request.document_uuids,
)
# Apply similarity threshold if provided
if request.min_similarity is not None:
results = [r for r in results if r["similarity"] >= request.min_similarity]
# Convert to response schema
from api.schemas.knowledge_base import ChunkResponseSchema
chunks = [
ChunkResponseSchema(
id=r["id"],
document_id=r["document_id"],
chunk_text=r["chunk_text"],
contextualized_text=r.get("contextualized_text"),
chunk_index=r["chunk_index"],
chunk_metadata=r["chunk_metadata"],
filename=r["filename"],
document_uuid=r["document_uuid"],
similarity=r["similarity"],
)
for r in results
]
return ChunkSearchResponseSchema(
chunks=chunks,
query=request.query,
total_results=len(chunks),
)
except Exception as exc:
logger.error(f"Error searching chunks: {exc}")
raise HTTPException(status_code=500, detail="Failed to search chunks") from exc

View file

@ -4,6 +4,7 @@ from loguru import logger
from api.routes.campaign import router as campaign_router
from api.routes.credentials import router as credentials_router
from api.routes.integration import router as integration_router
from api.routes.knowledge_base import router as knowledge_base_router
from api.routes.looptalk import router as looptalk_router
from api.routes.organization import router as organization_router
from api.routes.organization_usage import router as organization_usage_router
@ -43,6 +44,7 @@ router.include_router(webrtc_signaling_router)
router.include_router(public_embed_router)
router.include_router(public_agent_router)
router.include_router(workflow_embed_router)
router.include_router(knowledge_base_router)
@router.get("/health")

View file

@ -32,6 +32,7 @@ class DefaultConfigurationsResponse(TypedDict):
llm: dict[str, dict]
tts: dict[str, dict]
stt: dict[str, dict]
embeddings: dict[str, dict]
default_providers: dict[str, str]
@ -50,6 +51,10 @@ async def get_default_configurations() -> DefaultConfigurationsResponse:
provider: model_cls.model_json_schema()
for provider, model_cls in REGISTRY[ServiceType.STT].items()
},
"embeddings": {
provider: model_cls.model_json_schema()
for provider, model_cls in REGISTRY[ServiceType.EMBEDDINGS].items()
},
"default_providers": DEFAULT_SERVICE_PROVIDERS,
}
return configurations
@ -69,6 +74,7 @@ class UserConfigurationRequestResponseSchema(BaseModel):
llm: dict[str, Union[str, float]] | None = None
tts: dict[str, Union[str, float]] | None = None
stt: dict[str, Union[str, float]] | None = None
embeddings: dict[str, Union[str, float]] | None = None
test_phone_number: str | None = None
timezone: str | None = None
organization_pricing: dict[str, Union[float, str, bool]] | None = None

View file

@ -0,0 +1,102 @@
"""Pydantic schemas for knowledge base operations."""
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
class DocumentUploadRequestSchema(BaseModel):
"""Request schema for initiating document upload."""
filename: str = Field(..., description="Name of the file to upload")
mime_type: str = Field(..., description="MIME type of the file")
custom_metadata: Optional[Dict[str, Any]] = Field(
default=None, description="Optional custom metadata"
)
class DocumentUploadResponseSchema(BaseModel):
"""Response schema containing upload URL and document metadata."""
upload_url: str = Field(..., description="Signed URL for uploading the file")
document_uuid: str = Field(..., description="Unique identifier for the document")
s3_key: str = Field(..., description="S3 key where file should be uploaded")
class ProcessDocumentRequestSchema(BaseModel):
"""Request schema for triggering document processing."""
document_uuid: str = Field(..., description="Document UUID to process")
s3_key: str = Field(..., description="S3 key of the uploaded file")
embedding_service: Literal["sentence_transformer", "openai"] = Field(
default="openai",
description="Embedding service to use for processing. "
"Options: 'openai' (default, 1536-dim, requires API key) or 'sentence_transformer' (free, 384-dim)",
)
class DocumentResponseSchema(BaseModel):
"""Response schema for document metadata."""
id: int
document_uuid: str
filename: str
file_size_bytes: int
file_hash: str
mime_type: str
processing_status: str # pending, processing, completed, failed
processing_error: Optional[str] = None
total_chunks: int
custom_metadata: Dict[str, Any]
docling_metadata: Dict[str, Any]
source_url: Optional[str] = None
created_at: datetime
updated_at: datetime
organization_id: int
created_by: int
is_active: bool
class DocumentListResponseSchema(BaseModel):
"""Response schema for list of documents."""
documents: List[DocumentResponseSchema]
total: int
limit: int
offset: int
class ChunkSearchRequestSchema(BaseModel):
"""Request schema for searching similar chunks."""
query: str = Field(..., description="Search query text")
limit: int = Field(default=5, ge=1, le=50, description="Maximum number of results")
document_uuids: Optional[List[str]] = Field(
default=None, description="Filter by specific document UUIDs"
)
min_similarity: Optional[float] = Field(
default=None, ge=0.0, le=1.0, description="Minimum similarity threshold"
)
class ChunkResponseSchema(BaseModel):
"""Response schema for a document chunk."""
id: int
document_id: int
chunk_text: str
contextualized_text: Optional[str]
chunk_index: int
chunk_metadata: Dict[str, Any]
filename: str
document_uuid: str
similarity: float
class ChunkSearchResponseSchema(BaseModel):
"""Response schema for chunk search results."""
chunks: List[ChunkResponseSchema]
query: str
total_results: int

View file

@ -3,6 +3,7 @@ from datetime import datetime
from pydantic import BaseModel
from api.services.configuration.registry import (
EmbeddingsConfig,
LLMConfig,
STTConfig,
TTSConfig,
@ -13,6 +14,7 @@ class UserConfiguration(BaseModel):
llm: LLMConfig | None = None
stt: STTConfig | None = None
tts: TTSConfig | None = None
embeddings: EmbeddingsConfig | None = None
test_phone_number: str | None = None
timezone: str | None = None
last_validated_at: datetime | None = None

View file

@ -48,6 +48,12 @@ class UserConfigurationValidator:
status_list.extend(self._validate_service(configuration.llm, "llm"))
status_list.extend(self._validate_service(configuration.stt, "stt"))
status_list.extend(self._validate_service(configuration.tts, "tts"))
# Embeddings is optional - only validate if configured
status_list.extend(
self._validate_service(
configuration.embeddings, "embeddings", required=False
)
)
if status_list:
raise ValueError(status_list)
@ -55,11 +61,16 @@ class UserConfigurationValidator:
return {"status": [{"model": "all", "message": "ok"}]}
def _validate_service(
self, service_config: Optional[ServiceConfig], service_name: str
self,
service_config: Optional[ServiceConfig],
service_name: str,
required: bool = True,
) -> list[APIKeyStatus]:
"""Validate a service configuration and return any error statuses."""
if not service_config:
return [{"model": service_name, "message": "API key is missing"}]
if required:
return [{"model": service_name, "message": "API key is missing"}]
return [] # Optional service not configured is OK
provider = service_config.provider
api_key = service_config.api_key

View file

@ -13,6 +13,7 @@ left as ``None``.
from api.services.configuration.registry import (
DeepgramSTTConfiguration,
ElevenlabsTTSConfiguration,
OpenAIEmbeddingsConfiguration,
OpenAILLMService,
ServiceProviders,
)
@ -22,6 +23,7 @@ _DEFAULTS = {
"llm": (ServiceProviders.OPENAI, OpenAILLMService),
"tts": (ServiceProviders.ELEVENLABS, ElevenlabsTTSConfiguration),
"stt": (ServiceProviders.DEEPGRAM, DeepgramSTTConfiguration),
"embeddings": (ServiceProviders.OPENAI, OpenAIEmbeddingsConfiguration),
}
# Public mapping of service name -> default provider

View file

@ -64,6 +64,7 @@ def mask_user_config(config: UserConfiguration) -> Dict[str, Any]:
"llm": _mask_service(config.llm),
"tts": _mask_service(config.tts),
"stt": _mask_service(config.stt),
"embeddings": _mask_service(config.embeddings),
"test_phone_number": config.test_phone_number,
"timezone": config.timezone,
}

View file

@ -9,7 +9,7 @@ from typing import Dict
from api.schemas.user_configuration import UserConfiguration
from api.services.configuration.masking import is_mask_of
SERVICE_FIELDS = ("llm", "tts", "stt")
SERVICE_FIELDS = ("llm", "tts", "stt", "embeddings")
def merge_user_configurations(

View file

@ -8,6 +8,7 @@ class ServiceType(Enum):
LLM = auto()
TTS = auto()
STT = auto()
EMBEDDINGS = auto()
class ServiceProviders(str, Enum):
@ -50,11 +51,16 @@ class BaseSTTConfiguration(BaseServiceConfiguration):
model: str
class BaseEmbeddingsConfiguration(BaseServiceConfiguration):
model: str
# Unified registry for all service types
REGISTRY: Dict[ServiceType, Dict[str, Type[BaseServiceConfiguration]]] = {
ServiceType.LLM: {},
ServiceType.TTS: {},
ServiceType.STT: {},
ServiceType.EMBEDDINGS: {},
}
T = TypeVar("T", bound=BaseServiceConfiguration)
@ -93,6 +99,10 @@ def register_stt(cls: Type[BaseSTTConfiguration]):
return register_service(ServiceType.STT)(cls)
def register_embeddings(cls: Type[BaseEmbeddingsConfiguration]):
return register_service(ServiceType.EMBEDDINGS)(cls)
###################################################### LLM ########################################################################
# Suggested models for each provider (used for UI dropdown)
@ -436,6 +446,27 @@ STTConfig = Annotated[
Field(discriminator="provider"),
]
ServiceConfig = Annotated[
Union[LLMConfig, TTSConfig, STTConfig], Field(discriminator="provider")
###################################################### EMBEDDINGS ########################################################################
OPENAI_EMBEDDING_MODELS = ["text-embedding-3-small"]
@register_embeddings
class OpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
provider: Literal[ServiceProviders.OPENAI] = ServiceProviders.OPENAI
model: str = Field(
default="text-embedding-3-small",
json_schema_extra={"examples": OPENAI_EMBEDDING_MODELS},
)
api_key: str
EmbeddingsConfig = Annotated[
Union[OpenAIEmbeddingsConfiguration],
Field(discriminator="provider"),
]
ServiceConfig = Annotated[
Union[LLMConfig, TTSConfig, STTConfig, EmbeddingsConfig],
Field(discriminator="provider"),
]

View file

@ -85,3 +85,16 @@ class BaseFileSystem(ABC):
Optional[str]: Presigned PUT URL if successful, None otherwise
"""
pass
@abstractmethod
async def adownload_file(self, source_path: str, local_path: str) -> bool:
"""Download a file from storage to local path.
Args:
source_path: Path to the file in storage
local_path: Local path where file should be downloaded
Returns:
bool: True if file was downloaded successfully, False otherwise
"""
pass

View file

@ -170,3 +170,15 @@ class MinioFileSystem(BaseFileSystem):
except Exception as e:
logger.error(f"Error generating MinIO upload URL: {e}")
return None
async def adownload_file(self, source_path: str, local_path: str) -> bool:
"""Download a file from MinIO to local path."""
try:
def _fget():
self.client.fget_object(self.bucket_name, source_path, local_path)
await asyncio.to_thread(_fget)
return True
except S3Error:
return False

View file

@ -126,3 +126,14 @@ class S3FileSystem(BaseFileSystem):
return url
except ClientError:
return None
async def adownload_file(self, source_path: str, local_path: str) -> bool:
"""Download a file from S3 to local path."""
try:
async with self.session.client(
"s3", region_name=self.region_name
) as s3_client:
await s3_client.download_file(self.bucket_name, source_path, local_path)
return True
except ClientError:
return False

View file

@ -0,0 +1,15 @@
"""Generative AI services for embeddings and document processing."""
from .embedding import (
BaseEmbeddingService,
EmbeddingAPIKeyNotConfiguredError,
OpenAIEmbeddingService,
SentenceTransformerEmbeddingService,
)
__all__ = [
"BaseEmbeddingService",
"EmbeddingAPIKeyNotConfiguredError",
"SentenceTransformerEmbeddingService",
"OpenAIEmbeddingService",
]

View file

@ -0,0 +1,12 @@
"""Embedding services for document processing and retrieval."""
from .base import BaseEmbeddingService
from .openai_service import EmbeddingAPIKeyNotConfiguredError, OpenAIEmbeddingService
from .sentence_transformer_service import SentenceTransformerEmbeddingService
__all__ = [
"BaseEmbeddingService",
"EmbeddingAPIKeyNotConfiguredError",
"SentenceTransformerEmbeddingService",
"OpenAIEmbeddingService",
]

View file

@ -0,0 +1,75 @@
"""Base class for embedding services."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
class BaseEmbeddingService(ABC):
"""Abstract base class for embedding services.
All embedding services (SentenceTransformer, OpenAI, etc.) should inherit from this class
and implement the required methods.
"""
@abstractmethod
def get_model_id(self) -> str:
"""Return the model identifier.
Returns:
String identifier for the model (e.g., 'sentence-transformers/all-MiniLM-L6-v2')
"""
pass
@abstractmethod
def get_embedding_dimension(self) -> int:
"""Return the embedding dimension.
Returns:
Integer dimension of the embedding vectors
"""
pass
@abstractmethod
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Embed a batch of texts.
Args:
texts: List of text strings to embed
Returns:
List of embedding vectors (each vector is a list of floats)
"""
pass
@abstractmethod
async def embed_query(self, query: str) -> List[float]:
"""Embed a single query text.
Args:
query: Query text to embed
Returns:
Embedding vector as list of floats
"""
pass
@abstractmethod
async def search_similar_chunks(
self,
query: str,
organization_id: int,
limit: int = 5,
document_uuids: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""Search for similar chunks using vector similarity.
Args:
query: Search query text
organization_id: Organization ID for scoping
limit: Maximum number of results to return
document_uuids: Optional list of document UUIDs to filter by
Returns:
List of dictionaries containing chunk data and similarity scores
"""
pass

View file

@ -0,0 +1,372 @@
"""OpenAI embedding service.
This module provides document processing capabilities using:
- OpenAI's text-embedding-3-small for embeddings (1536 dimensions)
- Docling for document conversion and chunking
- pgvector for vector similarity search
"""
import os
from pathlib import Path
from typing import Any, Dict, List, Optional
from docling.chunking import HybridChunker
from docling.document_converter import DocumentConverter
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
from loguru import logger
from openai import AsyncOpenAI
from transformers import AutoTokenizer
from api.db.db_client import DBClient
from api.db.models import KnowledgeBaseChunkModel
from .base import BaseEmbeddingService
# Model configuration
DEFAULT_MODEL_ID = "text-embedding-3-small"
EMBEDDING_DIMENSION = 1536 # Dimension for text-embedding-3-small
# For chunking, we'll use the same tokenizer as SentenceTransformer
# since OpenAI uses similar tokenization
TOKENIZER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
class EmbeddingAPIKeyNotConfiguredError(Exception):
"""Raised when OpenAI API key is not configured for embeddings."""
def __init__(self):
super().__init__(
"OpenAI API key not configured. Please set your API key in "
"Model Configurations > Embedding to use document processing."
)
class OpenAIEmbeddingService(BaseEmbeddingService):
"""Embedding service using OpenAI's text-embedding-3-small."""
def __init__(
self,
db_client: DBClient,
api_key: Optional[str] = None,
model_id: str = DEFAULT_MODEL_ID,
max_tokens: int = 512,
):
"""Initialize the OpenAI embedding service.
Args:
db_client: Database client for storing documents and chunks
api_key: OpenAI API key. If not provided, the client will not be
initialized and operations will fail with a clear error.
model_id: OpenAI embedding model ID (default: text-embedding-3-small)
max_tokens: Maximum number of tokens per chunk (default: 512)
"""
self.db = db_client
self.model_id = model_id
self.max_tokens = max_tokens
# Only initialize OpenAI client if API key is provided
self._api_key_configured = bool(api_key)
if self._api_key_configured:
self.client = AsyncOpenAI(api_key=api_key)
logger.info(f"OpenAI embedding service initialized with model: {model_id}")
else:
self.client = None
logger.warning(
"OpenAI embedding service initialized without API key. "
"Operations will fail until API key is configured in Model Configurations."
)
# Initialize tokenizer for chunking
# We use a HuggingFace tokenizer for consistent chunking
logger.info(
f"Loading tokenizer for chunking: {TOKENIZER_MODEL} with max_tokens={max_tokens}"
)
try:
self.tokenizer = HuggingFaceTokenizer(
tokenizer=AutoTokenizer.from_pretrained(
TOKENIZER_MODEL,
local_files_only=True,
),
max_tokens=max_tokens,
)
logger.info("Loaded tokenizer from cache")
except Exception as e:
logger.warning(f"Tokenizer not in cache, downloading: {e}")
self.tokenizer = HuggingFaceTokenizer(
tokenizer=AutoTokenizer.from_pretrained(TOKENIZER_MODEL),
max_tokens=max_tokens,
)
logger.info("Tokenizer downloaded and cached")
# Initialize chunker
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
self.chunker = HybridChunker(tokenizer=self.tokenizer)
# Initialize document converter
self.converter = DocumentConverter()
def get_model_id(self) -> str:
"""Return the model identifier."""
return self.model_id
def get_embedding_dimension(self) -> int:
"""Return the embedding dimension."""
return EMBEDDING_DIMENSION
def _ensure_api_key_configured(self):
"""Check if API key is configured and raise error if not."""
if not self._api_key_configured or self.client is None:
raise EmbeddingAPIKeyNotConfiguredError()
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Embed a batch of texts using OpenAI API.
Args:
texts: List of text strings to embed
Returns:
List of embedding vectors (each vector is a list of floats)
Raises:
EmbeddingAPIKeyNotConfiguredError: If API key is not configured
"""
self._ensure_api_key_configured()
try:
# OpenAI API call
response = await self.client.embeddings.create(
input=texts,
model=self.model_id,
)
# Extract embeddings from response
embeddings = [item.embedding for item in response.data]
return embeddings
except Exception as e:
logger.error(f"Error generating OpenAI embeddings: {e}")
raise
async def embed_query(self, query: str) -> List[float]:
"""Embed a single query text using OpenAI API.
Args:
query: Query text to embed
Returns:
Embedding vector as list of floats
Raises:
EmbeddingAPIKeyNotConfiguredError: If API key is not configured
"""
self._ensure_api_key_configured()
embeddings = await self.embed_texts([query])
return embeddings[0]
async def search_similar_chunks(
self,
query: str,
organization_id: int,
limit: int = 5,
document_uuids: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""Search for similar chunks using vector similarity.
Args:
query: Search query text
organization_id: Organization ID for scoping
limit: Maximum number of results to return
document_uuids: Optional list of document UUIDs to filter by
Returns:
List of dictionaries with chunk data and similarity scores
Raises:
EmbeddingAPIKeyNotConfiguredError: If API key is not configured
"""
self._ensure_api_key_configured()
# Generate query embedding
query_embedding = await self.embed_query(query)
# Perform vector similarity search
results = await self.db.search_similar_chunks(
query_embedding=query_embedding,
organization_id=organization_id,
limit=limit,
document_uuids=document_uuids,
embedding_model=self.model_id,
)
return results
async def process_document(
self,
file_path: str,
organization_id: int,
created_by: int,
custom_metadata: dict = None,
):
"""Process a document: convert, chunk, embed, and store in database.
Args:
file_path: Path to the document file
organization_id: Organization ID for scoping
created_by: User ID who uploaded the document
custom_metadata: Optional custom metadata dictionary
Returns:
The created document record
"""
try:
# Extract file metadata
filename = Path(file_path).name
file_hash = self.db.compute_file_hash(file_path)
file_size = os.path.getsize(file_path)
mime_type = self.db.get_mime_type(file_path)
# Check if document already exists
existing_doc = await self.db.get_document_by_hash(
file_hash, organization_id
)
if existing_doc:
logger.info(f"Document already exists: {filename} (hash: {file_hash})")
return existing_doc
# Create document record
doc_record = await self.db.create_document(
organization_id=organization_id,
created_by=created_by,
filename=filename,
file_size_bytes=file_size,
file_hash=file_hash,
mime_type=mime_type,
custom_metadata=custom_metadata or {},
)
logger.info(f"Processing document with OpenAI embeddings: {filename}")
# Update status to processing
await self.db.update_document_status(doc_record.id, "processing")
# Step 1: Convert document using docling
logger.info("Converting document with docling...")
conversion_result = self.converter.convert(file_path)
doc = conversion_result.document
# Store docling metadata
docling_metadata = {
"num_pages": len(doc.pages) if hasattr(doc, "pages") else None,
"document_type": type(doc).__name__,
}
# Step 2: Chunk the document
logger.info(f"Chunking document with max_tokens={self.max_tokens}...")
chunks = list(self.chunker.chunk(dl_doc=doc))
total_chunks = len(chunks)
logger.info(f"Generated {total_chunks} chunks")
# Step 3: Process each chunk
chunk_texts = []
chunk_records = []
token_counts = []
for i, chunk in enumerate(chunks):
# Get chunk text
chunk_text = chunk.text
# Get contextualized text
contextualized_text = self.chunker.contextualize(chunk=chunk)
# Calculate token count
text_to_tokenize = (
contextualized_text if contextualized_text else chunk_text
)
token_count = len(
self.tokenizer.tokenizer.encode(
text_to_tokenize, add_special_tokens=False
)
)
token_counts.append(token_count)
# Prepare chunk metadata
chunk_metadata = {}
if hasattr(chunk, "meta") and chunk.meta:
chunk_metadata = {
"doc_items": (
[str(item) for item in chunk.meta.doc_items]
if hasattr(chunk.meta, "doc_items")
else []
),
"headings": (
chunk.meta.headings
if hasattr(chunk.meta, "headings")
else []
),
}
# Create chunk record (without embedding yet)
chunk_record = KnowledgeBaseChunkModel(
document_id=doc_record.id,
organization_id=organization_id,
chunk_text=chunk_text,
contextualized_text=contextualized_text,
chunk_index=i,
chunk_metadata=chunk_metadata,
embedding_model=self.model_id,
embedding_dimension=EMBEDDING_DIMENSION,
token_count=token_count,
)
chunk_records.append(chunk_record)
chunk_texts.append(text_to_tokenize)
# Log chunk statistics
if token_counts:
avg_tokens = sum(token_counts) / len(token_counts)
min_tokens = min(token_counts)
max_tokens = max(token_counts)
logger.info("Chunk token statistics:")
logger.info(f" - Average: {avg_tokens:.1f} tokens")
logger.info(f" - Min: {min_tokens} tokens")
logger.info(f" - Max: {max_tokens} tokens")
# Step 4: Generate embeddings using OpenAI API
logger.info(f"Generating embeddings using OpenAI ({self.model_id})...")
embeddings = await self.embed_texts(chunk_texts)
# Step 5: Attach embeddings to chunk records
for chunk_record, embedding in zip(chunk_records, embeddings):
chunk_record.embedding = embedding
# Step 6: Save all chunks in batch
logger.info("Storing chunks in database...")
await self.db.create_chunks_batch(chunk_records)
# Update document status to completed
await self.db.update_document_status(
doc_record.id,
"completed",
total_chunks=total_chunks,
docling_metadata=docling_metadata,
)
logger.info(f"Successfully processed document: {filename}")
logger.info(f" - Total chunks: {total_chunks}")
logger.info(f" - Embedding model: {self.model_id}")
logger.info(f" - Document ID: {doc_record.id}")
logger.info(f" - Document UUID: {doc_record.document_uuid}")
return doc_record
except Exception as e:
logger.error(f"Error processing document with OpenAI: {e}")
# Update document status to failed if it exists
if "doc_record" in locals():
await self.db.update_document_status(
doc_record.id, "failed", error_message=str(e)
)
raise

View file

@ -0,0 +1,350 @@
"""Sentence Transformer embedding service.
This module provides document processing capabilities using:
- Sentence-transformers for embeddings (all-MiniLM-L6-v2)
- Docling for document conversion and chunking
- pgvector for vector similarity search
Setup for offline usage:
1. First run: Downloads and caches models to ~/.cache/sentence_transformers
2. Subsequent runs: Uses cached models (no internet needed)
3. For fully offline mode: Set TRANSFORMERS_OFFLINE=1 and HF_HUB_OFFLINE=1
"""
import os
from pathlib import Path
from typing import Any, Dict, List, Optional
from docling.chunking import HybridChunker
from docling.document_converter import DocumentConverter
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
from loguru import logger
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
from api.db.db_client import DBClient
from api.db.models import KnowledgeBaseChunkModel
from .base import BaseEmbeddingService
# Set environment variables for model caching
os.environ.setdefault("TRANSFORMERS_OFFLINE", "0")
os.environ.setdefault("HF_HUB_OFFLINE", "0")
os.environ.setdefault(
"SENTENCE_TRANSFORMERS_HOME", os.path.expanduser("~/.cache/sentence_transformers")
)
# Model configuration
DEFAULT_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
EMBEDDING_DIMENSION = 384 # Dimension for all-MiniLM-L6-v2
class SentenceTransformerEmbeddingService(BaseEmbeddingService):
"""Embedding service using Sentence Transformers."""
def __init__(
self,
db_client: DBClient,
model_id: str = DEFAULT_MODEL_ID,
max_tokens: int = 512,
):
"""Initialize the Sentence Transformer embedding service.
Args:
db_client: Database client for storing documents and chunks
model_id: Sentence-transformers model ID (default: all-MiniLM-L6-v2)
max_tokens: Maximum number of tokens per chunk (default: 512)
Note: This applies to the contextualized text (with headings/captions)
"""
self.db = db_client
self.model_id = model_id
self.max_tokens = max_tokens
# Initialize embedding model
logger.info(f"Loading embedding model: {model_id}")
try:
# Try to load from cache first (local_files_only=True)
self.embedding_model = SentenceTransformer(
model_id,
cache_folder=os.environ.get("SENTENCE_TRANSFORMERS_HOME"),
local_files_only=True,
)
logger.info("Loaded model from cache")
except Exception as e:
logger.warning(f"Model not in cache, downloading: {e}")
# If not in cache, download it (this will cache it for next time)
self.embedding_model = SentenceTransformer(
model_id,
cache_folder=os.environ.get("SENTENCE_TRANSFORMERS_HOME"),
)
logger.info("Model downloaded and cached")
# Initialize tokenizer for chunking with max_tokens
logger.info(f"Loading tokenizer: {model_id} with max_tokens={max_tokens}")
try:
# Try to load from cache first
self.tokenizer = HuggingFaceTokenizer(
tokenizer=AutoTokenizer.from_pretrained(
model_id,
local_files_only=True,
),
max_tokens=max_tokens,
)
logger.info("Loaded tokenizer from cache")
except Exception as e:
logger.warning(f"Tokenizer not in cache, downloading: {e}")
# If not in cache, download it
self.tokenizer = HuggingFaceTokenizer(
tokenizer=AutoTokenizer.from_pretrained(model_id),
max_tokens=max_tokens,
)
logger.info("Tokenizer downloaded and cached")
# Initialize chunker
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
self.chunker = HybridChunker(tokenizer=self.tokenizer)
# Initialize document converter
self.converter = DocumentConverter()
def get_model_id(self) -> str:
"""Return the model identifier."""
return self.model_id
def get_embedding_dimension(self) -> int:
"""Return the embedding dimension."""
return EMBEDDING_DIMENSION
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Embed a batch of texts.
Args:
texts: List of text strings to embed
Returns:
List of embedding vectors (each vector is a list of floats)
"""
embeddings = self.embedding_model.encode(
texts,
show_progress_bar=False,
convert_to_numpy=True,
)
return [embedding.tolist() for embedding in embeddings]
async def embed_query(self, query: str) -> List[float]:
"""Embed a single query text.
Args:
query: Query text to embed
Returns:
Embedding vector as list of floats
"""
embedding = self.embedding_model.encode([query])[0]
return embedding.tolist()
async def search_similar_chunks(
self,
query: str,
organization_id: int,
limit: int = 5,
document_uuids: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""Search for similar chunks using vector similarity.
Returns top-k most similar chunks without any threshold filtering.
Apply similarity thresholds and reranking at the application layer.
Args:
query: Search query text
organization_id: Organization ID for scoping
limit: Maximum number of results to return
document_uuids: Optional list of document UUIDs to filter by
Returns:
List of dictionaries with chunk data and similarity scores
"""
# Generate query embedding
query_embedding = await self.embed_query(query)
# Perform vector similarity search
results = await self.db.search_similar_chunks(
query_embedding=query_embedding,
organization_id=organization_id,
limit=limit,
document_uuids=document_uuids,
embedding_model=self.model_id,
)
return results
async def process_document(
self,
file_path: str,
organization_id: int,
created_by: int,
custom_metadata: dict = None,
):
"""Process a document: convert, chunk, embed, and store in database.
Args:
file_path: Path to the document file
organization_id: Organization ID for scoping
created_by: User ID who uploaded the document
custom_metadata: Optional custom metadata dictionary
Returns:
The created document record
"""
try:
# Extract file metadata
filename = Path(file_path).name
file_hash = self.db.compute_file_hash(file_path)
file_size = os.path.getsize(file_path)
mime_type = self.db.get_mime_type(file_path)
# Check if document already exists
existing_doc = await self.db.get_document_by_hash(
file_hash, organization_id
)
if existing_doc:
logger.info(f"Document already exists: {filename} (hash: {file_hash})")
return existing_doc
# Create document record
doc_record = await self.db.create_document(
organization_id=organization_id,
created_by=created_by,
filename=filename,
file_size_bytes=file_size,
file_hash=file_hash,
mime_type=mime_type,
custom_metadata=custom_metadata or {},
)
logger.info(f"Processing document: {filename}")
# Update status to processing
await self.db.update_document_status(doc_record.id, "processing")
# Step 1: Convert document using docling
logger.info("Converting document with docling...")
conversion_result = self.converter.convert(file_path)
doc = conversion_result.document
# Store docling metadata
docling_metadata = {
"num_pages": len(doc.pages) if hasattr(doc, "pages") else None,
"document_type": type(doc).__name__,
}
# Step 2: Chunk the document
logger.info(f"Chunking document with max_tokens={self.max_tokens}...")
chunks = list(self.chunker.chunk(dl_doc=doc))
total_chunks = len(chunks)
logger.info(f"Generated {total_chunks} chunks")
# Step 3: Process each chunk
chunk_texts = []
chunk_records = []
token_counts = []
for i, chunk in enumerate(chunks):
# Get chunk text
chunk_text = chunk.text
# Get contextualized text (enriched with surrounding context)
contextualized_text = self.chunker.contextualize(chunk=chunk)
# Calculate actual token count using the tokenizer
text_to_tokenize = (
contextualized_text if contextualized_text else chunk_text
)
token_count = len(
self.tokenizer.tokenizer.encode(
text_to_tokenize, add_special_tokens=False
)
)
token_counts.append(token_count)
# Prepare chunk metadata
chunk_metadata = {}
if hasattr(chunk, "meta") and chunk.meta:
chunk_metadata = {
"doc_items": (
[str(item) for item in chunk.meta.doc_items]
if hasattr(chunk.meta, "doc_items")
else []
),
"headings": (
chunk.meta.headings
if hasattr(chunk.meta, "headings")
else []
),
}
# Create chunk record (without embedding yet)
chunk_record = KnowledgeBaseChunkModel(
document_id=doc_record.id,
organization_id=organization_id,
chunk_text=chunk_text,
contextualized_text=contextualized_text,
chunk_index=i,
chunk_metadata=chunk_metadata,
embedding_model=self.model_id,
embedding_dimension=EMBEDDING_DIMENSION,
token_count=token_count,
)
chunk_records.append(chunk_record)
# Use contextualized text for embedding if available
chunk_texts.append(text_to_tokenize)
# Log chunk statistics
if token_counts:
avg_tokens = sum(token_counts) / len(token_counts)
min_tokens = min(token_counts)
max_tokens = max(token_counts)
logger.info("Chunk token statistics:")
logger.info(f" - Average: {avg_tokens:.1f} tokens")
logger.info(f" - Min: {min_tokens} tokens")
logger.info(f" - Max: {max_tokens} tokens")
# Step 4: Generate embeddings in batch
logger.info("Generating embeddings...")
embeddings = await self.embed_texts(chunk_texts)
# Step 5: Attach embeddings to chunk records
for chunk_record, embedding in zip(chunk_records, embeddings):
chunk_record.embedding = embedding
# Step 6: Save all chunks in batch
logger.info("Storing chunks in database...")
await self.db.create_chunks_batch(chunk_records)
# Update document status to completed
await self.db.update_document_status(
doc_record.id,
"completed",
total_chunks=total_chunks,
docling_metadata=docling_metadata,
)
logger.info(f"Successfully processed document: {filename}")
logger.info(f" - Total chunks: {total_chunks}")
logger.info(f" - Document ID: {doc_record.id}")
logger.info(f" - Document UUID: {doc_record.document_uuid}")
return doc_record
except Exception as e:
logger.error(f"Error processing document: {e}")
# Update document status to failed if it exists
if "doc_record" in locals():
await self.db.update_document_status(
doc_record.id, "failed", error_message=str(e)
)
raise

View file

@ -0,0 +1,44 @@
"""
Embeddings pricing models for different providers.
Prices are per token for embedding models.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import PricingModel
class EmbeddingPricingModel(PricingModel):
"""Pricing model for token-based embedding services."""
def __init__(self, token_price: Decimal):
"""Initialize with price per token.
Args:
token_price: Cost per token for embedding
"""
self.token_price = token_price
def calculate_cost(self, token_count: int) -> Decimal:
"""Calculate cost for embedding token usage."""
return Decimal(token_count) * self.token_price
# Embeddings pricing registry
EMBEDDINGS_PRICING: Dict[str, Dict[str, EmbeddingPricingModel]] = {
ServiceProviders.OPENAI: {
"text-embedding-3-small": EmbeddingPricingModel(
token_price=Decimal("0.02") / 1_000_000, # $0.02 per 1M tokens
),
"text-embedding-3-large": EmbeddingPricingModel(
token_price=Decimal("0.13") / 1_000_000, # $0.13 per 1M tokens
),
"text-embedding-ada-002": EmbeddingPricingModel(
token_price=Decimal("0.10") / 1_000_000, # $0.10 per 1M tokens (legacy)
),
},
}

View file

@ -4,6 +4,7 @@ Main pricing registry that combines all service type pricing models.
from typing import Dict
from .embeddings import EMBEDDINGS_PRICING
from .llm import LLM_PRICING
from .stt import STT_PRICING
from .tts import TTS_PRICING
@ -13,4 +14,5 @@ PRICING_REGISTRY: Dict = {
"llm": LLM_PRICING,
"tts": TTS_PRICING,
"stt": STT_PRICING,
"embeddings": EMBEDDINGS_PRICING,
}

View file

@ -58,6 +58,7 @@ class NodeDataDTO(BaseModel):
delayed_start: bool = False
delayed_start_duration: Optional[float] = None
tool_uuids: Optional[List[str]] = None
document_uuids: Optional[List[str]] = None
trigger_path: Optional[str] = None
# Webhook node specific fields
enabled: bool = True

View file

@ -41,6 +41,10 @@ from api.services.workflow.pipecat_engine_variable_extractor import (
VariableExtractionManager,
)
from api.services.workflow.tools.calculator import get_calculator_tools, safe_calculator
from api.services.workflow.tools.knowledge_base import (
get_knowledge_base_tool,
retrieve_from_knowledge_base,
)
from api.services.workflow.tools.timezone import (
convert_time,
get_current_time,
@ -290,6 +294,48 @@ class PipecatEngine:
self.llm.register_function("get_current_time", get_current_time_func)
self.llm.register_function("convert_time", convert_time_func)
async def _register_knowledge_base_function(
self, document_uuids: list[str]
) -> None:
"""Register knowledge base retrieval function with the LLM.
Args:
document_uuids: List of document UUIDs to filter the search by
"""
logger.debug(
f"Registering knowledge base retrieval function with {len(document_uuids)} document(s)"
)
async def retrieve_kb_func(function_call_params: FunctionCallParams) -> None:
logger.info("LLM Function Call EXECUTED: retrieve_from_knowledge_base")
logger.info(f"Arguments: {function_call_params.arguments}")
try:
query = function_call_params.arguments.get("query", "")
organization_id = await self._get_organization_id()
if not organization_id:
raise ValueError(
"Organization ID not available for knowledge base retrieval"
)
result = await retrieve_from_knowledge_base(
query=query,
organization_id=organization_id,
document_uuids=document_uuids,
limit=3, # Return top 3 most relevant chunks
)
await function_call_params.result_callback(result)
except Exception as e:
logger.error(f"Knowledge base retrieval failed: {e}")
await function_call_params.result_callback(
{"error": str(e), "chunks": [], "query": query, "total_results": 0}
)
# Register the function with the LLM
self.llm.register_function("retrieve_from_knowledge_base", retrieve_kb_func)
async def _perform_variable_extraction_if_needed(
self, previous_node: Optional[Node]
) -> None:
@ -346,6 +392,10 @@ class PipecatEngine:
if node.tool_uuids and self._custom_tool_manager:
await self._custom_tool_manager.register_handlers(node.tool_uuids)
# Register knowledge base retrieval handler if node has documents
if node.document_uuids:
await self._register_knowledge_base_function(node.document_uuids)
# Set up system message and functions
(
system_message,
@ -575,6 +625,17 @@ class PipecatEngine:
# Add built-in function schemas (calculator and timezone tools)
functions.extend(self.builtin_function_schemas)
# Add knowledge base retrieval tool if node has documents
if node.document_uuids:
kb_tool_def = get_knowledge_base_tool(node.document_uuids)
kb_schema = get_function_schema(
kb_tool_def["function"]["name"],
kb_tool_def["function"]["description"],
properties=kb_tool_def["function"]["parameters"].get("properties", {}),
required=kb_tool_def["function"]["parameters"].get("required", []),
)
functions.append(kb_schema)
# Add custom tools from node.tool_uuids
if node.tool_uuids and self._custom_tool_manager:
custom_tool_schemas = await self._custom_tool_manager.get_tool_schemas(

View file

@ -0,0 +1,305 @@
"""Knowledge Base retrieval tool for workflow execution.
This module provides vector similarity search capabilities for retrieving
relevant information from the knowledge base during conversations.
Implements OpenTelemetry tracing for observability in Langfuse.
"""
import json
from typing import Any, Dict, List, Optional
from loguru import logger
from opentelemetry import trace
from api.db import db_client
from api.services.gen_ai import OpenAIEmbeddingService
from api.services.pipecat.tracing_config import is_tracing_enabled
from pipecat.utils.tracing.context_registry import (
get_current_conversation_context,
get_current_turn_context,
)
async def retrieve_from_knowledge_base(
query: str,
organization_id: int,
document_uuids: Optional[List[str]] = None,
limit: int = 3,
embeddings_api_key: Optional[str] = None,
embeddings_model: Optional[str] = None,
) -> Dict[str, Any]:
"""Retrieve relevant information from the knowledge base using vector similarity search.
Uses OpenAI text-embedding-3-small for embeddings by default. This provides
high-quality 1536-dimensional embeddings for accurate retrieval.
This function includes OpenTelemetry tracing for Langfuse observability.
Args:
query: The search query to find relevant information
organization_id: Organization ID for scoping the search
document_uuids: Optional list of document UUIDs to filter by
limit: Maximum number of chunks to return (default: 3)
embeddings_api_key: Optional API key for embedding service
embeddings_model: Optional model ID for embedding service
Returns:
Dictionary containing:
- chunks: List of relevant text chunks with metadata
- query: The original query
- total_results: Number of results returned
"""
# Create span for retrieval operation if tracing is enabled
if is_tracing_enabled():
try:
# Get parent context from turn or conversation
turn_context = get_current_turn_context()
conversation_context = get_current_conversation_context()
parent_context = turn_context or conversation_context
# Get tracer
tracer = trace.get_tracer("pipecat")
except Exception as e:
logger.debug(f"Failed to setup tracing context: {e}")
# Fall back to non-traced execution
return await _perform_retrieval(
query,
organization_id,
document_uuids,
limit,
embeddings_api_key,
embeddings_model,
)
# Create span with parent context
if parent_context:
with tracer.start_as_current_span(
"knowledge_base_retrieval", context=parent_context
) as span:
try:
# Mark trace as public for Langfuse
span.set_attribute("langfuse.trace.public", True)
# Add operation metadata
span.set_attribute(
"gen_ai.operation.name", "knowledge_base_retrieval"
)
span.set_attribute("retrieval.query", query)
span.set_attribute("retrieval.limit", limit)
span.set_attribute("retrieval.organization_id", organization_id)
# Add document filter info
if document_uuids:
span.set_attribute(
"retrieval.document_count", len(document_uuids)
)
span.set_attribute(
"retrieval.document_uuids", json.dumps(document_uuids)
)
# Perform the actual retrieval
result = await _perform_retrieval(
query,
organization_id,
document_uuids,
limit,
embeddings_api_key,
embeddings_model,
)
# Add result metadata to span
span.set_attribute(
"retrieval.results_count", result["total_results"]
)
if result.get("error"):
span.set_attribute("retrieval.error", result["error"])
span.set_status(
trace.Status(trace.StatusCode.ERROR, result["error"])
)
else:
# Add similarity scores
if result["chunks"]:
similarities = [
chunk["similarity"] for chunk in result["chunks"]
]
span.set_attribute(
"retrieval.avg_similarity",
round(sum(similarities) / len(similarities), 4),
)
span.set_attribute(
"retrieval.max_similarity", max(similarities)
)
span.set_attribute(
"retrieval.min_similarity", min(similarities)
)
# Add retrieved documents info
filenames = list(
set(chunk["filename"] for chunk in result["chunks"])
)
span.set_attribute(
"retrieval.source_files", json.dumps(filenames)
)
# Add output as JSON for Langfuse
output_data = {
"query": query,
"chunks_retrieved": len(result["chunks"]),
"chunks": [
{
"text": chunk["text"][:200] + "..."
if len(chunk["text"]) > 200
else chunk["text"],
"filename": chunk["filename"],
"similarity": chunk["similarity"],
}
for chunk in result["chunks"]
],
}
span.set_attribute("output", json.dumps(output_data))
return result
except Exception as e:
logger.error(f"Error in traced retrieval: {e}")
span.record_exception(e)
span.set_status(trace.Status(trace.StatusCode.ERROR, str(e)))
raise
else:
# No parent context - perform retrieval without tracing
logger.debug(
"No parent context available for knowledge base retrieval tracing"
)
return await _perform_retrieval(
query,
organization_id,
document_uuids,
limit,
embeddings_api_key,
embeddings_model,
)
else:
# Tracing is disabled - perform retrieval without tracing
return await _perform_retrieval(
query,
organization_id,
document_uuids,
limit,
embeddings_api_key,
embeddings_model,
)
async def _perform_retrieval(
query: str,
organization_id: int,
document_uuids: Optional[List[str]],
limit: int,
embeddings_api_key: Optional[str] = None,
embeddings_model: Optional[str] = None,
) -> Dict[str, Any]:
"""Internal function to perform the actual retrieval operation.
Separated from tracing logic for cleaner code organization.
Uses OpenAI embeddings by default for high-quality retrieval.
"""
try:
# Create a new embedding service instance
# Uses OpenAI text-embedding-3-small by default, or user-provided config
embedding_service = OpenAIEmbeddingService(
db_client=db_client,
max_tokens=128, # This is only used for chunking, not for retrieval
api_key=embeddings_api_key,
model_id=embeddings_model or "text-embedding-3-small",
)
# Perform vector similarity search
results = await embedding_service.search_similar_chunks(
query=query,
organization_id=organization_id,
limit=limit,
document_uuids=document_uuids,
)
# Format results for LLM consumption
chunks = []
for result in results:
chunk_info = {
"text": result.get("contextualized_text") or result.get("chunk_text"),
"filename": result.get("filename"),
"similarity": round(result.get("similarity", 0), 4),
"chunk_index": result.get("chunk_index"),
}
chunks.append(chunk_info)
logger.info(
f"Knowledge base retrieval: query='{query}', "
f"results={len(chunks)}, "
f"document_filter={document_uuids}"
)
return {
"chunks": chunks,
"query": query,
"total_results": len(chunks),
}
except Exception as e:
logger.error(f"Error retrieving from knowledge base: {e}")
return {
"error": str(e),
"chunks": [],
"query": query,
"total_results": 0,
}
def get_knowledge_base_tool(
document_uuids: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""Get knowledge base retrieval tool definition for LLM function calling.
Args:
document_uuids: Optional list of document UUIDs to include in description
Returns:
Tool definition compatible with LLM function calling
"""
# Build description based on whether specific documents are filtered
if document_uuids and len(document_uuids) > 0:
description = (
"Retrieve relevant information from specific documents in the knowledge base. "
"Use this tool when you need to look up facts, policies, procedures, or any information "
"that might be stored in the available documents. The search will only look in the "
f"documents associated with this conversation step ({len(document_uuids)} document(s) available)."
)
else:
description = (
"Retrieve relevant information from the knowledge base. "
"Use this tool when you need to look up facts, policies, procedures, or any information "
"that might be stored in the knowledge base documents."
)
return {
"type": "function",
"function": {
"name": "retrieve_from_knowledge_base",
"description": description,
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": (
"The search query to find relevant information. "
"Be specific and use natural language. "
"Example: 'What is the refund policy for canceled orders?'"
),
}
},
"required": ["query"],
},
},
}

View file

@ -48,6 +48,7 @@ class Node:
self.delayed_start = data.delayed_start
self.delayed_start_duration = data.delayed_start_duration
self.tool_uuids = data.tool_uuids
self.document_uuids = data.document_uuids
self.data = data
@ -189,16 +190,6 @@ class WorkflowGraph:
in_d, out_d = in_deg[n.id], out_deg[n.id]
match n.node_type:
case NodeType.startNode:
if in_d != 0 or out_d < 1:
errors.append(
WorkflowError(
kind=ItemKind.node,
id=n.id,
field=None,
message=f"StartNode must have at least 1 outgoing edge",
)
)
case NodeType.endNode:
if in_d < 1 or out_d != 0:
errors.append(

View file

@ -46,6 +46,7 @@ from api.tasks.campaign_tasks import (
process_campaign_batch,
sync_campaign_source,
)
from api.tasks.knowledge_base_processing import process_knowledge_base_document
from api.tasks.run_integrations import run_integrations_post_workflow_run
from api.tasks.s3_upload import (
upload_audio_to_s3,
@ -64,6 +65,7 @@ class WorkerSettings:
sync_campaign_source,
process_campaign_batch,
monitor_campaign_progress,
process_knowledge_base_document,
]
cron_jobs = []
redis_settings = REDIS_SETTINGS

View file

@ -7,3 +7,4 @@ class FunctionNames:
SYNC_CAMPAIGN_SOURCE = "sync_campaign_source"
PROCESS_CAMPAIGN_BATCH = "process_campaign_batch"
MONITOR_CAMPAIGN_PROGRESS = "monitor_campaign_progress"
PROCESS_KNOWLEDGE_BASE_DOCUMENT = "process_knowledge_base_document"

View file

@ -0,0 +1,311 @@
"""ARQ background task for processing knowledge base documents."""
import os
import tempfile
from typing import Literal
from docling.chunking import HybridChunker
from docling.document_converter import DocumentConverter
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
from loguru import logger
from transformers import AutoTokenizer
from api.db import db_client
from api.db.models import KnowledgeBaseChunkModel
from api.services.gen_ai import (
OpenAIEmbeddingService,
SentenceTransformerEmbeddingService,
)
from api.services.storage import storage_fs
# For tokenization/chunking - use SentenceTransformer tokenizer as baseline
TOKENIZER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
async def process_knowledge_base_document(
ctx,
document_id: int,
s3_key: str,
organization_id: int,
max_tokens: int = 128,
embedding_service: Literal["sentence_transformer", "openai"] = "openai",
):
"""Process a knowledge base document: download, chunk, embed, and store.
Args:
ctx: ARQ context
document_id: Database ID of the document
s3_key: S3 key where the file is stored
organization_id: Organization ID
max_tokens: Maximum number of tokens per chunk (default: 128)
embedding_service: Embedding service to use (default: "openai")
- "openai": Use OpenAI text-embedding-3-small (1536-dim, requires API key)
- "sentence_transformer": Use SentenceTransformer (all-MiniLM-L6-v2, 384-dim, free)
"""
logger.info(
f"Starting knowledge base document processing for document_id={document_id}, "
f"s3_key={s3_key}, organization_id={organization_id}"
)
temp_file_path = None
try:
# Update status to processing
await db_client.update_document_status(document_id, "processing")
# Extract file extension from S3 key
filename = s3_key.split("/")[-1]
file_extension = (
os.path.splitext(filename)[1] or ".bin"
) # Default to .bin if no extension
# Create temp file for download with correct extension
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=file_extension)
temp_file_path = temp_file.name
temp_file.close()
# Download file from S3
logger.info(f"Downloading file from S3: {s3_key}")
download_success = await storage_fs.adownload_file(s3_key, temp_file_path)
if not download_success:
raise Exception(f"Failed to download file from S3: {s3_key}")
if not os.path.exists(temp_file_path):
raise FileNotFoundError(f"Downloaded file not found: {temp_file_path}")
file_size = os.path.getsize(temp_file_path)
logger.info(f"Downloaded file size: {file_size} bytes")
# Compute file hash and get mime type
file_hash = db_client.compute_file_hash(temp_file_path)
mime_type = db_client.get_mime_type(temp_file_path)
filename = s3_key.split("/")[-1]
# Get document record
document = await db_client.get_document_by_id(document_id)
if not document:
raise Exception(f"Document {document_id} not found")
# Check if a document with this hash already exists (reject duplicates)
existing_doc = await db_client.get_document_by_hash(file_hash, organization_id)
if existing_doc and existing_doc.id != document_id:
error_message = (
f"This file is a duplicate of '{existing_doc.filename}'. "
f"Please delete the duplicate files and consolidate them into a single unique file before uploading."
)
logger.warning(
f"Duplicate document detected: {document_id} is duplicate of {existing_doc.id} "
f"({existing_doc.filename})"
)
# Update file metadata
await db_client.update_document_metadata(
document_id,
file_size_bytes=file_size,
file_hash=file_hash,
mime_type=mime_type,
)
# Mark as failed with duplicate error message
await db_client.update_document_status(
document_id,
"failed",
error_message=error_message,
docling_metadata={
"duplicate_of": existing_doc.document_uuid,
"duplicate_filename": existing_doc.filename,
},
)
return
# Update document with file metadata
await db_client.update_document_metadata(
document_id,
file_size_bytes=file_size,
file_hash=file_hash,
mime_type=mime_type,
)
# Initialize the embedding service based on the parameter
if embedding_service == "openai":
logger.info(
f"Initializing OpenAI embedding service with max_tokens={max_tokens}"
)
# Try to get user's embeddings configuration
embeddings_api_key = None
embeddings_model = None
if document.created_by:
user_config = await db_client.get_user_configurations(
document.created_by
)
if user_config.embeddings:
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
logger.info(
f"Using user embeddings config: model={embeddings_model}"
)
# Check if API key is configured
if not embeddings_api_key:
error_message = (
"OpenAI API key not configured. Please set your API key in "
"Model Configurations > Embedding to process documents."
)
logger.warning(f"Document {document_id}: {error_message}")
await db_client.update_document_status(
document_id, "failed", error_message=error_message
)
return
service = OpenAIEmbeddingService(
db_client=db_client,
max_tokens=max_tokens,
api_key=embeddings_api_key,
model_id=embeddings_model or "text-embedding-3-small",
)
elif embedding_service == "sentence_transformer":
logger.info(
f"Initializing SentenceTransformer embedding service with max_tokens={max_tokens}"
)
service = SentenceTransformerEmbeddingService(
db_client=db_client,
max_tokens=max_tokens,
)
else:
raise ValueError(
f"Invalid embedding_service: {embedding_service}. "
f"Must be 'sentence_transformer' or 'openai'"
)
# Step 1: Convert document with docling
logger.info("Converting document with docling")
converter = DocumentConverter()
conversion_result = converter.convert(temp_file_path)
doc = conversion_result.document
# Store docling metadata
docling_metadata = {
"num_pages": len(doc.pages) if hasattr(doc, "pages") else None,
"document_type": type(doc).__name__,
}
# Step 2: Initialize tokenizer for chunking
logger.info(
f"Loading tokenizer: {TOKENIZER_MODEL} with max_tokens={max_tokens}"
)
tokenizer = HuggingFaceTokenizer(
tokenizer=AutoTokenizer.from_pretrained(TOKENIZER_MODEL),
max_tokens=max_tokens,
)
# Step 3: Initialize chunker
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
chunker = HybridChunker(tokenizer=tokenizer)
# Step 4: Chunk the document
logger.info(f"Chunking document with max_tokens={max_tokens}")
chunks = list(chunker.chunk(dl_doc=doc))
total_chunks = len(chunks)
logger.info(f"Generated {total_chunks} chunks")
# Step 5: Process each chunk
chunk_texts = []
chunk_records = []
token_counts = []
for i, chunk in enumerate(chunks):
chunk_text = chunk.text
contextualized_text = chunker.contextualize(chunk=chunk)
# Calculate token count
text_to_tokenize = (
contextualized_text if contextualized_text else chunk_text
)
token_count = len(
tokenizer.tokenizer.encode(text_to_tokenize, add_special_tokens=False)
)
token_counts.append(token_count)
# Prepare chunk metadata
chunk_metadata = {}
if hasattr(chunk, "meta") and chunk.meta:
chunk_metadata = {
"doc_items": (
[str(item) for item in chunk.meta.doc_items]
if hasattr(chunk.meta, "doc_items")
else []
),
"headings": (
chunk.meta.headings if hasattr(chunk.meta, "headings") else []
),
}
# Create chunk record (without embedding yet)
chunk_record = KnowledgeBaseChunkModel(
document_id=document_id,
organization_id=organization_id,
chunk_text=chunk_text,
contextualized_text=contextualized_text,
chunk_index=i,
chunk_metadata=chunk_metadata,
embedding_model=service.get_model_id(),
embedding_dimension=service.get_embedding_dimension(),
token_count=token_count,
)
chunk_records.append(chunk_record)
chunk_texts.append(text_to_tokenize)
# Log chunk statistics
if token_counts:
avg_tokens = sum(token_counts) / len(token_counts)
min_tokens = min(token_counts)
max_tokens_actual = max(token_counts)
logger.info("Chunk token statistics:")
logger.info(f" - Average: {avg_tokens:.1f} tokens")
logger.info(f" - Min: {min_tokens} tokens")
logger.info(f" - Max: {max_tokens_actual} tokens")
# Step 6: Generate embeddings using the embedding service
logger.info(f"Generating embeddings using {embedding_service}")
embeddings = await service.embed_texts(chunk_texts)
# Step 7: Attach embeddings to chunk records
for chunk_record, embedding in zip(chunk_records, embeddings):
chunk_record.embedding = embedding
# Step 8: Save chunks in database
logger.info("Storing chunks in database")
await db_client.create_chunks_batch(chunk_records)
# Step 9: Update document status to completed
await db_client.update_document_status(
document_id,
"completed",
total_chunks=total_chunks,
docling_metadata=docling_metadata,
)
logger.info(
f"Successfully processed knowledge base document {document_id}. "
f"Total chunks: {total_chunks}"
)
except Exception as e:
logger.error(
f"Error processing knowledge base document {document_id}: {e}",
exc_info=True,
)
# Update document status to failed
await db_client.update_document_status(
document_id, "failed", error_message=str(e)
)
raise
finally:
# Clean up temp file
if temp_file_path and os.path.exists(temp_file_path):
try:
os.remove(temp_file_path)
logger.debug(f"Cleaned up temp file: {temp_file_path}")
except Exception as e:
logger.warning(f"Failed to clean up temp file {temp_file_path}: {e}")

View file

@ -1,6 +1,6 @@
services:
postgres:
image: postgres:17
image: pgvector/pgvector:pg17
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres

View file

@ -0,0 +1,249 @@
'use client';
import { FileText, RefreshCw, Search, Trash2 } from 'lucide-react';
import { useCallback, useEffect, useState } from 'react';
import { toast } from 'sonner';
import {
deleteDocumentApiV1KnowledgeBaseDocumentsDocumentUuidDelete,
listDocumentsApiV1KnowledgeBaseDocumentsGet,
} from '@/client/sdk.gen';
import type { DocumentResponseSchema } from '@/client/types.gen';
import { Badge } from '@/components/ui/badge';
import { Button } from '@/components/ui/button';
import { Input } from '@/components/ui/input';
import { Skeleton } from '@/components/ui/skeleton';
import logger from '@/lib/logger';
interface DocumentListProps {
accessToken: string;
refreshTrigger: number;
}
export default function DocumentList({ accessToken, refreshTrigger }: DocumentListProps) {
const [documents, setDocuments] = useState<DocumentResponseSchema[]>([]);
const [isLoading, setIsLoading] = useState(true);
const [searchQuery, setSearchQuery] = useState('');
const [error, setError] = useState<string | null>(null);
const fetchDocuments = useCallback(async () => {
if (!accessToken) return;
try {
setIsLoading(true);
setError(null);
const response = await listDocumentsApiV1KnowledgeBaseDocumentsGet({
headers: {
'Authorization': `Bearer ${accessToken}`,
},
query: {
limit: 100,
offset: 0,
},
});
if (response.error || !response.data) {
throw new Error('Failed to fetch documents');
}
setDocuments(response.data.documents);
} catch (err) {
setError(err instanceof Error ? err.message : 'Failed to fetch documents');
logger.error('Error fetching documents:', err);
} finally {
setIsLoading(false);
}
}, [accessToken]);
// Fetch documents on mount and when refreshTrigger changes
useEffect(() => {
fetchDocuments();
}, [fetchDocuments, refreshTrigger]);
// Poll for documents that are processing
useEffect(() => {
const processingDocs = documents.filter(
(doc) => doc.processing_status === 'processing' || doc.processing_status === 'pending'
);
if (processingDocs.length === 0) return;
const pollInterval = setInterval(() => {
logger.info(`Polling for ${processingDocs.length} processing documents...`);
fetchDocuments();
}, 5000); // Poll every 5 seconds
return () => clearInterval(pollInterval);
}, [documents, fetchDocuments]);
const handleDelete = async (documentUuid: string, filename: string) => {
if (!confirm(`Are you sure you want to delete "${filename}"?`)) return;
try {
const response = await deleteDocumentApiV1KnowledgeBaseDocumentsDocumentUuidDelete({
path: {
document_uuid: documentUuid,
},
headers: {
'Authorization': `Bearer ${accessToken}`,
},
});
if (response.error) {
throw new Error('Failed to delete document');
}
toast.success(`Deleted "${filename}"`);
fetchDocuments();
} catch (err) {
toast.error(err instanceof Error ? err.message : 'Failed to delete document');
logger.error('Error deleting document:', err);
}
};
const getStatusBadge = (status: string) => {
switch (status) {
case 'completed':
return <Badge className="bg-green-500">Completed</Badge>;
case 'processing':
return (
<Badge variant="secondary" className="animate-pulse">
Processing
</Badge>
);
case 'pending':
return <Badge variant="outline">Pending</Badge>;
case 'failed':
return <Badge variant="destructive">Failed</Badge>;
default:
return <Badge variant="outline">{status}</Badge>;
}
};
const formatFileSize = (bytes: number): string => {
if (bytes === 0) return '0 B';
const k = 1024;
const sizes = ['B', 'KB', 'MB', 'GB'];
const i = Math.floor(Math.log(bytes) / Math.log(k));
return `${parseFloat((bytes / Math.pow(k, i)).toFixed(2))} ${sizes[i]}`;
};
const formatDate = (dateString: string): string => {
const date = new Date(dateString);
return date.toLocaleDateString() + ' ' + date.toLocaleTimeString();
};
const filteredDocuments = documents.filter((doc) =>
doc.filename.toLowerCase().includes(searchQuery.toLowerCase())
);
if (isLoading && documents.length === 0) {
return (
<div className="space-y-4">
{[1, 2, 3].map((i) => (
<div key={i} className="flex items-center justify-between p-4 border rounded-lg">
<div className="space-y-2 flex-1">
<Skeleton className="h-4 w-48" />
<Skeleton className="h-3 w-64" />
</div>
<Skeleton className="h-8 w-24" />
</div>
))}
</div>
);
}
if (error) {
return (
<div className="p-4 bg-destructive/10 border border-destructive/20 rounded-lg text-destructive">
{error}
</div>
);
}
return (
<div className="space-y-4">
{/* Search and Refresh */}
<div className="flex items-center gap-4">
<div className="relative flex-1">
<Search className="absolute left-3 top-1/2 transform -translate-y-1/2 h-4 w-4 text-muted-foreground" />
<Input
placeholder="Search documents..."
value={searchQuery}
onChange={(e) => setSearchQuery(e.target.value)}
className="pl-10"
/>
</div>
<Button
variant="outline"
size="icon"
onClick={fetchDocuments}
disabled={isLoading}
>
<RefreshCw className={`h-4 w-4 ${isLoading ? 'animate-spin' : ''}`} />
</Button>
</div>
{/* Document List */}
{filteredDocuments.length === 0 ? (
<div className="text-center py-12">
<FileText className="w-12 h-12 text-muted-foreground mx-auto mb-4" />
<p className="text-muted-foreground">
{searchQuery
? 'No documents match your search'
: 'No documents uploaded yet'}
</p>
</div>
) : (
<div className="space-y-3">
{filteredDocuments.map((doc) => (
<div
key={doc.document_uuid}
className="flex items-center justify-between p-4 border rounded-lg hover:bg-muted/50 transition-colors"
>
<div className="flex items-center gap-4 flex-1">
<div className="w-10 h-10 rounded-lg bg-primary/10 flex items-center justify-center">
<FileText className="w-5 h-5 text-primary" />
</div>
<div className="flex-1 min-w-0">
<div className="flex items-center gap-2 mb-1">
<span className="font-medium truncate">{doc.filename}</span>
{getStatusBadge(doc.processing_status)}
</div>
<div className="flex items-center gap-4 text-sm text-muted-foreground">
<span>{formatFileSize(doc.file_size_bytes)}</span>
{doc.processing_status === 'completed' && (
<span>{doc.total_chunks} chunks</span>
)}
<span>{formatDate(doc.created_at)}</span>
</div>
{doc.processing_error && (
<p className="text-xs text-destructive mt-1">
Error: {doc.processing_error}
</p>
)}
{doc.docling_metadata &&
typeof doc.docling_metadata === 'object' &&
'duplicate_of' in doc.docling_metadata && (
<p className="text-xs text-muted-foreground mt-1">
Duplicate of another document
</p>
)}
</div>
</div>
<Button
variant="ghost"
size="sm"
onClick={() => handleDelete(doc.document_uuid, doc.filename)}
className="text-destructive hover:text-destructive/90"
>
<Trash2 className="w-4 h-4" />
</Button>
</div>
))}
</div>
)}
</div>
);
}

View file

@ -0,0 +1,220 @@
'use client';
import { Upload } from 'lucide-react';
import { useRef, useState } from 'react';
import { toast } from 'sonner';
import {
getUploadUrlApiV1KnowledgeBaseUploadUrlPost,
processDocumentApiV1KnowledgeBaseProcessDocumentPost,
} from '@/client/sdk.gen';
import type { DocumentUploadResponseSchema } from '@/client/types.gen';
import { Button } from '@/components/ui/button';
import { Progress } from '@/components/ui/progress';
import logger from '@/lib/logger';
interface DocumentUploadProps {
accessToken: string;
onUploadSuccess: () => void;
}
const MAX_FILE_SIZE = 100 * 1024 * 1024; // 100MB
const ACCEPTED_FILE_TYPES = ['.pdf', '.docx', '.doc', '.txt'];
export default function DocumentUpload({ accessToken, onUploadSuccess }: DocumentUploadProps) {
const [uploading, setUploading] = useState(false);
const [uploadProgress, setUploadProgress] = useState(0);
const [dragActive, setDragActive] = useState(false);
const fileInputRef = useRef<HTMLInputElement>(null);
const validateFile = (file: File): boolean => {
// Validate file type
const fileExtension = '.' + file.name.split('.').pop()?.toLowerCase();
if (!ACCEPTED_FILE_TYPES.includes(fileExtension)) {
toast.error(`Please select a supported file type: ${ACCEPTED_FILE_TYPES.join(', ')}`);
return false;
}
// Validate file size
if (file.size > MAX_FILE_SIZE) {
toast.error('File size must be less than 100MB');
return false;
}
return true;
};
const uploadFile = async (file: File) => {
if (!validateFile(file)) return;
setUploading(true);
setUploadProgress(0);
try {
// Step 1: Request presigned upload URL
logger.info('Requesting presigned upload URL for:', file.name);
const uploadUrlResponse = await getUploadUrlApiV1KnowledgeBaseUploadUrlPost({
body: {
filename: file.name,
mime_type: file.type || 'application/octet-stream',
custom_metadata: {
original_filename: file.name,
uploaded_at: new Date().toISOString(),
},
},
headers: {
'Authorization': `Bearer ${accessToken}`,
},
});
if (uploadUrlResponse.error || !uploadUrlResponse.data) {
throw new Error('Failed to get upload URL');
}
const uploadData: DocumentUploadResponseSchema = uploadUrlResponse.data;
logger.info('Received presigned URL, uploading file...');
setUploadProgress(25);
// Step 2: Upload file directly to S3/MinIO using PUT
const uploadResponse = await fetch(uploadData.upload_url, {
method: 'PUT',
body: file,
headers: {
'Content-Type': file.type || 'application/octet-stream',
},
});
if (!uploadResponse.ok) {
throw new Error('Failed to upload file to storage');
}
setUploadProgress(75);
logger.info('File uploaded successfully, triggering processing...');
// Step 3: Trigger document processing
const processResponse = await processDocumentApiV1KnowledgeBaseProcessDocumentPost({
body: {
document_uuid: uploadData.document_uuid,
s3_key: uploadData.s3_key,
},
headers: {
'Authorization': `Bearer ${accessToken}`,
},
});
if (processResponse.error) {
throw new Error('Failed to trigger processing');
}
setUploadProgress(100);
logger.info('Document processing triggered successfully');
toast.success(`File uploaded: ${file.name}. Processing started.`);
onUploadSuccess();
} catch (error) {
logger.error('Error uploading document:', error);
toast.error(error instanceof Error ? error.message : 'Failed to upload document');
} finally {
setUploading(false);
setUploadProgress(0);
// Reset file input
if (fileInputRef.current) {
fileInputRef.current.value = '';
}
}
};
const handleFileSelect = async (event: React.ChangeEvent<HTMLInputElement>) => {
const file = event.target.files?.[0];
if (file) {
await uploadFile(file);
}
};
const handleDrag = (e: React.DragEvent) => {
e.preventDefault();
e.stopPropagation();
if (e.type === 'dragenter' || e.type === 'dragover') {
setDragActive(true);
} else if (e.type === 'dragleave') {
setDragActive(false);
}
};
const handleDrop = async (e: React.DragEvent) => {
e.preventDefault();
e.stopPropagation();
setDragActive(false);
const file = e.dataTransfer.files?.[0];
if (file) {
await uploadFile(file);
}
};
const handleButtonClick = () => {
fileInputRef.current?.click();
};
return (
<div className="space-y-4">
<input
ref={fileInputRef}
type="file"
accept={ACCEPTED_FILE_TYPES.join(',')}
onChange={handleFileSelect}
className="hidden"
disabled={uploading}
/>
{/* Drag and Drop Area */}
<div
className={`
border-2 border-dashed rounded-lg p-8 text-center transition-colors
${dragActive ? 'border-primary bg-primary/5' : 'border-muted-foreground/25'}
${uploading ? 'opacity-50 pointer-events-none' : 'cursor-pointer hover:border-primary hover:bg-muted/50'}
`}
onDragEnter={handleDrag}
onDragLeave={handleDrag}
onDragOver={handleDrag}
onDrop={handleDrop}
onClick={handleButtonClick}
>
<Upload className="w-12 h-12 mx-auto mb-4 text-muted-foreground" />
<p className="text-lg font-medium mb-2">
{uploading ? 'Uploading...' : 'Drop your document here'}
</p>
<p className="text-sm text-muted-foreground mb-4">
or click to browse
</p>
<p className="text-xs text-muted-foreground">
Supported formats: {ACCEPTED_FILE_TYPES.join(', ')} (Max 100MB)
</p>
</div>
{/* Upload Progress */}
{uploading && (
<div className="space-y-2">
<div className="flex justify-between text-sm">
<span>Uploading...</span>
<span>{uploadProgress}%</span>
</div>
<Progress value={uploadProgress} />
</div>
)}
{/* Manual Upload Button */}
<div className="flex justify-center">
<Button
type="button"
variant="outline"
onClick={handleButtonClick}
disabled={uploading}
>
{uploading ? 'Uploading...' : 'Choose File'}
</Button>
</div>
</div>
);
}

104
ui/src/app/files/page.tsx Normal file
View file

@ -0,0 +1,104 @@
"use client";
import { useCallback, useEffect, useState } from "react";
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
import { Skeleton } from "@/components/ui/skeleton";
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { useAuth } from "@/lib/auth";
import DocumentList from "./DocumentList";
import DocumentUpload from "./DocumentUpload";
export default function FilesPage() {
const { user, getAccessToken, redirectToLogin, loading } = useAuth();
const [refreshKey, setRefreshKey] = useState(0);
const [accessToken, setAccessToken] = useState<string>('');
// Redirect if not authenticated
useEffect(() => {
if (!loading && !user) {
redirectToLogin();
}
}, [loading, user, redirectToLogin]);
// Get access token
const fetchAccessToken = useCallback(async () => {
if (user) {
const token = await getAccessToken();
setAccessToken(token);
}
}, [user, getAccessToken]);
useEffect(() => {
fetchAccessToken();
}, [fetchAccessToken]);
const handleUploadSuccess = () => {
// Trigger refresh of document list
setRefreshKey(prev => prev + 1);
};
if (loading || !user || !accessToken) {
return (
<div className="container mx-auto px-4 py-8">
<div className="space-y-4">
<Skeleton className="h-12 w-64" />
<Skeleton className="h-64 w-full" />
</div>
</div>
);
}
return (
<div className="container mx-auto px-4 py-8">
<div className="mb-8">
<h1 className="text-3xl font-bold mb-2">Knowledge Base Files</h1>
<p className="text-muted-foreground">
Upload and manage documents for your voice agents to reference.
</p>
</div>
<Tabs defaultValue="all" className="space-y-6">
<TabsList>
<TabsTrigger value="all">All Files</TabsTrigger>
<TabsTrigger value="upload">Upload New</TabsTrigger>
</TabsList>
<TabsContent value="all" className="space-y-4">
<Card>
<CardHeader>
<CardTitle>Your Documents</CardTitle>
<CardDescription>
View and manage your uploaded documents
</CardDescription>
</CardHeader>
<CardContent>
<DocumentList
accessToken={accessToken}
refreshTrigger={refreshKey}
/>
</CardContent>
</Card>
</TabsContent>
<TabsContent value="upload" className="space-y-4">
<Card>
<CardHeader>
<CardTitle>Upload Document</CardTitle>
<CardDescription>
Upload a PDF or document file to add to your knowledge base
</CardDescription>
</CardHeader>
<CardContent>
<DocumentUpload
accessToken={accessToken}
onUploadSuccess={handleUploadSuccess}
/>
</CardContent>
</Card>
</TabsContent>
</Tabs>
</div>
);
}

View file

@ -7,8 +7,10 @@ import {
ReactFlow,
} from "@xyflow/react";
import { BrushCleaning, Maximize2, Minus, Plus, Rocket, Settings, Variable } from 'lucide-react';
import React, { useMemo, useState } from 'react';
import React, { useEffect, useMemo, useState } from 'react';
import { listDocumentsApiV1KnowledgeBaseDocumentsGet, listToolsApiV1ToolsGet } from '@/client';
import type { DocumentResponseSchema, ToolResponse } from '@/client/types.gen';
import { FlowEdge, FlowNode, NodeType } from "@/components/flow/types";
import { Button } from '@/components/ui/button';
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/tooltip';
@ -63,6 +65,8 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
const [isConfigurationsDialogOpen, setIsConfigurationsDialogOpen] = useState(false);
const [isEmbedDialogOpen, setIsEmbedDialogOpen] = useState(false);
const [isPhoneCallDialogOpen, setIsPhoneCallDialogOpen] = useState(false);
const [documents, setDocuments] = useState<DocumentResponseSchema[] | undefined>(undefined);
const [tools, setTools] = useState<ToolResponse[] | undefined>(undefined);
const {
rfInstance,
@ -95,6 +99,36 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
getAccessToken
});
// Fetch documents and tools once for the entire workflow
useEffect(() => {
const fetchData = async () => {
try {
const accessToken = await getAccessToken();
// Fetch documents
const documentsResponse = await listDocumentsApiV1KnowledgeBaseDocumentsGet({
headers: { Authorization: `Bearer ${accessToken}` },
query: { limit: 100 },
});
if (documentsResponse.data) {
setDocuments(documentsResponse.data.documents);
}
// Fetch tools
const toolsResponse = await listToolsApiV1ToolsGet({
headers: { Authorization: `Bearer ${accessToken}` },
});
if (toolsResponse.data) {
setTools(toolsResponse.data);
}
} catch (error) {
console.error('Failed to fetch documents and tools:', error);
}
};
fetchData();
}, [getAccessToken]);
// Memoize defaultEdgeOptions to prevent unnecessary re-renders
const defaultEdgeOptions = useMemo(() => ({
animated: true,
@ -102,7 +136,11 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
}), []);
// Memoize the context value to prevent unnecessary re-renders
const workflowContextValue = useMemo(() => ({ saveWorkflow }), [saveWorkflow]);
const workflowContextValue = useMemo(() => ({
saveWorkflow,
documents,
tools
}), [saveWorkflow, documents, tools]);
return (
<WorkflowProvider value={workflowContextValue}>

View file

@ -1,7 +1,11 @@
import { createContext, useContext } from 'react';
import type { DocumentResponseSchema, ToolResponse } from '@/client/types.gen';
interface WorkflowContextType {
saveWorkflow: (updateWorkflowDefinition?: boolean) => Promise<void>;
documents?: DocumentResponseSchema[];
tools?: ToolResponse[];
}
const WorkflowContext = createContext<WorkflowContextType | undefined>(undefined);
@ -15,3 +19,8 @@ export const useWorkflow = () => {
}
return context;
};
// Optional hook that doesn't throw if context is not available
export const useWorkflowOptional = () => {
return useContext(WorkflowContext);
};

View file

@ -1,9 +1,8 @@
// This file is auto-generated by @hey-api/openapi-ts
import { type ClientOptions as DefaultClientOptions, type Config, createClient, createConfig } from '@hey-api/client-fetch';
import { createClientConfig } from '../lib/apiClient';
import type { ClientOptions } from './types.gen';
import { type Config, type ClientOptions as DefaultClientOptions, createClient, createConfig } from '@hey-api/client-fetch';
import { createClientConfig } from '../lib/apiClient';
/**
* The `createClientConfig()` function will be called on client initialization
@ -17,4 +16,4 @@ export type CreateClientConfig<T extends DefaultClientOptions = ClientOptions> =
export const client = createClient(createClientConfig(createConfig<ClientOptions>({
baseUrl: 'http://127.0.0.1:8000'
})));
})));

View file

@ -1,3 +1,3 @@
// This file is auto-generated by @hey-api/openapi-ts
export * from './sdk.gen';
export * from './types.gen';
export * from './sdk.gen';

File diff suppressed because one or more lines are too long

View file

@ -83,6 +83,54 @@ export type CampaignsResponse = {
campaigns: Array<CampaignResponse>;
};
/**
* Response schema for a document chunk.
*/
export type ChunkResponseSchema = {
id: number;
document_id: number;
chunk_text: string;
contextualized_text: string | null;
chunk_index: number;
chunk_metadata: {
[key: string]: unknown;
};
filename: string;
document_uuid: string;
similarity: number;
};
/**
* Request schema for searching similar chunks.
*/
export type ChunkSearchRequestSchema = {
/**
* Search query text
*/
query: string;
/**
* Maximum number of results
*/
limit?: number;
/**
* Filter by specific document UUIDs
*/
document_uuids?: Array<string> | null;
/**
* Minimum similarity threshold
*/
min_similarity?: number | null;
};
/**
* Response schema for chunk search results.
*/
export type ChunkSearchResponseSchema = {
chunks: Array<ChunkResponseSchema>;
query: string;
total_results: number;
};
/**
* Request schema for Cloudonix configuration.
*/
@ -303,11 +351,91 @@ export type DefaultConfigurationsResponse = {
[key: string]: unknown;
};
};
embeddings: {
[key: string]: {
[key: string]: unknown;
};
};
default_providers: {
[key: string]: string;
};
};
/**
* Response schema for list of documents.
*/
export type DocumentListResponseSchema = {
documents: Array<DocumentResponseSchema>;
total: number;
limit: number;
offset: number;
};
/**
* Response schema for document metadata.
*/
export type DocumentResponseSchema = {
id: number;
document_uuid: string;
filename: string;
file_size_bytes: number;
file_hash: string;
mime_type: string;
processing_status: string;
processing_error?: string | null;
total_chunks: number;
custom_metadata: {
[key: string]: unknown;
};
docling_metadata: {
[key: string]: unknown;
};
source_url?: string | null;
created_at: string;
updated_at: string;
organization_id: number;
created_by: number;
is_active: boolean;
};
/**
* Request schema for initiating document upload.
*/
export type DocumentUploadRequestSchema = {
/**
* Name of the file to upload
*/
filename: string;
/**
* MIME type of the file
*/
mime_type: string;
/**
* Optional custom metadata
*/
custom_metadata?: {
[key: string]: unknown;
} | null;
};
/**
* Response schema containing upload URL and document metadata.
*/
export type DocumentUploadResponseSchema = {
/**
* Signed URL for uploading the file
*/
upload_url: string;
/**
* Unique identifier for the document
*/
document_uuid: string;
/**
* S3 key where file should be uploaded
*/
s3_key: string;
};
export type DuplicateTemplateRequest = {
template_id: number;
workflow_name: string;
@ -537,6 +665,24 @@ export type PresignedUploadUrlResponse = {
expires_in: number;
};
/**
* Request schema for triggering document processing.
*/
export type ProcessDocumentRequestSchema = {
/**
* Document UUID to process
*/
document_uuid: string;
/**
* S3 key of the uploaded file
*/
s3_key: string;
/**
* Embedding service to use for processing. Options: 'openai' (default, 1536-dim, requires API key) or 'sentence_transformer' (free, 384-dim)
*/
embedding_service?: 'sentence_transformer' | 'openai';
};
export type S3SignedUrlResponse = {
url: string;
expires_in: number;
@ -787,6 +933,9 @@ export type UserConfigurationRequestResponseSchema = {
stt?: {
[key: string]: string | number;
} | null;
embeddings?: {
[key: string]: string | number;
} | null;
test_phone_number?: string | null;
timezone?: string | null;
organization_pricing?: {
@ -4126,6 +4275,213 @@ export type CreateOrUpdateEmbedTokenApiV1WorkflowWorkflowIdEmbedTokenPostRespons
export type CreateOrUpdateEmbedTokenApiV1WorkflowWorkflowIdEmbedTokenPostResponse = CreateOrUpdateEmbedTokenApiV1WorkflowWorkflowIdEmbedTokenPostResponses[keyof CreateOrUpdateEmbedTokenApiV1WorkflowWorkflowIdEmbedTokenPostResponses];
export type GetUploadUrlApiV1KnowledgeBaseUploadUrlPostData = {
body: DocumentUploadRequestSchema;
headers?: {
authorization?: string | null;
'X-API-Key'?: string | null;
};
path?: never;
query?: never;
url: '/api/v1/knowledge-base/upload-url';
};
export type GetUploadUrlApiV1KnowledgeBaseUploadUrlPostErrors = {
/**
* Not found
*/
404: unknown;
/**
* Validation Error
*/
422: HttpValidationError;
};
export type GetUploadUrlApiV1KnowledgeBaseUploadUrlPostError = GetUploadUrlApiV1KnowledgeBaseUploadUrlPostErrors[keyof GetUploadUrlApiV1KnowledgeBaseUploadUrlPostErrors];
export type GetUploadUrlApiV1KnowledgeBaseUploadUrlPostResponses = {
/**
* Successful Response
*/
200: DocumentUploadResponseSchema;
};
export type GetUploadUrlApiV1KnowledgeBaseUploadUrlPostResponse = GetUploadUrlApiV1KnowledgeBaseUploadUrlPostResponses[keyof GetUploadUrlApiV1KnowledgeBaseUploadUrlPostResponses];
export type ProcessDocumentApiV1KnowledgeBaseProcessDocumentPostData = {
body: ProcessDocumentRequestSchema;
headers?: {
authorization?: string | null;
'X-API-Key'?: string | null;
};
path?: never;
query?: never;
url: '/api/v1/knowledge-base/process-document';
};
export type ProcessDocumentApiV1KnowledgeBaseProcessDocumentPostErrors = {
/**
* Not found
*/
404: unknown;
/**
* Validation Error
*/
422: HttpValidationError;
};
export type ProcessDocumentApiV1KnowledgeBaseProcessDocumentPostError = ProcessDocumentApiV1KnowledgeBaseProcessDocumentPostErrors[keyof ProcessDocumentApiV1KnowledgeBaseProcessDocumentPostErrors];
export type ProcessDocumentApiV1KnowledgeBaseProcessDocumentPostResponses = {
/**
* Successful Response
*/
200: DocumentResponseSchema;
};
export type ProcessDocumentApiV1KnowledgeBaseProcessDocumentPostResponse = ProcessDocumentApiV1KnowledgeBaseProcessDocumentPostResponses[keyof ProcessDocumentApiV1KnowledgeBaseProcessDocumentPostResponses];
export type ListDocumentsApiV1KnowledgeBaseDocumentsGetData = {
body?: never;
headers?: {
authorization?: string | null;
'X-API-Key'?: string | null;
};
path?: never;
query?: {
/**
* Filter by processing status
*/
status?: string | null;
limit?: number;
offset?: number;
};
url: '/api/v1/knowledge-base/documents';
};
export type ListDocumentsApiV1KnowledgeBaseDocumentsGetErrors = {
/**
* Not found
*/
404: unknown;
/**
* Validation Error
*/
422: HttpValidationError;
};
export type ListDocumentsApiV1KnowledgeBaseDocumentsGetError = ListDocumentsApiV1KnowledgeBaseDocumentsGetErrors[keyof ListDocumentsApiV1KnowledgeBaseDocumentsGetErrors];
export type ListDocumentsApiV1KnowledgeBaseDocumentsGetResponses = {
/**
* Successful Response
*/
200: DocumentListResponseSchema;
};
export type ListDocumentsApiV1KnowledgeBaseDocumentsGetResponse = ListDocumentsApiV1KnowledgeBaseDocumentsGetResponses[keyof ListDocumentsApiV1KnowledgeBaseDocumentsGetResponses];
export type DeleteDocumentApiV1KnowledgeBaseDocumentsDocumentUuidDeleteData = {
body?: never;
headers?: {
authorization?: string | null;
'X-API-Key'?: string | null;
};
path: {
document_uuid: string;
};
query?: never;
url: '/api/v1/knowledge-base/documents/{document_uuid}';
};
export type DeleteDocumentApiV1KnowledgeBaseDocumentsDocumentUuidDeleteErrors = {
/**
* Not found
*/
404: unknown;
/**
* Validation Error
*/
422: HttpValidationError;
};
export type DeleteDocumentApiV1KnowledgeBaseDocumentsDocumentUuidDeleteError = DeleteDocumentApiV1KnowledgeBaseDocumentsDocumentUuidDeleteErrors[keyof DeleteDocumentApiV1KnowledgeBaseDocumentsDocumentUuidDeleteErrors];
export type DeleteDocumentApiV1KnowledgeBaseDocumentsDocumentUuidDeleteResponses = {
/**
* Successful Response
*/
200: unknown;
};
export type GetDocumentApiV1KnowledgeBaseDocumentsDocumentUuidGetData = {
body?: never;
headers?: {
authorization?: string | null;
'X-API-Key'?: string | null;
};
path: {
document_uuid: string;
};
query?: never;
url: '/api/v1/knowledge-base/documents/{document_uuid}';
};
export type GetDocumentApiV1KnowledgeBaseDocumentsDocumentUuidGetErrors = {
/**
* Not found
*/
404: unknown;
/**
* Validation Error
*/
422: HttpValidationError;
};
export type GetDocumentApiV1KnowledgeBaseDocumentsDocumentUuidGetError = GetDocumentApiV1KnowledgeBaseDocumentsDocumentUuidGetErrors[keyof GetDocumentApiV1KnowledgeBaseDocumentsDocumentUuidGetErrors];
export type GetDocumentApiV1KnowledgeBaseDocumentsDocumentUuidGetResponses = {
/**
* Successful Response
*/
200: DocumentResponseSchema;
};
export type GetDocumentApiV1KnowledgeBaseDocumentsDocumentUuidGetResponse = GetDocumentApiV1KnowledgeBaseDocumentsDocumentUuidGetResponses[keyof GetDocumentApiV1KnowledgeBaseDocumentsDocumentUuidGetResponses];
export type SearchChunksApiV1KnowledgeBaseSearchPostData = {
body: ChunkSearchRequestSchema;
headers?: {
authorization?: string | null;
'X-API-Key'?: string | null;
};
path?: never;
query?: never;
url: '/api/v1/knowledge-base/search';
};
export type SearchChunksApiV1KnowledgeBaseSearchPostErrors = {
/**
* Not found
*/
404: unknown;
/**
* Validation Error
*/
422: HttpValidationError;
};
export type SearchChunksApiV1KnowledgeBaseSearchPostError = SearchChunksApiV1KnowledgeBaseSearchPostErrors[keyof SearchChunksApiV1KnowledgeBaseSearchPostErrors];
export type SearchChunksApiV1KnowledgeBaseSearchPostResponses = {
/**
* Successful Response
*/
200: ChunkSearchResponseSchema;
};
export type SearchChunksApiV1KnowledgeBaseSearchPostResponse = SearchChunksApiV1KnowledgeBaseSearchPostResponses[keyof SearchChunksApiV1KnowledgeBaseSearchPostResponses];
export type HealthApiV1HealthGetData = {
body?: never;
path?: never;
@ -4149,4 +4505,4 @@ export type HealthApiV1HealthGetResponses = {
export type ClientOptions = {
baseUrl: 'http://127.0.0.1:8000' | (string & {});
};
};

View file

@ -14,7 +14,7 @@ import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { VoiceSelector } from "@/components/VoiceSelector";
import { useUserConfig } from "@/context/UserConfigContext";
type ServiceSegment = "llm" | "tts" | "stt";
type ServiceSegment = "llm" | "tts" | "stt" | "embeddings";
interface SchemaProperty {
type?: string;
@ -41,6 +41,7 @@ const TAB_CONFIG: { key: ServiceSegment; label: string }[] = [
{ key: "llm", label: "LLM" },
{ key: "tts", label: "Voice" },
{ key: "stt", label: "Transcriber" },
{ key: "embeddings", label: "Embedding" },
];
// Display names for language codes (Deepgram + Sarvam)
@ -109,12 +110,14 @@ export default function ServiceConfiguration() {
const [schemas, setSchemas] = useState<Record<ServiceSegment, Record<string, ProviderSchema>>>({
llm: {},
tts: {},
stt: {}
stt: {},
embeddings: {}
});
const [serviceProviders, setServiceProviders] = useState<Record<ServiceSegment, string>>({
llm: "",
tts: "",
stt: ""
stt: "",
embeddings: ""
});
const [isManualModelInput, setIsManualModelInput] = useState(false);
const [hasCheckedManualMode, setHasCheckedManualMode] = useState(false);
@ -136,7 +139,8 @@ export default function ServiceConfiguration() {
setSchemas({
llm: response.data.llm as Record<string, ProviderSchema>,
tts: response.data.tts as Record<string, ProviderSchema>,
stt: response.data.stt as Record<string, ProviderSchema>
stt: response.data.stt as Record<string, ProviderSchema>,
embeddings: response.data.embeddings as Record<string, ProviderSchema>
});
} else {
console.error("Failed to fetch configurations");
@ -147,7 +151,8 @@ export default function ServiceConfiguration() {
const selectedProviders: Record<ServiceSegment, string> = {
llm: response.data.default_providers.llm,
tts: response.data.default_providers.tts,
stt: response.data.default_providers.stt
stt: response.data.default_providers.stt,
embeddings: response.data.default_providers.embeddings
};
const setServicePropertyValues = (service: ServiceSegment) => {
@ -173,6 +178,7 @@ export default function ServiceConfiguration() {
setServicePropertyValues("llm");
setServicePropertyValues("tts");
setServicePropertyValues("stt");
setServicePropertyValues("embeddings");
// IMPORTANT: Reset form values BEFORE changing providers
// Otherwise, Radix Select sees old values that don't match new provider's enum
@ -246,7 +252,7 @@ export default function ServiceConfiguration() {
setApiError(null);
setIsSaving(true);
const userConfig = {
const userConfig: Record<ServiceSegment, Record<string, string | number>> = {
llm: {
provider: serviceProviders.llm,
api_key: data.llm_api_key as string,
@ -259,6 +265,11 @@ export default function ServiceConfiguration() {
stt: {
provider: serviceProviders.stt,
api_key: data.stt_api_key as string
},
embeddings: {
provider: serviceProviders.embeddings,
api_key: data.embeddings_api_key as string,
model: data.embeddings_model as string
}
};
@ -273,12 +284,25 @@ export default function ServiceConfiguration() {
}
});
// Build save config - only include embeddings if api_key is provided
const saveConfig: {
llm: Record<string, string | number>;
tts: Record<string, string | number>;
stt: Record<string, string | number>;
embeddings?: Record<string, string | number>;
} = {
llm: userConfig.llm,
tts: userConfig.tts,
stt: userConfig.stt
};
// Only include embeddings if user has configured it (has api_key)
if (userConfig.embeddings.api_key) {
saveConfig.embeddings = userConfig.embeddings;
}
try {
await saveUserConfig({
llm: userConfig.llm,
tts: userConfig.tts,
stt: userConfig.stt
});
await saveUserConfig(saveConfig);
setApiError(null);
} catch (error: unknown) {
if (error instanceof Error) {
@ -543,7 +567,7 @@ export default function ServiceConfiguration() {
<Card>
<CardContent className="pt-6">
<Tabs defaultValue="llm" className="w-full">
<TabsList className="grid w-full grid-cols-3 mb-6">
<TabsList className="grid w-full grid-cols-4 mb-6">
{TAB_CONFIG.map(({ key, label }) => (
<TabsTrigger key={key} value={key}>
{label}

View file

@ -0,0 +1,65 @@
"use client";
import { useCallback, useEffect, useState } from "react";
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
import type { DocumentResponseSchema } from "@/client/types.gen";
import { Badge } from "@/components/ui/badge";
interface DocumentBadgesProps {
documentUuids: string[];
onStaleUuidsDetected?: (staleUuids: string[]) => void;
}
export const DocumentBadges = ({ documentUuids, onStaleUuidsDetected }: DocumentBadgesProps) => {
const { documents } = useWorkflow();
const [documentNames, setDocumentNames] = useState<Record<string, string>>({});
const processDocuments = useCallback((docs: DocumentResponseSchema[]) => {
const nameMap: Record<string, string> = {};
const validUuids = new Set<string>();
docs
.filter((doc) => documentUuids.includes(doc.document_uuid))
.forEach((doc) => {
nameMap[doc.document_uuid] = doc.filename;
validUuids.add(doc.document_uuid);
});
setDocumentNames(nameMap);
// Detect stale UUIDs - this only runs when we have loaded data (not undefined)
if (onStaleUuidsDetected) {
const staleUuids = documentUuids.filter(uuid => !validUuids.has(uuid));
if (staleUuids.length > 0) {
onStaleUuidsDetected(staleUuids);
}
}
}, [documentUuids, onStaleUuidsDetected]);
useEffect(() => {
if (documentUuids.length > 0 && documents !== undefined) {
processDocuments(documents);
} else if (documentUuids.length === 0) {
setDocumentNames({});
}
}, [documentUuids, documents, processDocuments]);
if (documentUuids.length === 0) {
return <></>;
}
// Show loading while data hasn't loaded yet
if (documents === undefined) {
return <Badge variant="outline">Loading...</Badge>;
}
return (
<>
{documentUuids.map((uuid) => (
<Badge key={uuid} variant="outline">
{documentNames[uuid] || uuid}
</Badge>
))}
</>
);
};

View file

@ -0,0 +1,137 @@
"use client";
import { FileText } from "lucide-react";
import Link from "next/link";
import { useMemo } from "react";
import type { DocumentResponseSchema } from "@/client/types.gen";
import { Button } from "@/components/ui/button";
import { Checkbox } from "@/components/ui/checkbox";
import { Label } from "@/components/ui/label";
interface DocumentSelectorProps {
value: string[];
onChange: (uuids: string[]) => void;
documents: DocumentResponseSchema[];
disabled?: boolean;
label?: string;
description?: string;
showLabel?: boolean;
}
export const DocumentSelector = ({
value,
onChange,
documents,
disabled = false,
label = "Knowledge Base Documents",
description = "Select documents that the agent can reference during conversations.",
showLabel = true,
}: DocumentSelectorProps) => {
// Only show completed documents
const completedDocuments = useMemo(
() => documents.filter((doc) => doc.processing_status === "completed"),
[documents]
);
const handleToggle = (documentUuid: string, checked: boolean) => {
if (checked) {
onChange([...value, documentUuid]);
} else {
onChange(value.filter((uuid) => uuid !== documentUuid));
}
};
const formatFileSize = (bytes: number): string => {
if (bytes === 0) return "0 Bytes";
const k = 1024;
const sizes = ["Bytes", "KB", "MB", "GB"];
const i = Math.floor(Math.log(bytes) / Math.log(k));
return Math.round(bytes / Math.pow(k, i) * 100) / 100 + " " + sizes[i];
};
if (completedDocuments.length === 0) {
return (
<div className="space-y-2">
{showLabel && (
<>
<Label>{label}</Label>
{description && (
<Label className="text-xs text-muted-foreground">{description}</Label>
)}
</>
)}
<div className="border rounded-md p-4 space-y-3">
<div className="text-sm text-muted-foreground text-center">
No documents available. Upload documents to the knowledge base first.
</div>
<div className="flex justify-center">
<Link href="/files">
<Button variant="outline" size="sm">
Upload Documents
</Button>
</Link>
</div>
</div>
</div>
);
}
return (
<div className="space-y-2">
{showLabel && (
<>
<Label>{label}</Label>
{description && (
<Label className="text-xs text-muted-foreground">{description}</Label>
)}
</>
)}
<div className="border rounded-md max-h-[300px] overflow-y-auto">
<div className="divide-y">
{completedDocuments.map((doc) => (
<div
key={doc.document_uuid}
className="flex items-start gap-3 p-3 hover:bg-muted/50 transition-colors"
>
<Checkbox
id={`doc-${doc.document_uuid}`}
checked={value.includes(doc.document_uuid)}
onCheckedChange={(checked) =>
handleToggle(doc.document_uuid, checked as boolean)
}
disabled={disabled}
/>
<div className="flex-1 space-y-1">
<label
htmlFor={`doc-${doc.document_uuid}`}
className="flex items-center gap-2 cursor-pointer"
>
<div className="w-8 h-8 rounded-md bg-blue-500/10 flex items-center justify-center flex-shrink-0">
<FileText className="w-4 h-4 text-blue-500" />
</div>
<div className="flex-1 min-w-0">
<div className="text-sm font-medium truncate">
{doc.filename}
</div>
<div className="text-xs text-muted-foreground">
{formatFileSize(doc.file_size_bytes)} {doc.total_chunks} chunks
</div>
</div>
</label>
</div>
</div>
))}
</div>
</div>
<div className="flex items-center justify-between text-xs text-muted-foreground pt-1">
<span>
{value.length} {value.length === 1 ? "document" : "documents"} selected
</span>
<Link href="/files" className="hover:underline">
Manage Documents
</Link>
</div>
</div>
);
};

View file

@ -2,43 +2,43 @@
import { useCallback, useEffect, useState } from "react";
import { listToolsApiV1ToolsGet } from "@/client/sdk.gen";
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
import type { ToolResponse } from "@/client/types.gen";
import { Badge } from "@/components/ui/badge";
import { useAuth } from "@/lib/auth";
interface ToolBadgesProps {
toolUuids: string[];
onStaleUuidsDetected?: (staleUuids: string[]) => void;
}
export function ToolBadges({ toolUuids }: ToolBadgesProps) {
const { getAccessToken } = useAuth();
const [tools, setTools] = useState<ToolResponse[]>([]);
export function ToolBadges({ toolUuids, onStaleUuidsDetected }: ToolBadgesProps) {
const { tools } = useWorkflow();
const [selectedTools, setSelectedTools] = useState<ToolResponse[]>([]);
const fetchTools = useCallback(async () => {
try {
const accessToken = await getAccessToken();
const response = await listToolsApiV1ToolsGet({
headers: { Authorization: `Bearer ${accessToken}` },
});
if (response.data) {
setTools(response.data);
const processTools = useCallback((toolsData: ToolResponse[]) => {
const filtered = toolsData.filter(tool => toolUuids.includes(tool.tool_uuid));
setSelectedTools(filtered);
// Detect stale UUIDs - this only runs when we have loaded data (not undefined)
if (onStaleUuidsDetected) {
const validUuids = new Set(toolsData.map(tool => tool.tool_uuid));
const staleUuids = toolUuids.filter(uuid => !validUuids.has(uuid));
if (staleUuids.length > 0) {
onStaleUuidsDetected(staleUuids);
}
} catch (error) {
console.error("Failed to fetch tools:", error);
}
}, [getAccessToken]);
}, [toolUuids, onStaleUuidsDetected]);
useEffect(() => {
if (toolUuids.length > 0) {
fetchTools();
if (toolUuids.length > 0 && tools !== undefined) {
processTools(tools);
} else if (toolUuids.length === 0) {
setSelectedTools([]);
}
}, [toolUuids.length, fetchTools]);
}, [toolUuids, tools, processTools]);
const selectedTools = tools.filter((tool) => toolUuids.includes(tool.tool_uuid));
if (selectedTools.length === 0 && toolUuids.length > 0) {
// Still loading or tools not found
// Show loading while data hasn't loaded yet
if (tools === undefined && toolUuids.length > 0) {
return (
<div className="flex flex-wrap gap-1">
<Badge variant="outline" className="text-xs">

View file

@ -1,20 +1,18 @@
"use client";
import { ExternalLink, Loader2 } from "lucide-react";
import { ExternalLink } from "lucide-react";
import Link from "next/link";
import { useCallback, useEffect, useState } from "react";
import { renderToolIcon } from "@/app/tools/config";
import { listToolsApiV1ToolsGet } from "@/client/sdk.gen";
import type { ToolResponse } from "@/client/types.gen";
import { Button } from "@/components/ui/button";
import { Checkbox } from "@/components/ui/checkbox";
import { Label } from "@/components/ui/label";
import { useAuth } from "@/lib/auth";
interface ToolSelectorProps {
value: string[];
onChange: (uuids: string[]) => void;
tools: ToolResponse[];
disabled?: boolean;
label?: string;
description?: string;
@ -24,43 +22,14 @@ interface ToolSelectorProps {
export function ToolSelector({
value,
onChange,
tools,
disabled = false,
label = "Tools",
description = "Select tools that the agent can use during the conversation.",
showLabel = true,
}: ToolSelectorProps) {
const { getAccessToken } = useAuth();
const [tools, setTools] = useState<ToolResponse[]>([]);
const [loading, setLoading] = useState(false);
const fetchTools = useCallback(async () => {
setLoading(true);
try {
const accessToken = await getAccessToken();
const response = await listToolsApiV1ToolsGet({
headers: { Authorization: `Bearer ${accessToken}` },
query: { status: "active" },
});
if (response.error) {
console.error("Failed to fetch tools:", response.error);
setTools([]);
return;
}
if (response.data) {
setTools(response.data);
}
} catch (error) {
console.error("Failed to fetch tools:", error);
setTools([]);
} finally {
setLoading(false);
}
}, [getAccessToken]);
useEffect(() => {
fetchTools();
}, [fetchTools]);
// Filter to only show active tools
const activeTools = tools.filter((tool) => tool.status === "active");
const handleToggle = (toolUuid: string, checked: boolean) => {
if (checked) {
@ -83,12 +52,7 @@ export function ToolSelector({
</>
)}
{loading ? (
<div className="flex items-center gap-2 p-3 border rounded-md">
<Loader2 className="h-4 w-4 animate-spin" />
<span className="text-sm text-muted-foreground">Loading tools...</span>
</div>
) : tools.length === 0 ? (
{activeTools.length === 0 ? (
<div className="p-4 border rounded-md text-center">
<p className="text-sm text-muted-foreground mb-2">
No tools available.
@ -102,7 +66,7 @@ export function ToolSelector({
</div>
) : (
<div className="border rounded-md divide-y">
{tools.map((tool) => {
{activeTools.map((tool) => {
const isSelected = value.includes(tool.tool_uuid);
return (
<label

View file

@ -1,8 +1,11 @@
import { NodeProps, NodeToolbar, Position } from "@xyflow/react";
import { Edit, Headset, PlusIcon, Trash2Icon, Wrench } from "lucide-react";
import { memo, useEffect, useMemo, useState } from "react";
import { Edit, FileText, Headset, PlusIcon, Trash2Icon, Wrench } from "lucide-react";
import { memo, useCallback, useEffect, useMemo, useState } from "react";
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
import type { DocumentResponseSchema, ToolResponse } from "@/client/types.gen";
import { DocumentBadges } from "@/components/flow/DocumentBadges";
import { DocumentSelector } from "@/components/flow/DocumentSelector";
import { ToolBadges } from "@/components/flow/ToolBadges";
import { ToolSelector } from "@/components/flow/ToolSelector";
import { ExtractionVariable, FlowNodeData } from "@/components/flow/types";
@ -34,6 +37,10 @@ interface AgentNodeEditFormProps {
setAddGlobalPrompt: (value: boolean) => void;
toolUuids: string[];
setToolUuids: (value: string[]) => void;
documentUuids: string[];
setDocumentUuids: (value: string[]) => void;
tools: ToolResponse[];
documents: DocumentResponseSchema[];
}
interface AgentNodeProps extends NodeProps {
@ -42,7 +49,7 @@ interface AgentNodeProps extends NodeProps {
export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
const { open, setOpen, handleSaveNodeData, handleDeleteNode } = useNodeHandlers({ id });
const { saveWorkflow } = useWorkflow();
const { saveWorkflow, tools, documents } = useWorkflow();
// Form state
const [prompt, setPrompt] = useState(data.prompt);
@ -55,6 +62,7 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
const [variables, setVariables] = useState<ExtractionVariable[]>(data.extraction_variables ?? []);
const [addGlobalPrompt, setAddGlobalPrompt] = useState(data.add_global_prompt ?? true);
const [toolUuids, setToolUuids] = useState<string[]>(data.tool_uuids ?? []);
const [documentUuids, setDocumentUuids] = useState<string[]>(data.document_uuids ?? []);
// Compute if form has unsaved changes (only check prompt, name)
const isDirty = useMemo(() => {
@ -75,6 +83,7 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
extraction_variables: variables,
add_global_prompt: addGlobalPrompt,
tool_uuids: toolUuids.length > 0 ? toolUuids : undefined,
document_uuids: documentUuids.length > 0 ? documentUuids : undefined,
});
setOpen(false);
// Save the workflow after updating node data with a small delay to ensure state is updated
@ -94,6 +103,7 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
setVariables(data.extraction_variables ?? []);
setAddGlobalPrompt(data.add_global_prompt ?? true);
setToolUuids(data.tool_uuids ?? []);
setDocumentUuids(data.document_uuids ?? []);
}
setOpen(newOpen);
};
@ -109,9 +119,34 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
setVariables(data.extraction_variables ?? []);
setAddGlobalPrompt(data.add_global_prompt ?? true);
setToolUuids(data.tool_uuids ?? []);
setDocumentUuids(data.document_uuids ?? []);
}
}, [data, open]);
// Handle cleanup of stale document UUIDs
const handleStaleDocuments = useCallback((staleUuids: string[]) => {
const cleanedUuids = (data.document_uuids ?? []).filter(uuid => !staleUuids.includes(uuid));
handleSaveNodeData({
...data,
document_uuids: cleanedUuids.length > 0 ? cleanedUuids : undefined,
});
setTimeout(async () => {
await saveWorkflow();
}, 100);
}, [data, handleSaveNodeData, saveWorkflow]);
// Handle cleanup of stale tool UUIDs
const handleStaleTools = useCallback((staleUuids: string[]) => {
const cleanedUuids = (data.tool_uuids ?? []).filter(uuid => !staleUuids.includes(uuid));
handleSaveNodeData({
...data,
tool_uuids: cleanedUuids.length > 0 ? cleanedUuids : undefined,
});
setTimeout(async () => {
await saveWorkflow();
}, 100);
}, [data, handleSaveNodeData, saveWorkflow]);
return (
<>
<NodeContent
@ -136,7 +171,16 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
<Wrench className="h-3 w-3" />
<span>Tools:</span>
</div>
<ToolBadges toolUuids={data.tool_uuids} />
<ToolBadges toolUuids={data.tool_uuids} onStaleUuidsDetected={handleStaleTools} />
</div>
)}
{data.document_uuids && data.document_uuids.length > 0 && (
<div className="mt-3 pt-3 border-t border-border/50">
<div className="flex items-center gap-1.5 text-xs text-muted-foreground mb-2">
<FileText className="h-3 w-3" />
<span>Documents:</span>
</div>
<DocumentBadges documentUuids={data.document_uuids} onStaleUuidsDetected={handleStaleDocuments} />
</div>
)}
</NodeContent>
@ -179,6 +223,10 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
setAddGlobalPrompt={setAddGlobalPrompt}
toolUuids={toolUuids}
setToolUuids={setToolUuids}
documentUuids={documentUuids}
setDocumentUuids={setDocumentUuids}
tools={tools ?? []}
documents={documents ?? []}
/>
)}
</NodeEditDialog>
@ -203,6 +251,10 @@ const AgentNodeEditForm = ({
setAddGlobalPrompt,
toolUuids,
setToolUuids,
documentUuids,
setDocumentUuids,
tools,
documents,
}: AgentNodeEditFormProps) => {
const handleVariableNameChange = (idx: number, value: string) => {
const newVars = [...variables];
@ -343,9 +395,20 @@ const AgentNodeEditForm = ({
<ToolSelector
value={toolUuids}
onChange={setToolUuids}
tools={tools}
description="Select tools that the agent can invoke during this conversation step."
/>
</div>
{/* Documents Section */}
<div className="pt-4 border-t mt-4">
<DocumentSelector
value={documentUuids}
onChange={setDocumentUuids}
documents={documents}
description="Select documents from the knowledge base that the agent can reference during this conversation step."
/>
</div>
</div>
);
};

View file

@ -1,8 +1,11 @@
import { NodeProps, NodeToolbar, Position } from "@xyflow/react";
import { Edit, Play, PlusIcon, Trash2Icon, Wrench } from "lucide-react";
import { memo, useEffect, useMemo, useState } from "react";
import { Edit, FileText, Play, PlusIcon, Trash2Icon, Wrench } from "lucide-react";
import { memo, useCallback, useEffect, useMemo, useState } from "react";
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
import type { DocumentResponseSchema, ToolResponse } from "@/client/types.gen";
import { DocumentBadges } from "@/components/flow/DocumentBadges";
import { DocumentSelector } from "@/components/flow/DocumentSelector";
import { ToolBadges } from "@/components/flow/ToolBadges";
import { ToolSelector } from "@/components/flow/ToolSelector";
import { ExtractionVariable, FlowNodeData } from "@/components/flow/types";
@ -41,6 +44,10 @@ interface StartCallEditFormProps {
setVariables: (vars: ExtractionVariable[]) => void;
toolUuids: string[];
setToolUuids: (value: string[]) => void;
documentUuids: string[];
setDocumentUuids: (value: string[]) => void;
tools: ToolResponse[];
documents: DocumentResponseSchema[];
}
interface StartCallNodeProps extends NodeProps {
@ -52,7 +59,7 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
id,
additionalData: { is_start: true }
});
const { saveWorkflow } = useWorkflow();
const { saveWorkflow, tools, documents } = useWorkflow();
// Form state
const [prompt, setPrompt] = useState(data.prompt ?? "");
@ -66,6 +73,7 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
const [extractionPrompt, setExtractionPrompt] = useState(data.extraction_prompt ?? "");
const [variables, setVariables] = useState<ExtractionVariable[]>(data.extraction_variables ?? []);
const [toolUuids, setToolUuids] = useState<string[]>(data.tool_uuids ?? []);
const [documentUuids, setDocumentUuids] = useState<string[]>(data.document_uuids ?? []);
// Compute if form has unsaved changes (only check prompt, name)
const isDirty = useMemo(() => {
@ -89,6 +97,7 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
extraction_prompt: extractionPrompt,
extraction_variables: variables,
tool_uuids: toolUuids.length > 0 ? toolUuids : undefined,
document_uuids: documentUuids.length > 0 ? documentUuids : undefined,
});
setOpen(false);
// Save the workflow after updating node data with a small delay to ensure state is updated
@ -111,6 +120,7 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
setExtractionPrompt(data.extraction_prompt ?? "");
setVariables(data.extraction_variables ?? []);
setToolUuids(data.tool_uuids ?? []);
setDocumentUuids(data.document_uuids ?? []);
}
setOpen(newOpen);
};
@ -129,9 +139,34 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
setExtractionPrompt(data.extraction_prompt ?? "");
setVariables(data.extraction_variables ?? []);
setToolUuids(data.tool_uuids ?? []);
setDocumentUuids(data.document_uuids ?? []);
}
}, [data, open]);
// Handle cleanup of stale document UUIDs
const handleStaleDocuments = useCallback((staleUuids: string[]) => {
const cleanedUuids = (data.document_uuids ?? []).filter(uuid => !staleUuids.includes(uuid));
handleSaveNodeData({
...data,
document_uuids: cleanedUuids.length > 0 ? cleanedUuids : undefined,
});
setTimeout(async () => {
await saveWorkflow();
}, 100);
}, [data, handleSaveNodeData, saveWorkflow]);
// Handle cleanup of stale tool UUIDs
const handleStaleTools = useCallback((staleUuids: string[]) => {
const cleanedUuids = (data.tool_uuids ?? []).filter(uuid => !staleUuids.includes(uuid));
handleSaveNodeData({
...data,
tool_uuids: cleanedUuids.length > 0 ? cleanedUuids : undefined,
});
setTimeout(async () => {
await saveWorkflow();
}, 100);
}, [data, handleSaveNodeData, saveWorkflow]);
return (
<>
<NodeContent
@ -155,7 +190,16 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
<Wrench className="h-3 w-3" />
<span>Tools:</span>
</div>
<ToolBadges toolUuids={data.tool_uuids} />
<ToolBadges toolUuids={data.tool_uuids} onStaleUuidsDetected={handleStaleTools} />
</div>
)}
{data.document_uuids && data.document_uuids.length > 0 && (
<div className="mt-3 pt-3 border-t border-border/50">
<div className="flex items-center gap-1.5 text-xs text-muted-foreground mb-2">
<FileText className="h-3 w-3" />
<span>Documents:</span>
</div>
<DocumentBadges documentUuids={data.document_uuids} onStaleUuidsDetected={handleStaleDocuments} />
</div>
)}
</NodeContent>
@ -199,6 +243,10 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
setVariables={setVariables}
toolUuids={toolUuids}
setToolUuids={setToolUuids}
documentUuids={documentUuids}
setDocumentUuids={setDocumentUuids}
tools={tools ?? []}
documents={documents ?? []}
/>
)}
</NodeEditDialog>
@ -229,6 +277,10 @@ const StartCallEditForm = ({
setVariables,
toolUuids,
setToolUuids,
documentUuids,
setDocumentUuids,
tools,
documents,
}: StartCallEditFormProps) => {
const handleVariableNameChange = (idx: number, value: string) => {
const newVars = [...variables];
@ -414,9 +466,20 @@ const StartCallEditForm = ({
<ToolSelector
value={toolUuids}
onChange={setToolUuids}
tools={tools}
description="Select tools that the agent can invoke during this conversation step."
/>
</div>
{/* Documents Section */}
<div className="pt-4 border-t mt-4">
<DocumentSelector
value={documentUuids}
onChange={setDocumentUuids}
documents={documents}
description="Select documents from the knowledge base that the agent can reference during this conversation step."
/>
</div>
</div>
);
};

View file

@ -42,6 +42,8 @@ export type FlowNodeData = {
};
// Tools - array of tool UUIDs that can be invoked by this node
tool_uuids?: string[];
// Documents - array of knowledge base document UUIDs that can be referenced by this node
document_uuids?: string[];
}
export type FlowNode = {

View file

@ -6,6 +6,7 @@ import {
ChevronLeft,
ChevronRight,
CircleDollarSign,
Database,
FileText,
HelpCircle,
Home,
@ -114,6 +115,11 @@ export function AppSidebar() {
url: "/tools",
icon: Wrench,
},
{
title: "Files",
url: "/files",
icon: Database,
},
// {
// title: "Integrations",
// url: "/integrations",

View file

@ -18,6 +18,9 @@ export type SaveUserConfigFunctionParams = {
stt?: {
[key: string]: string | number;
} | null;
embeddings?: {
[key: string]: string | number;
} | null;
test_phone_number?: string | null;
timezone?: string | null;
};