feat: add openai embedding service

This commit is contained in:
Abhishek Kumar 2026-01-17 13:36:26 +05:30
parent eb41285204
commit 3f0e500fde
39 changed files with 1902 additions and 339 deletions

View file

@ -5,16 +5,17 @@ Revises: dcb0a27d98c6
Create Date: 2026-01-16 13:40:17.808807
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op
from pgvector.sqlalchemy import Vector
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = 'dc33eef8dabe'
down_revision: Union[str, None] = 'dcb0a27d98c6'
revision: str = "dc33eef8dabe"
down_revision: Union[str, None] = "dcb0a27d98c6"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@ -22,75 +23,170 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Enable pgvector extension
op.execute('CREATE EXTENSION IF NOT EXISTS vector')
op.execute("CREATE EXTENSION IF NOT EXISTS vector")
sa.Enum('pending', 'processing', 'completed', 'failed', name='document_processing_status').create(op.get_bind())
op.create_table('knowledge_base_documents',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('document_uuid', sa.String(length=36), nullable=False),
sa.Column('organization_id', sa.Integer(), nullable=False),
sa.Column('filename', sa.String(length=500), nullable=False),
sa.Column('file_size_bytes', sa.Integer(), nullable=True),
sa.Column('file_hash', sa.String(length=64), nullable=True),
sa.Column('mime_type', sa.String(length=100), nullable=True),
sa.Column('source_url', sa.String(), nullable=True),
sa.Column('total_chunks', sa.Integer(), nullable=False),
sa.Column('processing_status', postgresql.ENUM('pending', 'processing', 'completed', 'failed', name='document_processing_status', create_type=False), server_default=sa.text("'pending'::document_processing_status"), nullable=False),
sa.Column('processing_error', sa.Text(), nullable=True),
sa.Column('docling_metadata', sa.JSON(), nullable=False),
sa.Column('custom_metadata', sa.JSON(), nullable=False),
sa.Column('created_by', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('archived_at', sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(['created_by'], ['users.id'], ),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
op.create_index('ix_kb_documents_created_at', 'knowledge_base_documents', ['created_at'], unique=False)
op.create_index('ix_kb_documents_organization_id', 'knowledge_base_documents', ['organization_id'], unique=False)
op.create_index('ix_kb_documents_status', 'knowledge_base_documents', ['processing_status'], unique=False)
op.create_index('ix_kb_documents_uuid', 'knowledge_base_documents', ['document_uuid'], unique=False)
op.create_index(op.f('ix_knowledge_base_documents_document_uuid'), 'knowledge_base_documents', ['document_uuid'], unique=True)
op.create_table('knowledge_base_chunks'),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('document_id', sa.Integer(), nullable=False),
sa.Column('organization_id', sa.Integer(), nullable=False),
sa.Column('chunk_text', sa.Text(), nullable=False),
sa.Column('contextualized_text', sa.Text(), nullable=True),
sa.Column('chunk_index', sa.Integer(), nullable=False),
sa.Column('chunk_metadata', sa.JSON(), nullable=False),
sa.Column('embedding_model', sa.String(length=200), nullable=False),
sa.Column('embedding_dimension', sa.Integer(), nullable=False),
sa.Column('embedding', Vector(384), nullable=True),
sa.Column('token_count', sa.Integer(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(['document_id'], ['knowledge_base_documents.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
sa.Enum(
"pending",
"processing",
"completed",
"failed",
name="document_processing_status",
).create(op.get_bind())
op.create_table(
"knowledge_base_documents",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("document_uuid", sa.String(length=36), nullable=False),
sa.Column("organization_id", sa.Integer(), nullable=False),
sa.Column("filename", sa.String(length=500), nullable=False),
sa.Column("file_size_bytes", sa.Integer(), nullable=True),
sa.Column("file_hash", sa.String(length=64), nullable=True),
sa.Column("mime_type", sa.String(length=100), nullable=True),
sa.Column("source_url", sa.String(), nullable=True),
sa.Column("total_chunks", sa.Integer(), nullable=False),
sa.Column(
"processing_status",
postgresql.ENUM(
"pending",
"processing",
"completed",
"failed",
name="document_processing_status",
create_type=False,
),
server_default=sa.text("'pending'::document_processing_status"),
nullable=False,
),
sa.Column("processing_error", sa.Text(), nullable=True),
sa.Column("docling_metadata", sa.JSON(), nullable=False),
sa.Column("custom_metadata", sa.JSON(), nullable=False),
sa.Column("created_by", sa.Integer(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("archived_at", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(
["created_by"],
["users.id"],
),
sa.ForeignKeyConstraint(
["organization_id"], ["organizations.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_kb_documents_created_at",
"knowledge_base_documents",
["created_at"],
unique=False,
)
op.create_index(
"ix_kb_documents_organization_id",
"knowledge_base_documents",
["organization_id"],
unique=False,
)
op.create_index(
"ix_kb_documents_status",
"knowledge_base_documents",
["processing_status"],
unique=False,
)
op.create_index(
"ix_kb_documents_uuid",
"knowledge_base_documents",
["document_uuid"],
unique=False,
)
op.create_index(
op.f("ix_knowledge_base_documents_document_uuid"),
"knowledge_base_documents",
["document_uuid"],
unique=True,
)
op.create_table(
"knowledge_base_chunks",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("document_id", sa.Integer(), nullable=False),
sa.Column("organization_id", sa.Integer(), nullable=False),
sa.Column("chunk_text", sa.Text(), nullable=False),
sa.Column("contextualized_text", sa.Text(), nullable=True),
sa.Column("chunk_index", sa.Integer(), nullable=False),
sa.Column("chunk_metadata", sa.JSON(), nullable=False),
sa.Column("embedding_model", sa.String(length=200), nullable=False),
sa.Column("embedding_dimension", sa.Integer(), nullable=False),
sa.Column("embedding", Vector(1536), nullable=True),
sa.Column("token_count", sa.Integer(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(
["document_id"], ["knowledge_base_documents.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(
["organization_id"], ["organizations.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_kb_chunks_chunk_index",
"knowledge_base_chunks",
["chunk_index"],
unique=False,
)
op.create_index(
"ix_kb_chunks_document_id",
"knowledge_base_chunks",
["document_id"],
unique=False,
)
op.create_index(
"ix_kb_chunks_embedding_ivfflat",
"knowledge_base_chunks",
["embedding"],
unique=False,
postgresql_using="ivfflat",
postgresql_with={"lists": 100},
postgresql_ops={"embedding": "vector_cosine_ops"},
)
op.create_index(
"ix_kb_chunks_organization_id",
"knowledge_base_chunks",
["organization_id"],
unique=False,
)
op.create_index('ix_kb_chunks_chunk_index', 'knowledge_base_chunks', ['chunk_index'], unique=False)
op.create_index('ix_kb_chunks_document_id', 'knowledge_base_chunks', ['document_id'], unique=False)
op.create_index('ix_kb_chunks_embedding_ivfflat', 'knowledge_base_chunks', ['embedding'], unique=False, postgresql_using='ivfflat', postgresql_with={'lists': 100}, postgresql_ops={'embedding': 'vector_cosine_ops'})
op.create_index('ix_kb_chunks_organization_id', 'knowledge_base_chunks', ['organization_id'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index('ix_kb_chunks_organization_id', table_name='knowledge_base_chunks')
op.drop_index('ix_kb_chunks_embedding_ivfflat', table_name='knowledge_base_chunks', postgresql_using='ivfflat', postgresql_with={'lists': 100}, postgresql_ops={'embedding': 'vector_cosine_ops'})
op.drop_index('ix_kb_chunks_document_id', table_name='knowledge_base_chunks')
op.drop_index('ix_kb_chunks_chunk_index', table_name='knowledge_base_chunks')
op.drop_table('knowledge_base_chunks')
op.drop_index(op.f('ix_knowledge_base_documents_document_uuid'), table_name='knowledge_base_documents')
op.drop_index('ix_kb_documents_uuid', table_name='knowledge_base_documents')
op.drop_index('ix_kb_documents_status', table_name='knowledge_base_documents')
op.drop_index('ix_kb_documents_organization_id', table_name='knowledge_base_documents')
op.drop_index('ix_kb_documents_created_at', table_name='knowledge_base_documents')
op.drop_table('knowledge_base_documents')
sa.Enum('pending', 'processing', 'completed', 'failed', name='document_processing_status').drop(op.get_bind())
op.drop_index("ix_kb_chunks_organization_id", table_name="knowledge_base_chunks")
op.drop_index(
"ix_kb_chunks_embedding_ivfflat",
table_name="knowledge_base_chunks",
postgresql_using="ivfflat",
postgresql_with={"lists": 100},
postgresql_ops={"embedding": "vector_cosine_ops"},
)
op.drop_index("ix_kb_chunks_document_id", table_name="knowledge_base_chunks")
op.drop_index("ix_kb_chunks_chunk_index", table_name="knowledge_base_chunks")
op.drop_table("knowledge_base_chunks")
op.drop_index(
op.f("ix_knowledge_base_documents_document_uuid"),
table_name="knowledge_base_documents",
)
op.drop_index("ix_kb_documents_uuid", table_name="knowledge_base_documents")
op.drop_index("ix_kb_documents_status", table_name="knowledge_base_documents")
op.drop_index(
"ix_kb_documents_organization_id", table_name="knowledge_base_documents"
)
op.drop_index("ix_kb_documents_created_at", table_name="knowledge_base_documents")
op.drop_table("knowledge_base_documents")
sa.Enum(
"pending",
"processing",
"completed",
"failed",
name="document_processing_status",
).drop(op.get_bind())
# Note: We don't drop the vector extension as it may be used by other tables
# If you want to drop it, uncomment the following line:

View file

@ -332,6 +332,7 @@ class KnowledgeBaseClient(BaseDBClient):
limit: int = 5,
document_ids: Optional[List[int]] = None,
document_uuids: Optional[List[str]] = None,
embedding_model: Optional[str] = None,
) -> List[dict]:
"""Search for similar chunks using vector similarity.
@ -344,6 +345,7 @@ class KnowledgeBaseClient(BaseDBClient):
limit: Maximum number of results to return
document_ids: Optional list of document IDs to filter by
document_uuids: Optional list of document UUIDs to filter by
embedding_model: Optional embedding model to filter by (for dimension compatibility)
Returns:
List of dictionaries with chunk data and similarity scores, ordered by similarity (highest first)
@ -359,23 +361,37 @@ class KnowledgeBaseClient(BaseDBClient):
"c.organization_id = $2",
"d.is_active = true",
]
params = [None, organization_id, limit] # $1 will be embedding_str, $3 is limit
params = [
None,
organization_id,
limit,
] # $1 will be embedding_str, $3 is limit
param_index = 4 # Next available parameter index
# Add document_ids filter if provided
if document_ids:
placeholders = ", ".join(f"${param_index + i}" for i in range(len(document_ids)))
placeholders = ", ".join(
f"${param_index + i}" for i in range(len(document_ids))
)
where_conditions.append(f"c.document_id IN ({placeholders})")
params.extend(document_ids)
param_index += len(document_ids)
# Add document_uuids filter if provided
if document_uuids:
placeholders = ", ".join(f"${param_index + i}" for i in range(len(document_uuids)))
placeholders = ", ".join(
f"${param_index + i}" for i in range(len(document_uuids))
)
where_conditions.append(f"d.document_uuid IN ({placeholders})")
params.extend(document_uuids)
param_index += len(document_uuids)
# Add embedding_model filter if provided (for dimension compatibility)
if embedding_model:
where_conditions.append(f"c.embedding_model = ${param_index}")
params.append(embedding_model)
param_index += 1
# Build the complete SQL query
where_clause = " AND ".join(where_conditions)
query_sql = f"""

View file

@ -2,6 +2,7 @@ import uuid
from datetime import UTC, datetime
from loguru import logger
from pgvector.sqlalchemy import Vector
from sqlalchemy import (
JSON,
Boolean,
@ -19,7 +20,6 @@ from sqlalchemy import (
and_,
text,
)
from pgvector.sqlalchemy import Vector
from sqlalchemy.orm import declarative_base, relationship
from ..enums import (
@ -929,7 +929,13 @@ class KnowledgeBaseDocumentModel(Base):
source_url = Column(String, nullable=True) # If document was fetched from URL
total_chunks = Column(Integer, nullable=False, default=0)
processing_status = Column(
Enum("pending", "processing", "completed", "failed", name="document_processing_status"),
Enum(
"pending",
"processing",
"completed",
"failed",
name="document_processing_status",
),
nullable=False,
default="pending",
server_default=text("'pending'::document_processing_status"),
@ -937,7 +943,9 @@ class KnowledgeBaseDocumentModel(Base):
processing_error = Column(Text, nullable=True)
# Docling conversion metadata
docling_metadata = Column(JSON, nullable=False, default=dict) # Store docling document metadata
docling_metadata = Column(
JSON, nullable=False, default=dict
) # Store docling document metadata
# Custom metadata (user-defined tags, categories, etc.)
custom_metadata = Column(JSON, nullable=False, default=dict)
@ -1000,21 +1008,31 @@ class KnowledgeBaseChunkModel(Base):
# Chunk content
chunk_text = Column(Text, nullable=False) # The actual chunk text
contextualized_text = Column(Text, nullable=True) # Enriched text from chunker.contextualize()
contextualized_text = Column(
Text, nullable=True
) # Enriched text from chunker.contextualize()
# Chunk positioning and metadata
chunk_index = Column(Integer, nullable=False) # Position in document (0-based)
# Docling chunk metadata
chunk_metadata = Column(JSON, nullable=False, default=dict) # Store chunk.meta if available
chunk_metadata = Column(
JSON, nullable=False, default=dict
) # Store chunk.meta if available
# Embedding configuration
embedding_model = Column(String(200), nullable=False) # e.g., "sentence-transformers/all-MiniLM-L6-v2"
embedding_dimension = Column(Integer, nullable=False) # e.g., 384 for all-MiniLM-L6-v2
embedding_model = Column(
String(200), nullable=False
) # e.g., "sentence-transformers/all-MiniLM-L6-v2"
embedding_dimension = Column(
Integer, nullable=False
) # e.g., 384 for all-MiniLM-L6-v2
# Vector embedding (pgvector column)
# The dimension should match the embedding_dimension field
embedding = Column(Vector(384), nullable=True) # Default to 384 for all-MiniLM-L6-v2
# Default: 1536 dimensions for OpenAI text-embedding-3-small
# SentenceTransformer (384-dim) also supported but stored as 384-dim vectors
embedding = Column(Vector(1536), nullable=True)
# Token count (useful for chunking strategy analysis)
token_count = Column(Integer, nullable=True)
@ -1036,6 +1054,9 @@ class KnowledgeBaseChunkModel(Base):
Index("ix_kb_chunks_document_id", "document_id"),
Index("ix_kb_chunks_organization_id", "organization_id"),
Index("ix_kb_chunks_chunk_index", "chunk_index"),
Index(
"ix_kb_chunks_embedding_model", "embedding_model"
), # For filtering by model
# Vector similarity search index (using IVFFlat or HNSW)
# IVFFlat is good for datasets with 10k-1M vectors
# HNSW is better for larger datasets but uses more memory

View file

@ -103,6 +103,10 @@ async def process_document(
The document status will be updated from 'pending' -> 'processing' -> 'completed' or 'failed'.
Embedding Services:
* openai (default): High-quality 1536-dimensional embeddings (requires OPENAI_API_KEY)
* sentence_transformer: Free, offline-capable, 384-dimensional embeddings
Access Control:
* Users can only process documents in their organization.
"""
@ -129,11 +133,13 @@ async def process_document(
document.id,
request.s3_key,
user.selected_organization_id,
128, # max_tokens (default)
request.embedding_service,
)
logger.info(
f"Created document {request.document_uuid} (id={document.id}) and enqueued processing, "
f"org {user.selected_organization_id}"
f"Created document {request.document_uuid} (id={document.id}) and enqueued processing "
f"with {request.embedding_service} embeddings, org {user.selected_organization_id}"
)
return DocumentResponseSchema(
@ -277,9 +283,7 @@ async def get_document(
raise
except Exception as exc:
logger.error(f"Error getting document: {exc}")
raise HTTPException(
status_code=500, detail="Failed to get document"
) from exc
raise HTTPException(status_code=500, detail="Failed to get document") from exc
@router.delete(
@ -342,15 +346,26 @@ async def search_chunks(
try:
# Import here to avoid circular dependency
from api.services.admin_utils.local_exec import DocumentProcessor
from api.services.gen_ai import OpenAIEmbeddingService
# Initialize processor (reuses cached models)
processor = DocumentProcessor(
# Try to get user's embeddings configuration
user_config = await db_client.get_user_configurations(user.id)
embeddings_api_key = None
embeddings_model = None
if user_config.embeddings:
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
# Initialize embedding service with user config or fallback to env
embedding_service = OpenAIEmbeddingService(
db_client=db_client,
api_key=embeddings_api_key,
model_id=embeddings_model or "text-embedding-3-small",
)
# Perform search
results = await processor.search_similar_chunks(
results = await embedding_service.search_similar_chunks(
query=request.query,
organization_id=user.selected_organization_id,
limit=request.limit,

View file

@ -32,6 +32,7 @@ class DefaultConfigurationsResponse(TypedDict):
llm: dict[str, dict]
tts: dict[str, dict]
stt: dict[str, dict]
embeddings: dict[str, dict]
default_providers: dict[str, str]
@ -50,6 +51,10 @@ async def get_default_configurations() -> DefaultConfigurationsResponse:
provider: model_cls.model_json_schema()
for provider, model_cls in REGISTRY[ServiceType.STT].items()
},
"embeddings": {
provider: model_cls.model_json_schema()
for provider, model_cls in REGISTRY[ServiceType.EMBEDDINGS].items()
},
"default_providers": DEFAULT_SERVICE_PROVIDERS,
}
return configurations
@ -69,6 +74,7 @@ class UserConfigurationRequestResponseSchema(BaseModel):
llm: dict[str, Union[str, float]] | None = None
tts: dict[str, Union[str, float]] | None = None
stt: dict[str, Union[str, float]] | None = None
embeddings: dict[str, Union[str, float]] | None = None
test_phone_number: str | None = None
timezone: str | None = None
organization_pricing: dict[str, Union[float, str, bool]] | None = None

View file

@ -1,7 +1,7 @@
"""Pydantic schemas for knowledge base operations."""
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
@ -29,6 +29,11 @@ class ProcessDocumentRequestSchema(BaseModel):
document_uuid: str = Field(..., description="Document UUID to process")
s3_key: str = Field(..., description="S3 key of the uploaded file")
embedding_service: Literal["sentence_transformer", "openai"] = Field(
default="openai",
description="Embedding service to use for processing. "
"Options: 'openai' (default, 1536-dim, requires API key) or 'sentence_transformer' (free, 384-dim)",
)
class DocumentResponseSchema(BaseModel):

View file

@ -3,6 +3,7 @@ from datetime import datetime
from pydantic import BaseModel
from api.services.configuration.registry import (
EmbeddingsConfig,
LLMConfig,
STTConfig,
TTSConfig,
@ -13,6 +14,7 @@ class UserConfiguration(BaseModel):
llm: LLMConfig | None = None
stt: STTConfig | None = None
tts: TTSConfig | None = None
embeddings: EmbeddingsConfig | None = None
test_phone_number: str | None = None
timezone: str | None = None
last_validated_at: datetime | None = None

View file

@ -48,6 +48,12 @@ class UserConfigurationValidator:
status_list.extend(self._validate_service(configuration.llm, "llm"))
status_list.extend(self._validate_service(configuration.stt, "stt"))
status_list.extend(self._validate_service(configuration.tts, "tts"))
# Embeddings is optional - only validate if configured
status_list.extend(
self._validate_service(
configuration.embeddings, "embeddings", required=False
)
)
if status_list:
raise ValueError(status_list)
@ -55,11 +61,16 @@ class UserConfigurationValidator:
return {"status": [{"model": "all", "message": "ok"}]}
def _validate_service(
self, service_config: Optional[ServiceConfig], service_name: str
self,
service_config: Optional[ServiceConfig],
service_name: str,
required: bool = True,
) -> list[APIKeyStatus]:
"""Validate a service configuration and return any error statuses."""
if not service_config:
return [{"model": service_name, "message": "API key is missing"}]
if required:
return [{"model": service_name, "message": "API key is missing"}]
return [] # Optional service not configured is OK
provider = service_config.provider
api_key = service_config.api_key

View file

@ -13,6 +13,7 @@ left as ``None``.
from api.services.configuration.registry import (
DeepgramSTTConfiguration,
ElevenlabsTTSConfiguration,
OpenAIEmbeddingsConfiguration,
OpenAILLMService,
ServiceProviders,
)
@ -22,6 +23,7 @@ _DEFAULTS = {
"llm": (ServiceProviders.OPENAI, OpenAILLMService),
"tts": (ServiceProviders.ELEVENLABS, ElevenlabsTTSConfiguration),
"stt": (ServiceProviders.DEEPGRAM, DeepgramSTTConfiguration),
"embeddings": (ServiceProviders.OPENAI, OpenAIEmbeddingsConfiguration),
}
# Public mapping of service name -> default provider

View file

@ -64,6 +64,7 @@ def mask_user_config(config: UserConfiguration) -> Dict[str, Any]:
"llm": _mask_service(config.llm),
"tts": _mask_service(config.tts),
"stt": _mask_service(config.stt),
"embeddings": _mask_service(config.embeddings),
"test_phone_number": config.test_phone_number,
"timezone": config.timezone,
}

View file

@ -9,7 +9,7 @@ from typing import Dict
from api.schemas.user_configuration import UserConfiguration
from api.services.configuration.masking import is_mask_of
SERVICE_FIELDS = ("llm", "tts", "stt")
SERVICE_FIELDS = ("llm", "tts", "stt", "embeddings")
def merge_user_configurations(

View file

@ -8,6 +8,7 @@ class ServiceType(Enum):
LLM = auto()
TTS = auto()
STT = auto()
EMBEDDINGS = auto()
class ServiceProviders(str, Enum):
@ -50,11 +51,16 @@ class BaseSTTConfiguration(BaseServiceConfiguration):
model: str
class BaseEmbeddingsConfiguration(BaseServiceConfiguration):
model: str
# Unified registry for all service types
REGISTRY: Dict[ServiceType, Dict[str, Type[BaseServiceConfiguration]]] = {
ServiceType.LLM: {},
ServiceType.TTS: {},
ServiceType.STT: {},
ServiceType.EMBEDDINGS: {},
}
T = TypeVar("T", bound=BaseServiceConfiguration)
@ -93,6 +99,10 @@ def register_stt(cls: Type[BaseSTTConfiguration]):
return register_service(ServiceType.STT)(cls)
def register_embeddings(cls: Type[BaseEmbeddingsConfiguration]):
return register_service(ServiceType.EMBEDDINGS)(cls)
###################################################### LLM ########################################################################
# Suggested models for each provider (used for UI dropdown)
@ -436,6 +446,27 @@ STTConfig = Annotated[
Field(discriminator="provider"),
]
ServiceConfig = Annotated[
Union[LLMConfig, TTSConfig, STTConfig], Field(discriminator="provider")
###################################################### EMBEDDINGS ########################################################################
OPENAI_EMBEDDING_MODELS = ["text-embedding-3-small"]
@register_embeddings
class OpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
provider: Literal[ServiceProviders.OPENAI] = ServiceProviders.OPENAI
model: str = Field(
default="text-embedding-3-small",
json_schema_extra={"examples": OPENAI_EMBEDDING_MODELS},
)
api_key: str
EmbeddingsConfig = Annotated[
Union[OpenAIEmbeddingsConfiguration],
Field(discriminator="provider"),
]
ServiceConfig = Annotated[
Union[LLMConfig, TTSConfig, STTConfig, EmbeddingsConfig],
Field(discriminator="provider"),
]

View file

@ -133,9 +133,7 @@ class S3FileSystem(BaseFileSystem):
async with self.session.client(
"s3", region_name=self.region_name
) as s3_client:
await s3_client.download_file(
self.bucket_name, source_path, local_path
)
await s3_client.download_file(self.bucket_name, source_path, local_path)
return True
except ClientError:
return False

View file

@ -0,0 +1,15 @@
"""Generative AI services for embeddings and document processing."""
from .embedding import (
BaseEmbeddingService,
EmbeddingAPIKeyNotConfiguredError,
OpenAIEmbeddingService,
SentenceTransformerEmbeddingService,
)
__all__ = [
"BaseEmbeddingService",
"EmbeddingAPIKeyNotConfiguredError",
"SentenceTransformerEmbeddingService",
"OpenAIEmbeddingService",
]

View file

@ -0,0 +1,12 @@
"""Embedding services for document processing and retrieval."""
from .base import BaseEmbeddingService
from .openai_service import EmbeddingAPIKeyNotConfiguredError, OpenAIEmbeddingService
from .sentence_transformer_service import SentenceTransformerEmbeddingService
__all__ = [
"BaseEmbeddingService",
"EmbeddingAPIKeyNotConfiguredError",
"SentenceTransformerEmbeddingService",
"OpenAIEmbeddingService",
]

View file

@ -0,0 +1,75 @@
"""Base class for embedding services."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
class BaseEmbeddingService(ABC):
"""Abstract base class for embedding services.
All embedding services (SentenceTransformer, OpenAI, etc.) should inherit from this class
and implement the required methods.
"""
@abstractmethod
def get_model_id(self) -> str:
"""Return the model identifier.
Returns:
String identifier for the model (e.g., 'sentence-transformers/all-MiniLM-L6-v2')
"""
pass
@abstractmethod
def get_embedding_dimension(self) -> int:
"""Return the embedding dimension.
Returns:
Integer dimension of the embedding vectors
"""
pass
@abstractmethod
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Embed a batch of texts.
Args:
texts: List of text strings to embed
Returns:
List of embedding vectors (each vector is a list of floats)
"""
pass
@abstractmethod
async def embed_query(self, query: str) -> List[float]:
"""Embed a single query text.
Args:
query: Query text to embed
Returns:
Embedding vector as list of floats
"""
pass
@abstractmethod
async def search_similar_chunks(
self,
query: str,
organization_id: int,
limit: int = 5,
document_uuids: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""Search for similar chunks using vector similarity.
Args:
query: Search query text
organization_id: Organization ID for scoping
limit: Maximum number of results to return
document_uuids: Optional list of document UUIDs to filter by
Returns:
List of dictionaries containing chunk data and similarity scores
"""
pass

View file

@ -0,0 +1,372 @@
"""OpenAI embedding service.
This module provides document processing capabilities using:
- OpenAI's text-embedding-3-small for embeddings (1536 dimensions)
- Docling for document conversion and chunking
- pgvector for vector similarity search
"""
import os
from pathlib import Path
from typing import Any, Dict, List, Optional
from docling.chunking import HybridChunker
from docling.document_converter import DocumentConverter
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
from loguru import logger
from openai import AsyncOpenAI
from transformers import AutoTokenizer
from api.db.db_client import DBClient
from api.db.models import KnowledgeBaseChunkModel
from .base import BaseEmbeddingService
# Model configuration
DEFAULT_MODEL_ID = "text-embedding-3-small"
EMBEDDING_DIMENSION = 1536 # Dimension for text-embedding-3-small
# For chunking, we'll use the same tokenizer as SentenceTransformer
# since OpenAI uses similar tokenization
TOKENIZER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
class EmbeddingAPIKeyNotConfiguredError(Exception):
"""Raised when OpenAI API key is not configured for embeddings."""
def __init__(self):
super().__init__(
"OpenAI API key not configured. Please set your API key in "
"Model Configurations > Embedding to use document processing."
)
class OpenAIEmbeddingService(BaseEmbeddingService):
"""Embedding service using OpenAI's text-embedding-3-small."""
def __init__(
self,
db_client: DBClient,
api_key: Optional[str] = None,
model_id: str = DEFAULT_MODEL_ID,
max_tokens: int = 512,
):
"""Initialize the OpenAI embedding service.
Args:
db_client: Database client for storing documents and chunks
api_key: OpenAI API key. If not provided, the client will not be
initialized and operations will fail with a clear error.
model_id: OpenAI embedding model ID (default: text-embedding-3-small)
max_tokens: Maximum number of tokens per chunk (default: 512)
"""
self.db = db_client
self.model_id = model_id
self.max_tokens = max_tokens
# Only initialize OpenAI client if API key is provided
self._api_key_configured = bool(api_key)
if self._api_key_configured:
self.client = AsyncOpenAI(api_key=api_key)
logger.info(f"OpenAI embedding service initialized with model: {model_id}")
else:
self.client = None
logger.warning(
"OpenAI embedding service initialized without API key. "
"Operations will fail until API key is configured in Model Configurations."
)
# Initialize tokenizer for chunking
# We use a HuggingFace tokenizer for consistent chunking
logger.info(
f"Loading tokenizer for chunking: {TOKENIZER_MODEL} with max_tokens={max_tokens}"
)
try:
self.tokenizer = HuggingFaceTokenizer(
tokenizer=AutoTokenizer.from_pretrained(
TOKENIZER_MODEL,
local_files_only=True,
),
max_tokens=max_tokens,
)
logger.info("Loaded tokenizer from cache")
except Exception as e:
logger.warning(f"Tokenizer not in cache, downloading: {e}")
self.tokenizer = HuggingFaceTokenizer(
tokenizer=AutoTokenizer.from_pretrained(TOKENIZER_MODEL),
max_tokens=max_tokens,
)
logger.info("Tokenizer downloaded and cached")
# Initialize chunker
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
self.chunker = HybridChunker(tokenizer=self.tokenizer)
# Initialize document converter
self.converter = DocumentConverter()
def get_model_id(self) -> str:
"""Return the model identifier."""
return self.model_id
def get_embedding_dimension(self) -> int:
"""Return the embedding dimension."""
return EMBEDDING_DIMENSION
def _ensure_api_key_configured(self):
"""Check if API key is configured and raise error if not."""
if not self._api_key_configured or self.client is None:
raise EmbeddingAPIKeyNotConfiguredError()
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Embed a batch of texts using OpenAI API.
Args:
texts: List of text strings to embed
Returns:
List of embedding vectors (each vector is a list of floats)
Raises:
EmbeddingAPIKeyNotConfiguredError: If API key is not configured
"""
self._ensure_api_key_configured()
try:
# OpenAI API call
response = await self.client.embeddings.create(
input=texts,
model=self.model_id,
)
# Extract embeddings from response
embeddings = [item.embedding for item in response.data]
return embeddings
except Exception as e:
logger.error(f"Error generating OpenAI embeddings: {e}")
raise
async def embed_query(self, query: str) -> List[float]:
"""Embed a single query text using OpenAI API.
Args:
query: Query text to embed
Returns:
Embedding vector as list of floats
Raises:
EmbeddingAPIKeyNotConfiguredError: If API key is not configured
"""
self._ensure_api_key_configured()
embeddings = await self.embed_texts([query])
return embeddings[0]
async def search_similar_chunks(
self,
query: str,
organization_id: int,
limit: int = 5,
document_uuids: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""Search for similar chunks using vector similarity.
Args:
query: Search query text
organization_id: Organization ID for scoping
limit: Maximum number of results to return
document_uuids: Optional list of document UUIDs to filter by
Returns:
List of dictionaries with chunk data and similarity scores
Raises:
EmbeddingAPIKeyNotConfiguredError: If API key is not configured
"""
self._ensure_api_key_configured()
# Generate query embedding
query_embedding = await self.embed_query(query)
# Perform vector similarity search
results = await self.db.search_similar_chunks(
query_embedding=query_embedding,
organization_id=organization_id,
limit=limit,
document_uuids=document_uuids,
embedding_model=self.model_id,
)
return results
async def process_document(
self,
file_path: str,
organization_id: int,
created_by: int,
custom_metadata: dict = None,
):
"""Process a document: convert, chunk, embed, and store in database.
Args:
file_path: Path to the document file
organization_id: Organization ID for scoping
created_by: User ID who uploaded the document
custom_metadata: Optional custom metadata dictionary
Returns:
The created document record
"""
try:
# Extract file metadata
filename = Path(file_path).name
file_hash = self.db.compute_file_hash(file_path)
file_size = os.path.getsize(file_path)
mime_type = self.db.get_mime_type(file_path)
# Check if document already exists
existing_doc = await self.db.get_document_by_hash(
file_hash, organization_id
)
if existing_doc:
logger.info(f"Document already exists: {filename} (hash: {file_hash})")
return existing_doc
# Create document record
doc_record = await self.db.create_document(
organization_id=organization_id,
created_by=created_by,
filename=filename,
file_size_bytes=file_size,
file_hash=file_hash,
mime_type=mime_type,
custom_metadata=custom_metadata or {},
)
logger.info(f"Processing document with OpenAI embeddings: {filename}")
# Update status to processing
await self.db.update_document_status(doc_record.id, "processing")
# Step 1: Convert document using docling
logger.info("Converting document with docling...")
conversion_result = self.converter.convert(file_path)
doc = conversion_result.document
# Store docling metadata
docling_metadata = {
"num_pages": len(doc.pages) if hasattr(doc, "pages") else None,
"document_type": type(doc).__name__,
}
# Step 2: Chunk the document
logger.info(f"Chunking document with max_tokens={self.max_tokens}...")
chunks = list(self.chunker.chunk(dl_doc=doc))
total_chunks = len(chunks)
logger.info(f"Generated {total_chunks} chunks")
# Step 3: Process each chunk
chunk_texts = []
chunk_records = []
token_counts = []
for i, chunk in enumerate(chunks):
# Get chunk text
chunk_text = chunk.text
# Get contextualized text
contextualized_text = self.chunker.contextualize(chunk=chunk)
# Calculate token count
text_to_tokenize = (
contextualized_text if contextualized_text else chunk_text
)
token_count = len(
self.tokenizer.tokenizer.encode(
text_to_tokenize, add_special_tokens=False
)
)
token_counts.append(token_count)
# Prepare chunk metadata
chunk_metadata = {}
if hasattr(chunk, "meta") and chunk.meta:
chunk_metadata = {
"doc_items": (
[str(item) for item in chunk.meta.doc_items]
if hasattr(chunk.meta, "doc_items")
else []
),
"headings": (
chunk.meta.headings
if hasattr(chunk.meta, "headings")
else []
),
}
# Create chunk record (without embedding yet)
chunk_record = KnowledgeBaseChunkModel(
document_id=doc_record.id,
organization_id=organization_id,
chunk_text=chunk_text,
contextualized_text=contextualized_text,
chunk_index=i,
chunk_metadata=chunk_metadata,
embedding_model=self.model_id,
embedding_dimension=EMBEDDING_DIMENSION,
token_count=token_count,
)
chunk_records.append(chunk_record)
chunk_texts.append(text_to_tokenize)
# Log chunk statistics
if token_counts:
avg_tokens = sum(token_counts) / len(token_counts)
min_tokens = min(token_counts)
max_tokens = max(token_counts)
logger.info("Chunk token statistics:")
logger.info(f" - Average: {avg_tokens:.1f} tokens")
logger.info(f" - Min: {min_tokens} tokens")
logger.info(f" - Max: {max_tokens} tokens")
# Step 4: Generate embeddings using OpenAI API
logger.info(f"Generating embeddings using OpenAI ({self.model_id})...")
embeddings = await self.embed_texts(chunk_texts)
# Step 5: Attach embeddings to chunk records
for chunk_record, embedding in zip(chunk_records, embeddings):
chunk_record.embedding = embedding
# Step 6: Save all chunks in batch
logger.info("Storing chunks in database...")
await self.db.create_chunks_batch(chunk_records)
# Update document status to completed
await self.db.update_document_status(
doc_record.id,
"completed",
total_chunks=total_chunks,
docling_metadata=docling_metadata,
)
logger.info(f"Successfully processed document: {filename}")
logger.info(f" - Total chunks: {total_chunks}")
logger.info(f" - Embedding model: {self.model_id}")
logger.info(f" - Document ID: {doc_record.id}")
logger.info(f" - Document UUID: {doc_record.document_uuid}")
return doc_record
except Exception as e:
logger.error(f"Error processing document with OpenAI: {e}")
# Update document status to failed if it exists
if "doc_record" in locals():
await self.db.update_document_status(
doc_record.id, "failed", error_message=str(e)
)
raise

View file

@ -0,0 +1,350 @@
"""Sentence Transformer embedding service.
This module provides document processing capabilities using:
- Sentence-transformers for embeddings (all-MiniLM-L6-v2)
- Docling for document conversion and chunking
- pgvector for vector similarity search
Setup for offline usage:
1. First run: Downloads and caches models to ~/.cache/sentence_transformers
2. Subsequent runs: Uses cached models (no internet needed)
3. For fully offline mode: Set TRANSFORMERS_OFFLINE=1 and HF_HUB_OFFLINE=1
"""
import os
from pathlib import Path
from typing import Any, Dict, List, Optional
from docling.chunking import HybridChunker
from docling.document_converter import DocumentConverter
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
from loguru import logger
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
from api.db.db_client import DBClient
from api.db.models import KnowledgeBaseChunkModel
from .base import BaseEmbeddingService
# Set environment variables for model caching
os.environ.setdefault("TRANSFORMERS_OFFLINE", "0")
os.environ.setdefault("HF_HUB_OFFLINE", "0")
os.environ.setdefault(
"SENTENCE_TRANSFORMERS_HOME", os.path.expanduser("~/.cache/sentence_transformers")
)
# Model configuration
DEFAULT_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
EMBEDDING_DIMENSION = 384 # Dimension for all-MiniLM-L6-v2
class SentenceTransformerEmbeddingService(BaseEmbeddingService):
"""Embedding service using Sentence Transformers."""
def __init__(
self,
db_client: DBClient,
model_id: str = DEFAULT_MODEL_ID,
max_tokens: int = 512,
):
"""Initialize the Sentence Transformer embedding service.
Args:
db_client: Database client for storing documents and chunks
model_id: Sentence-transformers model ID (default: all-MiniLM-L6-v2)
max_tokens: Maximum number of tokens per chunk (default: 512)
Note: This applies to the contextualized text (with headings/captions)
"""
self.db = db_client
self.model_id = model_id
self.max_tokens = max_tokens
# Initialize embedding model
logger.info(f"Loading embedding model: {model_id}")
try:
# Try to load from cache first (local_files_only=True)
self.embedding_model = SentenceTransformer(
model_id,
cache_folder=os.environ.get("SENTENCE_TRANSFORMERS_HOME"),
local_files_only=True,
)
logger.info("Loaded model from cache")
except Exception as e:
logger.warning(f"Model not in cache, downloading: {e}")
# If not in cache, download it (this will cache it for next time)
self.embedding_model = SentenceTransformer(
model_id,
cache_folder=os.environ.get("SENTENCE_TRANSFORMERS_HOME"),
)
logger.info("Model downloaded and cached")
# Initialize tokenizer for chunking with max_tokens
logger.info(f"Loading tokenizer: {model_id} with max_tokens={max_tokens}")
try:
# Try to load from cache first
self.tokenizer = HuggingFaceTokenizer(
tokenizer=AutoTokenizer.from_pretrained(
model_id,
local_files_only=True,
),
max_tokens=max_tokens,
)
logger.info("Loaded tokenizer from cache")
except Exception as e:
logger.warning(f"Tokenizer not in cache, downloading: {e}")
# If not in cache, download it
self.tokenizer = HuggingFaceTokenizer(
tokenizer=AutoTokenizer.from_pretrained(model_id),
max_tokens=max_tokens,
)
logger.info("Tokenizer downloaded and cached")
# Initialize chunker
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
self.chunker = HybridChunker(tokenizer=self.tokenizer)
# Initialize document converter
self.converter = DocumentConverter()
def get_model_id(self) -> str:
"""Return the model identifier."""
return self.model_id
def get_embedding_dimension(self) -> int:
"""Return the embedding dimension."""
return EMBEDDING_DIMENSION
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Embed a batch of texts.
Args:
texts: List of text strings to embed
Returns:
List of embedding vectors (each vector is a list of floats)
"""
embeddings = self.embedding_model.encode(
texts,
show_progress_bar=False,
convert_to_numpy=True,
)
return [embedding.tolist() for embedding in embeddings]
async def embed_query(self, query: str) -> List[float]:
"""Embed a single query text.
Args:
query: Query text to embed
Returns:
Embedding vector as list of floats
"""
embedding = self.embedding_model.encode([query])[0]
return embedding.tolist()
async def search_similar_chunks(
self,
query: str,
organization_id: int,
limit: int = 5,
document_uuids: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""Search for similar chunks using vector similarity.
Returns top-k most similar chunks without any threshold filtering.
Apply similarity thresholds and reranking at the application layer.
Args:
query: Search query text
organization_id: Organization ID for scoping
limit: Maximum number of results to return
document_uuids: Optional list of document UUIDs to filter by
Returns:
List of dictionaries with chunk data and similarity scores
"""
# Generate query embedding
query_embedding = await self.embed_query(query)
# Perform vector similarity search
results = await self.db.search_similar_chunks(
query_embedding=query_embedding,
organization_id=organization_id,
limit=limit,
document_uuids=document_uuids,
embedding_model=self.model_id,
)
return results
async def process_document(
self,
file_path: str,
organization_id: int,
created_by: int,
custom_metadata: dict = None,
):
"""Process a document: convert, chunk, embed, and store in database.
Args:
file_path: Path to the document file
organization_id: Organization ID for scoping
created_by: User ID who uploaded the document
custom_metadata: Optional custom metadata dictionary
Returns:
The created document record
"""
try:
# Extract file metadata
filename = Path(file_path).name
file_hash = self.db.compute_file_hash(file_path)
file_size = os.path.getsize(file_path)
mime_type = self.db.get_mime_type(file_path)
# Check if document already exists
existing_doc = await self.db.get_document_by_hash(
file_hash, organization_id
)
if existing_doc:
logger.info(f"Document already exists: {filename} (hash: {file_hash})")
return existing_doc
# Create document record
doc_record = await self.db.create_document(
organization_id=organization_id,
created_by=created_by,
filename=filename,
file_size_bytes=file_size,
file_hash=file_hash,
mime_type=mime_type,
custom_metadata=custom_metadata or {},
)
logger.info(f"Processing document: {filename}")
# Update status to processing
await self.db.update_document_status(doc_record.id, "processing")
# Step 1: Convert document using docling
logger.info("Converting document with docling...")
conversion_result = self.converter.convert(file_path)
doc = conversion_result.document
# Store docling metadata
docling_metadata = {
"num_pages": len(doc.pages) if hasattr(doc, "pages") else None,
"document_type": type(doc).__name__,
}
# Step 2: Chunk the document
logger.info(f"Chunking document with max_tokens={self.max_tokens}...")
chunks = list(self.chunker.chunk(dl_doc=doc))
total_chunks = len(chunks)
logger.info(f"Generated {total_chunks} chunks")
# Step 3: Process each chunk
chunk_texts = []
chunk_records = []
token_counts = []
for i, chunk in enumerate(chunks):
# Get chunk text
chunk_text = chunk.text
# Get contextualized text (enriched with surrounding context)
contextualized_text = self.chunker.contextualize(chunk=chunk)
# Calculate actual token count using the tokenizer
text_to_tokenize = (
contextualized_text if contextualized_text else chunk_text
)
token_count = len(
self.tokenizer.tokenizer.encode(
text_to_tokenize, add_special_tokens=False
)
)
token_counts.append(token_count)
# Prepare chunk metadata
chunk_metadata = {}
if hasattr(chunk, "meta") and chunk.meta:
chunk_metadata = {
"doc_items": (
[str(item) for item in chunk.meta.doc_items]
if hasattr(chunk.meta, "doc_items")
else []
),
"headings": (
chunk.meta.headings
if hasattr(chunk.meta, "headings")
else []
),
}
# Create chunk record (without embedding yet)
chunk_record = KnowledgeBaseChunkModel(
document_id=doc_record.id,
organization_id=organization_id,
chunk_text=chunk_text,
contextualized_text=contextualized_text,
chunk_index=i,
chunk_metadata=chunk_metadata,
embedding_model=self.model_id,
embedding_dimension=EMBEDDING_DIMENSION,
token_count=token_count,
)
chunk_records.append(chunk_record)
# Use contextualized text for embedding if available
chunk_texts.append(text_to_tokenize)
# Log chunk statistics
if token_counts:
avg_tokens = sum(token_counts) / len(token_counts)
min_tokens = min(token_counts)
max_tokens = max(token_counts)
logger.info("Chunk token statistics:")
logger.info(f" - Average: {avg_tokens:.1f} tokens")
logger.info(f" - Min: {min_tokens} tokens")
logger.info(f" - Max: {max_tokens} tokens")
# Step 4: Generate embeddings in batch
logger.info("Generating embeddings...")
embeddings = await self.embed_texts(chunk_texts)
# Step 5: Attach embeddings to chunk records
for chunk_record, embedding in zip(chunk_records, embeddings):
chunk_record.embedding = embedding
# Step 6: Save all chunks in batch
logger.info("Storing chunks in database...")
await self.db.create_chunks_batch(chunk_records)
# Update document status to completed
await self.db.update_document_status(
doc_record.id,
"completed",
total_chunks=total_chunks,
docling_metadata=docling_metadata,
)
logger.info(f"Successfully processed document: {filename}")
logger.info(f" - Total chunks: {total_chunks}")
logger.info(f" - Document ID: {doc_record.id}")
logger.info(f" - Document UUID: {doc_record.document_uuid}")
return doc_record
except Exception as e:
logger.error(f"Error processing document: {e}")
# Update document status to failed if it exists
if "doc_record" in locals():
await self.db.update_document_status(
doc_record.id, "failed", error_message=str(e)
)
raise

View file

@ -0,0 +1,44 @@
"""
Embeddings pricing models for different providers.
Prices are per token for embedding models.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import PricingModel
class EmbeddingPricingModel(PricingModel):
"""Pricing model for token-based embedding services."""
def __init__(self, token_price: Decimal):
"""Initialize with price per token.
Args:
token_price: Cost per token for embedding
"""
self.token_price = token_price
def calculate_cost(self, token_count: int) -> Decimal:
"""Calculate cost for embedding token usage."""
return Decimal(token_count) * self.token_price
# Embeddings pricing registry
EMBEDDINGS_PRICING: Dict[str, Dict[str, EmbeddingPricingModel]] = {
ServiceProviders.OPENAI: {
"text-embedding-3-small": EmbeddingPricingModel(
token_price=Decimal("0.02") / 1_000_000, # $0.02 per 1M tokens
),
"text-embedding-3-large": EmbeddingPricingModel(
token_price=Decimal("0.13") / 1_000_000, # $0.13 per 1M tokens
),
"text-embedding-ada-002": EmbeddingPricingModel(
token_price=Decimal("0.10") / 1_000_000, # $0.10 per 1M tokens (legacy)
),
},
}

View file

@ -4,6 +4,7 @@ Main pricing registry that combines all service type pricing models.
from typing import Dict
from .embeddings import EMBEDDINGS_PRICING
from .llm import LLM_PRICING
from .stt import STT_PRICING
from .tts import TTS_PRICING
@ -13,4 +14,5 @@ PRICING_REGISTRY: Dict = {
"llm": LLM_PRICING,
"tts": TTS_PRICING,
"stt": STT_PRICING,
"embeddings": EMBEDDINGS_PRICING,
}

View file

@ -58,6 +58,7 @@ class NodeDataDTO(BaseModel):
delayed_start: bool = False
delayed_start_duration: Optional[float] = None
tool_uuids: Optional[List[str]] = None
document_uuids: Optional[List[str]] = None
trigger_path: Optional[str] = None
# Webhook node specific fields
enabled: bool = True

View file

@ -41,6 +41,10 @@ from api.services.workflow.pipecat_engine_variable_extractor import (
VariableExtractionManager,
)
from api.services.workflow.tools.calculator import get_calculator_tools, safe_calculator
from api.services.workflow.tools.knowledge_base import (
get_knowledge_base_tool,
retrieve_from_knowledge_base,
)
from api.services.workflow.tools.timezone import (
convert_time,
get_current_time,
@ -290,6 +294,48 @@ class PipecatEngine:
self.llm.register_function("get_current_time", get_current_time_func)
self.llm.register_function("convert_time", convert_time_func)
async def _register_knowledge_base_function(
self, document_uuids: list[str]
) -> None:
"""Register knowledge base retrieval function with the LLM.
Args:
document_uuids: List of document UUIDs to filter the search by
"""
logger.debug(
f"Registering knowledge base retrieval function with {len(document_uuids)} document(s)"
)
async def retrieve_kb_func(function_call_params: FunctionCallParams) -> None:
logger.info("LLM Function Call EXECUTED: retrieve_from_knowledge_base")
logger.info(f"Arguments: {function_call_params.arguments}")
try:
query = function_call_params.arguments.get("query", "")
organization_id = await self._get_organization_id()
if not organization_id:
raise ValueError(
"Organization ID not available for knowledge base retrieval"
)
result = await retrieve_from_knowledge_base(
query=query,
organization_id=organization_id,
document_uuids=document_uuids,
limit=3, # Return top 3 most relevant chunks
)
await function_call_params.result_callback(result)
except Exception as e:
logger.error(f"Knowledge base retrieval failed: {e}")
await function_call_params.result_callback(
{"error": str(e), "chunks": [], "query": query, "total_results": 0}
)
# Register the function with the LLM
self.llm.register_function("retrieve_from_knowledge_base", retrieve_kb_func)
async def _perform_variable_extraction_if_needed(
self, previous_node: Optional[Node]
) -> None:
@ -346,6 +392,10 @@ class PipecatEngine:
if node.tool_uuids and self._custom_tool_manager:
await self._custom_tool_manager.register_handlers(node.tool_uuids)
# Register knowledge base retrieval handler if node has documents
if node.document_uuids:
await self._register_knowledge_base_function(node.document_uuids)
# Set up system message and functions
(
system_message,
@ -575,6 +625,17 @@ class PipecatEngine:
# Add built-in function schemas (calculator and timezone tools)
functions.extend(self.builtin_function_schemas)
# Add knowledge base retrieval tool if node has documents
if node.document_uuids:
kb_tool_def = get_knowledge_base_tool(node.document_uuids)
kb_schema = get_function_schema(
kb_tool_def["function"]["name"],
kb_tool_def["function"]["description"],
properties=kb_tool_def["function"]["parameters"].get("properties", {}),
required=kb_tool_def["function"]["parameters"].get("required", []),
)
functions.append(kb_schema)
# Add custom tools from node.tool_uuids
if node.tool_uuids and self._custom_tool_manager:
custom_tool_schemas = await self._custom_tool_manager.get_tool_schemas(

View file

@ -0,0 +1,305 @@
"""Knowledge Base retrieval tool for workflow execution.
This module provides vector similarity search capabilities for retrieving
relevant information from the knowledge base during conversations.
Implements OpenTelemetry tracing for observability in Langfuse.
"""
import json
from typing import Any, Dict, List, Optional
from loguru import logger
from opentelemetry import trace
from api.db import db_client
from api.services.gen_ai import OpenAIEmbeddingService
from api.services.pipecat.tracing_config import is_tracing_enabled
from pipecat.utils.tracing.context_registry import (
get_current_conversation_context,
get_current_turn_context,
)
async def retrieve_from_knowledge_base(
query: str,
organization_id: int,
document_uuids: Optional[List[str]] = None,
limit: int = 3,
embeddings_api_key: Optional[str] = None,
embeddings_model: Optional[str] = None,
) -> Dict[str, Any]:
"""Retrieve relevant information from the knowledge base using vector similarity search.
Uses OpenAI text-embedding-3-small for embeddings by default. This provides
high-quality 1536-dimensional embeddings for accurate retrieval.
This function includes OpenTelemetry tracing for Langfuse observability.
Args:
query: The search query to find relevant information
organization_id: Organization ID for scoping the search
document_uuids: Optional list of document UUIDs to filter by
limit: Maximum number of chunks to return (default: 3)
embeddings_api_key: Optional API key for embedding service
embeddings_model: Optional model ID for embedding service
Returns:
Dictionary containing:
- chunks: List of relevant text chunks with metadata
- query: The original query
- total_results: Number of results returned
"""
# Create span for retrieval operation if tracing is enabled
if is_tracing_enabled():
try:
# Get parent context from turn or conversation
turn_context = get_current_turn_context()
conversation_context = get_current_conversation_context()
parent_context = turn_context or conversation_context
# Get tracer
tracer = trace.get_tracer("pipecat")
except Exception as e:
logger.debug(f"Failed to setup tracing context: {e}")
# Fall back to non-traced execution
return await _perform_retrieval(
query,
organization_id,
document_uuids,
limit,
embeddings_api_key,
embeddings_model,
)
# Create span with parent context
if parent_context:
with tracer.start_as_current_span(
"knowledge_base_retrieval", context=parent_context
) as span:
try:
# Mark trace as public for Langfuse
span.set_attribute("langfuse.trace.public", True)
# Add operation metadata
span.set_attribute(
"gen_ai.operation.name", "knowledge_base_retrieval"
)
span.set_attribute("retrieval.query", query)
span.set_attribute("retrieval.limit", limit)
span.set_attribute("retrieval.organization_id", organization_id)
# Add document filter info
if document_uuids:
span.set_attribute(
"retrieval.document_count", len(document_uuids)
)
span.set_attribute(
"retrieval.document_uuids", json.dumps(document_uuids)
)
# Perform the actual retrieval
result = await _perform_retrieval(
query,
organization_id,
document_uuids,
limit,
embeddings_api_key,
embeddings_model,
)
# Add result metadata to span
span.set_attribute(
"retrieval.results_count", result["total_results"]
)
if result.get("error"):
span.set_attribute("retrieval.error", result["error"])
span.set_status(
trace.Status(trace.StatusCode.ERROR, result["error"])
)
else:
# Add similarity scores
if result["chunks"]:
similarities = [
chunk["similarity"] for chunk in result["chunks"]
]
span.set_attribute(
"retrieval.avg_similarity",
round(sum(similarities) / len(similarities), 4),
)
span.set_attribute(
"retrieval.max_similarity", max(similarities)
)
span.set_attribute(
"retrieval.min_similarity", min(similarities)
)
# Add retrieved documents info
filenames = list(
set(chunk["filename"] for chunk in result["chunks"])
)
span.set_attribute(
"retrieval.source_files", json.dumps(filenames)
)
# Add output as JSON for Langfuse
output_data = {
"query": query,
"chunks_retrieved": len(result["chunks"]),
"chunks": [
{
"text": chunk["text"][:200] + "..."
if len(chunk["text"]) > 200
else chunk["text"],
"filename": chunk["filename"],
"similarity": chunk["similarity"],
}
for chunk in result["chunks"]
],
}
span.set_attribute("output", json.dumps(output_data))
return result
except Exception as e:
logger.error(f"Error in traced retrieval: {e}")
span.record_exception(e)
span.set_status(trace.Status(trace.StatusCode.ERROR, str(e)))
raise
else:
# No parent context - perform retrieval without tracing
logger.debug(
"No parent context available for knowledge base retrieval tracing"
)
return await _perform_retrieval(
query,
organization_id,
document_uuids,
limit,
embeddings_api_key,
embeddings_model,
)
else:
# Tracing is disabled - perform retrieval without tracing
return await _perform_retrieval(
query,
organization_id,
document_uuids,
limit,
embeddings_api_key,
embeddings_model,
)
async def _perform_retrieval(
query: str,
organization_id: int,
document_uuids: Optional[List[str]],
limit: int,
embeddings_api_key: Optional[str] = None,
embeddings_model: Optional[str] = None,
) -> Dict[str, Any]:
"""Internal function to perform the actual retrieval operation.
Separated from tracing logic for cleaner code organization.
Uses OpenAI embeddings by default for high-quality retrieval.
"""
try:
# Create a new embedding service instance
# Uses OpenAI text-embedding-3-small by default, or user-provided config
embedding_service = OpenAIEmbeddingService(
db_client=db_client,
max_tokens=128, # This is only used for chunking, not for retrieval
api_key=embeddings_api_key,
model_id=embeddings_model or "text-embedding-3-small",
)
# Perform vector similarity search
results = await embedding_service.search_similar_chunks(
query=query,
organization_id=organization_id,
limit=limit,
document_uuids=document_uuids,
)
# Format results for LLM consumption
chunks = []
for result in results:
chunk_info = {
"text": result.get("contextualized_text") or result.get("chunk_text"),
"filename": result.get("filename"),
"similarity": round(result.get("similarity", 0), 4),
"chunk_index": result.get("chunk_index"),
}
chunks.append(chunk_info)
logger.info(
f"Knowledge base retrieval: query='{query}', "
f"results={len(chunks)}, "
f"document_filter={document_uuids}"
)
return {
"chunks": chunks,
"query": query,
"total_results": len(chunks),
}
except Exception as e:
logger.error(f"Error retrieving from knowledge base: {e}")
return {
"error": str(e),
"chunks": [],
"query": query,
"total_results": 0,
}
def get_knowledge_base_tool(
document_uuids: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""Get knowledge base retrieval tool definition for LLM function calling.
Args:
document_uuids: Optional list of document UUIDs to include in description
Returns:
Tool definition compatible with LLM function calling
"""
# Build description based on whether specific documents are filtered
if document_uuids and len(document_uuids) > 0:
description = (
"Retrieve relevant information from specific documents in the knowledge base. "
"Use this tool when you need to look up facts, policies, procedures, or any information "
"that might be stored in the available documents. The search will only look in the "
f"documents associated with this conversation step ({len(document_uuids)} document(s) available)."
)
else:
description = (
"Retrieve relevant information from the knowledge base. "
"Use this tool when you need to look up facts, policies, procedures, or any information "
"that might be stored in the knowledge base documents."
)
return {
"type": "function",
"function": {
"name": "retrieve_from_knowledge_base",
"description": description,
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": (
"The search query to find relevant information. "
"Be specific and use natural language. "
"Example: 'What is the refund policy for canceled orders?'"
),
}
},
"required": ["query"],
},
},
}

View file

@ -48,6 +48,7 @@ class Node:
self.delayed_start = data.delayed_start
self.delayed_start_duration = data.delayed_start_duration
self.tool_uuids = data.tool_uuids
self.document_uuids = data.document_uuids
self.data = data
@ -189,16 +190,6 @@ class WorkflowGraph:
in_d, out_d = in_deg[n.id], out_deg[n.id]
match n.node_type:
case NodeType.startNode:
if in_d != 0 or out_d < 1:
errors.append(
WorkflowError(
kind=ItemKind.node,
id=n.id,
field=None,
message=f"StartNode must have at least 1 outgoing edge",
)
)
case NodeType.endNode:
if in_d < 1 or out_d != 0:
errors.append(

View file

@ -2,25 +2,33 @@
import os
import tempfile
from typing import Literal
from docling.chunking import HybridChunker
from docling.document_converter import DocumentConverter
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
from loguru import logger
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
from api.db import db_client
from api.db.models import KnowledgeBaseChunkModel, KnowledgeBaseDocumentModel
from api.db.models import KnowledgeBaseChunkModel
from api.services.gen_ai import (
OpenAIEmbeddingService,
SentenceTransformerEmbeddingService,
)
from api.services.storage import storage_fs
# Constants
EMBED_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
EMBEDDING_DIMENSION = 384
# For tokenization/chunking - use SentenceTransformer tokenizer as baseline
TOKENIZER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
async def process_knowledge_base_document(
ctx, document_id: int, s3_key: str, organization_id: int, max_tokens: int = 128
ctx,
document_id: int,
s3_key: str,
organization_id: int,
max_tokens: int = 128,
embedding_service: Literal["sentence_transformer", "openai"] = "openai",
):
"""Process a knowledge base document: download, chunk, embed, and store.
@ -30,6 +38,9 @@ async def process_knowledge_base_document(
s3_key: S3 key where the file is stored
organization_id: Organization ID
max_tokens: Maximum number of tokens per chunk (default: 128)
embedding_service: Embedding service to use (default: "openai")
- "openai": Use OpenAI text-embedding-3-small (1536-dim, requires API key)
- "sentence_transformer": Use SentenceTransformer (all-MiniLM-L6-v2, 384-dim, free)
"""
logger.info(
f"Starting knowledge base document processing for document_id={document_id}, "
@ -42,8 +53,14 @@ async def process_knowledge_base_document(
# Update status to processing
await db_client.update_document_status(document_id, "processing")
# Create temp file for download
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
# Extract file extension from S3 key
filename = s3_key.split("/")[-1]
file_extension = (
os.path.splitext(filename)[1] or ".bin"
) # Default to .bin if no extension
# Create temp file for download with correct extension
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=file_extension)
temp_file_path = temp_file.name
temp_file.close()
@ -108,27 +125,58 @@ async def process_knowledge_base_document(
mime_type=mime_type,
)
# Initialize models for processing
cache_dir = os.path.expanduser("~/.cache/hf_models")
logger.info(f"Loading embedding model: {EMBED_MODEL_ID} (cache: {cache_dir})")
embedding_model = SentenceTransformer(
EMBED_MODEL_ID,
cache_folder=cache_dir,
)
# Initialize the embedding service based on the parameter
if embedding_service == "openai":
logger.info(
f"Initializing OpenAI embedding service with max_tokens={max_tokens}"
)
# Try to get user's embeddings configuration
embeddings_api_key = None
embeddings_model = None
if document.created_by:
user_config = await db_client.get_user_configurations(
document.created_by
)
if user_config.embeddings:
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
logger.info(
f"Using user embeddings config: model={embeddings_model}"
)
logger.info(f"Loading tokenizer: {EMBED_MODEL_ID} (cache: {cache_dir})")
tokenizer = HuggingFaceTokenizer(
tokenizer=AutoTokenizer.from_pretrained(
EMBED_MODEL_ID,
cache_dir=cache_dir,
),
max_tokens=max_tokens,
)
# Check if API key is configured
if not embeddings_api_key:
error_message = (
"OpenAI API key not configured. Please set your API key in "
"Model Configurations > Embedding to process documents."
)
logger.warning(f"Document {document_id}: {error_message}")
await db_client.update_document_status(
document_id, "failed", error_message=error_message
)
return
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
chunker = HybridChunker(tokenizer=tokenizer)
service = OpenAIEmbeddingService(
db_client=db_client,
max_tokens=max_tokens,
api_key=embeddings_api_key,
model_id=embeddings_model or "text-embedding-3-small",
)
elif embedding_service == "sentence_transformer":
logger.info(
f"Initializing SentenceTransformer embedding service with max_tokens={max_tokens}"
)
service = SentenceTransformerEmbeddingService(
db_client=db_client,
max_tokens=max_tokens,
)
else:
raise ValueError(
f"Invalid embedding_service: {embedding_service}. "
f"Must be 'sentence_transformer' or 'openai'"
)
# Convert document with docling
# Step 1: Convert document with docling
logger.info("Converting document with docling")
converter = DocumentConverter()
conversion_result = converter.convert(temp_file_path)
@ -140,13 +188,26 @@ async def process_knowledge_base_document(
"document_type": type(doc).__name__,
}
# Chunk the document
# Step 2: Initialize tokenizer for chunking
logger.info(
f"Loading tokenizer: {TOKENIZER_MODEL} with max_tokens={max_tokens}"
)
tokenizer = HuggingFaceTokenizer(
tokenizer=AutoTokenizer.from_pretrained(TOKENIZER_MODEL),
max_tokens=max_tokens,
)
# Step 3: Initialize chunker
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
chunker = HybridChunker(tokenizer=tokenizer)
# Step 4: Chunk the document
logger.info(f"Chunking document with max_tokens={max_tokens}")
chunks = list(chunker.chunk(dl_doc=doc))
total_chunks = len(chunks)
logger.info(f"Generated {total_chunks} chunks")
# Process each chunk
# Step 5: Process each chunk
chunk_texts = []
chunk_records = []
token_counts = []
@ -156,7 +217,9 @@ async def process_knowledge_base_document(
contextualized_text = chunker.contextualize(chunk=chunk)
# Calculate token count
text_to_tokenize = contextualized_text if contextualized_text else chunk_text
text_to_tokenize = (
contextualized_text if contextualized_text else chunk_text
)
token_count = len(
tokenizer.tokenizer.encode(text_to_tokenize, add_special_tokens=False)
)
@ -176,7 +239,7 @@ async def process_knowledge_base_document(
),
}
# Create chunk record
# Create chunk record (without embedding yet)
chunk_record = KnowledgeBaseChunkModel(
document_id=document_id,
organization_id=organization_id,
@ -184,8 +247,8 @@ async def process_knowledge_base_document(
contextualized_text=contextualized_text,
chunk_index=i,
chunk_metadata=chunk_metadata,
embedding_model=EMBED_MODEL_ID,
embedding_dimension=EMBEDDING_DIMENSION,
embedding_model=service.get_model_id(),
embedding_dimension=service.get_embedding_dimension(),
token_count=token_count,
)
@ -196,29 +259,25 @@ async def process_knowledge_base_document(
if token_counts:
avg_tokens = sum(token_counts) / len(token_counts)
min_tokens = min(token_counts)
max_tokens = max(token_counts)
logger.info(f"Chunk token statistics:")
max_tokens_actual = max(token_counts)
logger.info("Chunk token statistics:")
logger.info(f" - Average: {avg_tokens:.1f} tokens")
logger.info(f" - Min: {min_tokens} tokens")
logger.info(f" - Max: {max_tokens} tokens")
logger.info(f" - Max: {max_tokens_actual} tokens")
# Generate embeddings in batch
logger.info("Generating embeddings")
embeddings = embedding_model.encode(
chunk_texts,
show_progress_bar=False,
convert_to_numpy=True,
)
# Step 6: Generate embeddings using the embedding service
logger.info(f"Generating embeddings using {embedding_service}")
embeddings = await service.embed_texts(chunk_texts)
# Attach embeddings to chunk records
# Step 7: Attach embeddings to chunk records
for chunk_record, embedding in zip(chunk_records, embeddings):
chunk_record.embedding = embedding.tolist()
chunk_record.embedding = embedding
# Save chunks in database
# Step 8: Save chunks in database
logger.info("Storing chunks in database")
await db_client.create_chunks_batch(chunk_records)
# Update document status to completed
# Step 9: Update document status to completed
await db_client.update_document_status(
document_id,
"completed",

View file

@ -7,8 +7,10 @@ import {
ReactFlow,
} from "@xyflow/react";
import { BrushCleaning, Maximize2, Minus, Plus, Rocket, Settings, Variable } from 'lucide-react';
import React, { useMemo, useState } from 'react';
import React, { useEffect, useMemo, useState } from 'react';
import { listDocumentsApiV1KnowledgeBaseDocumentsGet, listToolsApiV1ToolsGet } from '@/client';
import type { DocumentResponseSchema, ToolResponse } from '@/client/types.gen';
import { FlowEdge, FlowNode, NodeType } from "@/components/flow/types";
import { Button } from '@/components/ui/button';
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/tooltip';
@ -63,6 +65,8 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
const [isConfigurationsDialogOpen, setIsConfigurationsDialogOpen] = useState(false);
const [isEmbedDialogOpen, setIsEmbedDialogOpen] = useState(false);
const [isPhoneCallDialogOpen, setIsPhoneCallDialogOpen] = useState(false);
const [documents, setDocuments] = useState<DocumentResponseSchema[] | undefined>(undefined);
const [tools, setTools] = useState<ToolResponse[] | undefined>(undefined);
const {
rfInstance,
@ -95,6 +99,36 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
getAccessToken
});
// Fetch documents and tools once for the entire workflow
useEffect(() => {
const fetchData = async () => {
try {
const accessToken = await getAccessToken();
// Fetch documents
const documentsResponse = await listDocumentsApiV1KnowledgeBaseDocumentsGet({
headers: { Authorization: `Bearer ${accessToken}` },
query: { limit: 100 },
});
if (documentsResponse.data) {
setDocuments(documentsResponse.data.documents);
}
// Fetch tools
const toolsResponse = await listToolsApiV1ToolsGet({
headers: { Authorization: `Bearer ${accessToken}` },
});
if (toolsResponse.data) {
setTools(toolsResponse.data);
}
} catch (error) {
console.error('Failed to fetch documents and tools:', error);
}
};
fetchData();
}, [getAccessToken]);
// Memoize defaultEdgeOptions to prevent unnecessary re-renders
const defaultEdgeOptions = useMemo(() => ({
animated: true,
@ -102,7 +136,11 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
}), []);
// Memoize the context value to prevent unnecessary re-renders
const workflowContextValue = useMemo(() => ({ saveWorkflow }), [saveWorkflow]);
const workflowContextValue = useMemo(() => ({
saveWorkflow,
documents,
tools
}), [saveWorkflow, documents, tools]);
return (
<WorkflowProvider value={workflowContextValue}>

View file

@ -1,7 +1,11 @@
import { createContext, useContext } from 'react';
import type { DocumentResponseSchema, ToolResponse } from '@/client/types.gen';
interface WorkflowContextType {
saveWorkflow: (updateWorkflowDefinition?: boolean) => Promise<void>;
documents?: DocumentResponseSchema[];
tools?: ToolResponse[];
}
const WorkflowContext = createContext<WorkflowContextType | undefined>(undefined);
@ -15,3 +19,8 @@ export const useWorkflow = () => {
}
return context;
};
// Optional hook that doesn't throw if context is not available
export const useWorkflowOptional = () => {
return useContext(WorkflowContext);
};

View file

@ -1,9 +1,8 @@
// This file is auto-generated by @hey-api/openapi-ts
import { type ClientOptions as DefaultClientOptions, type Config, createClient, createConfig } from '@hey-api/client-fetch';
import { createClientConfig } from '../lib/apiClient';
import type { ClientOptions } from './types.gen';
import { type Config, type ClientOptions as DefaultClientOptions, createClient, createConfig } from '@hey-api/client-fetch';
import { createClientConfig } from '../lib/apiClient';
/**
* The `createClientConfig()` function will be called on client initialization
@ -17,4 +16,4 @@ export type CreateClientConfig<T extends DefaultClientOptions = ClientOptions> =
export const client = createClient(createClientConfig(createConfig<ClientOptions>({
baseUrl: 'http://127.0.0.1:8000'
})));
})));

View file

@ -1,3 +1,3 @@
// This file is auto-generated by @hey-api/openapi-ts
export * from './sdk.gen';
export * from './types.gen';
export * from './sdk.gen';

File diff suppressed because one or more lines are too long

View file

@ -351,6 +351,11 @@ export type DefaultConfigurationsResponse = {
[key: string]: unknown;
};
};
embeddings: {
[key: string]: {
[key: string]: unknown;
};
};
default_providers: {
[key: string]: string;
};
@ -672,6 +677,10 @@ export type ProcessDocumentRequestSchema = {
* S3 key of the uploaded file
*/
s3_key: string;
/**
* Embedding service to use for processing. Options: 'openai' (default, 1536-dim, requires API key) or 'sentence_transformer' (free, 384-dim)
*/
embedding_service?: 'sentence_transformer' | 'openai';
};
export type S3SignedUrlResponse = {
@ -924,6 +933,9 @@ export type UserConfigurationRequestResponseSchema = {
stt?: {
[key: string]: string | number;
} | null;
embeddings?: {
[key: string]: string | number;
} | null;
test_phone_number?: string | null;
timezone?: string | null;
organization_pricing?: {
@ -4493,4 +4505,4 @@ export type HealthApiV1HealthGetResponses = {
export type ClientOptions = {
baseUrl: 'http://127.0.0.1:8000' | (string & {});
};
};

View file

@ -14,7 +14,7 @@ import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { VoiceSelector } from "@/components/VoiceSelector";
import { useUserConfig } from "@/context/UserConfigContext";
type ServiceSegment = "llm" | "tts" | "stt";
type ServiceSegment = "llm" | "tts" | "stt" | "embeddings";
interface SchemaProperty {
type?: string;
@ -41,6 +41,7 @@ const TAB_CONFIG: { key: ServiceSegment; label: string }[] = [
{ key: "llm", label: "LLM" },
{ key: "tts", label: "Voice" },
{ key: "stt", label: "Transcriber" },
{ key: "embeddings", label: "Embedding" },
];
// Display names for language codes (Deepgram + Sarvam)
@ -109,12 +110,14 @@ export default function ServiceConfiguration() {
const [schemas, setSchemas] = useState<Record<ServiceSegment, Record<string, ProviderSchema>>>({
llm: {},
tts: {},
stt: {}
stt: {},
embeddings: {}
});
const [serviceProviders, setServiceProviders] = useState<Record<ServiceSegment, string>>({
llm: "",
tts: "",
stt: ""
stt: "",
embeddings: ""
});
const [isManualModelInput, setIsManualModelInput] = useState(false);
const [hasCheckedManualMode, setHasCheckedManualMode] = useState(false);
@ -136,7 +139,8 @@ export default function ServiceConfiguration() {
setSchemas({
llm: response.data.llm as Record<string, ProviderSchema>,
tts: response.data.tts as Record<string, ProviderSchema>,
stt: response.data.stt as Record<string, ProviderSchema>
stt: response.data.stt as Record<string, ProviderSchema>,
embeddings: response.data.embeddings as Record<string, ProviderSchema>
});
} else {
console.error("Failed to fetch configurations");
@ -147,7 +151,8 @@ export default function ServiceConfiguration() {
const selectedProviders: Record<ServiceSegment, string> = {
llm: response.data.default_providers.llm,
tts: response.data.default_providers.tts,
stt: response.data.default_providers.stt
stt: response.data.default_providers.stt,
embeddings: response.data.default_providers.embeddings
};
const setServicePropertyValues = (service: ServiceSegment) => {
@ -173,6 +178,7 @@ export default function ServiceConfiguration() {
setServicePropertyValues("llm");
setServicePropertyValues("tts");
setServicePropertyValues("stt");
setServicePropertyValues("embeddings");
// IMPORTANT: Reset form values BEFORE changing providers
// Otherwise, Radix Select sees old values that don't match new provider's enum
@ -246,7 +252,7 @@ export default function ServiceConfiguration() {
setApiError(null);
setIsSaving(true);
const userConfig = {
const userConfig: Record<ServiceSegment, Record<string, string | number>> = {
llm: {
provider: serviceProviders.llm,
api_key: data.llm_api_key as string,
@ -259,6 +265,11 @@ export default function ServiceConfiguration() {
stt: {
provider: serviceProviders.stt,
api_key: data.stt_api_key as string
},
embeddings: {
provider: serviceProviders.embeddings,
api_key: data.embeddings_api_key as string,
model: data.embeddings_model as string
}
};
@ -273,12 +284,25 @@ export default function ServiceConfiguration() {
}
});
// Build save config - only include embeddings if api_key is provided
const saveConfig: {
llm: Record<string, string | number>;
tts: Record<string, string | number>;
stt: Record<string, string | number>;
embeddings?: Record<string, string | number>;
} = {
llm: userConfig.llm,
tts: userConfig.tts,
stt: userConfig.stt
};
// Only include embeddings if user has configured it (has api_key)
if (userConfig.embeddings.api_key) {
saveConfig.embeddings = userConfig.embeddings;
}
try {
await saveUserConfig({
llm: userConfig.llm,
tts: userConfig.tts,
stt: userConfig.stt
});
await saveUserConfig(saveConfig);
setApiError(null);
} catch (error: unknown) {
if (error instanceof Error) {
@ -543,7 +567,7 @@ export default function ServiceConfiguration() {
<Card>
<CardContent className="pt-6">
<Tabs defaultValue="llm" className="w-full">
<TabsList className="grid w-full grid-cols-3 mb-6">
<TabsList className="grid w-full grid-cols-4 mb-6">
{TAB_CONFIG.map(({ key, label }) => (
<TabsTrigger key={key} value={key}>
{label}

View file

@ -1,57 +1,55 @@
"use client";
import { Badge } from "@/components/ui/badge";
import { listDocumentsApiV1KnowledgeBaseDocumentsGet } from "@/client/sdk.gen";
import { useAuth } from "@/lib/auth";
import { useCallback, useEffect, useState } from "react";
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
import type { DocumentResponseSchema } from "@/client/types.gen";
import { Badge } from "@/components/ui/badge";
interface DocumentBadgesProps {
documentUuids: string[];
onStaleUuidsDetected?: (staleUuids: string[]) => void;
}
export const DocumentBadges = ({ documentUuids }: DocumentBadgesProps) => {
const { getAccessToken } = useAuth();
export const DocumentBadges = ({ documentUuids, onStaleUuidsDetected }: DocumentBadgesProps) => {
const { documents } = useWorkflow();
const [documentNames, setDocumentNames] = useState<Record<string, string>>({});
const [loading, setLoading] = useState(false);
const fetchDocuments = useCallback(async () => {
if (documentUuids.length === 0) return;
const processDocuments = useCallback((docs: DocumentResponseSchema[]) => {
const nameMap: Record<string, string> = {};
const validUuids = new Set<string>();
setLoading(true);
try {
const accessToken = await getAccessToken();
const response = await listDocumentsApiV1KnowledgeBaseDocumentsGet({
headers: { Authorization: `Bearer ${accessToken}` },
query: {
limit: 100,
},
docs
.filter((doc) => documentUuids.includes(doc.document_uuid))
.forEach((doc) => {
nameMap[doc.document_uuid] = doc.filename;
validUuids.add(doc.document_uuid);
});
setDocumentNames(nameMap);
if (response.data) {
const nameMap: Record<string, string> = {};
response.data.documents
.filter((doc) => documentUuids.includes(doc.document_uuid))
.forEach((doc) => {
nameMap[doc.document_uuid] = doc.filename;
});
setDocumentNames(nameMap);
// Detect stale UUIDs - this only runs when we have loaded data (not undefined)
if (onStaleUuidsDetected) {
const staleUuids = documentUuids.filter(uuid => !validUuids.has(uuid));
if (staleUuids.length > 0) {
onStaleUuidsDetected(staleUuids);
}
} catch (error) {
console.error("Failed to fetch documents:", error);
} finally {
setLoading(false);
}
}, [documentUuids, getAccessToken]);
}, [documentUuids, onStaleUuidsDetected]);
useEffect(() => {
fetchDocuments();
}, [fetchDocuments]);
if (documentUuids.length > 0 && documents !== undefined) {
processDocuments(documents);
} else if (documentUuids.length === 0) {
setDocumentNames({});
}
}, [documentUuids, documents, processDocuments]);
if (documentUuids.length === 0) {
return <></>;
}
if (loading) {
// Show loading while data hasn't loaded yet
if (documents === undefined) {
return <Badge variant="outline">Loading...</Badge>;
}

View file

@ -1,18 +1,18 @@
"use client";
import { FileText } from "lucide-react";
import Link from "next/link";
import { useMemo } from "react";
import type { DocumentResponseSchema } from "@/client/types.gen";
import { Button } from "@/components/ui/button";
import { Checkbox } from "@/components/ui/checkbox";
import { Label } from "@/components/ui/label";
import { listDocumentsApiV1KnowledgeBaseDocumentsGet } from "@/client/sdk.gen";
import type { DocumentResponseSchema } from "@/client/types.gen";
import { useAuth } from "@/lib/auth";
import { FileText } from "lucide-react";
import Link from "next/link";
import { useCallback, useEffect, useState } from "react";
interface DocumentSelectorProps {
value: string[];
onChange: (uuids: string[]) => void;
documents: DocumentResponseSchema[];
disabled?: boolean;
label?: string;
description?: string;
@ -22,43 +22,17 @@ interface DocumentSelectorProps {
export const DocumentSelector = ({
value,
onChange,
documents,
disabled = false,
label = "Knowledge Base Documents",
description = "Select documents that the agent can reference during conversations.",
showLabel = true,
}: DocumentSelectorProps) => {
const { getAccessToken } = useAuth();
const [documents, setDocuments] = useState<DocumentResponseSchema[]>([]);
const [loading, setLoading] = useState(false);
const fetchDocuments = useCallback(async () => {
setLoading(true);
try {
const accessToken = await getAccessToken();
const response = await listDocumentsApiV1KnowledgeBaseDocumentsGet({
headers: { Authorization: `Bearer ${accessToken}` },
query: {
limit: 100,
},
});
if (response.data) {
// Only show completed documents
const completedDocs = response.data.documents.filter(
(doc) => doc.processing_status === "completed"
);
setDocuments(completedDocs);
}
} catch (error) {
console.error("Failed to fetch documents:", error);
} finally {
setLoading(false);
}
}, [getAccessToken]);
useEffect(() => {
fetchDocuments();
}, [fetchDocuments]);
// Only show completed documents
const completedDocuments = useMemo(
() => documents.filter((doc) => doc.processing_status === "completed"),
[documents]
);
const handleToggle = (documentUuid: string, checked: boolean) => {
if (checked) {
@ -76,25 +50,7 @@ export const DocumentSelector = ({
return Math.round(bytes / Math.pow(k, i) * 100) / 100 + " " + sizes[i];
};
if (loading) {
return (
<div className="space-y-2">
{showLabel && (
<>
<Label>{label}</Label>
{description && (
<Label className="text-xs text-muted-foreground">{description}</Label>
)}
</>
)}
<div className="border rounded-md p-4 text-sm text-muted-foreground text-center">
Loading documents...
</div>
</div>
);
}
if (documents.length === 0) {
if (completedDocuments.length === 0) {
return (
<div className="space-y-2">
{showLabel && (
@ -133,7 +89,7 @@ export const DocumentSelector = ({
)}
<div className="border rounded-md max-h-[300px] overflow-y-auto">
<div className="divide-y">
{documents.map((doc) => (
{completedDocuments.map((doc) => (
<div
key={doc.document_uuid}
className="flex items-start gap-3 p-3 hover:bg-muted/50 transition-colors"

View file

@ -2,43 +2,43 @@
import { useCallback, useEffect, useState } from "react";
import { listToolsApiV1ToolsGet } from "@/client/sdk.gen";
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
import type { ToolResponse } from "@/client/types.gen";
import { Badge } from "@/components/ui/badge";
import { useAuth } from "@/lib/auth";
interface ToolBadgesProps {
toolUuids: string[];
onStaleUuidsDetected?: (staleUuids: string[]) => void;
}
export function ToolBadges({ toolUuids }: ToolBadgesProps) {
const { getAccessToken } = useAuth();
const [tools, setTools] = useState<ToolResponse[]>([]);
export function ToolBadges({ toolUuids, onStaleUuidsDetected }: ToolBadgesProps) {
const { tools } = useWorkflow();
const [selectedTools, setSelectedTools] = useState<ToolResponse[]>([]);
const fetchTools = useCallback(async () => {
try {
const accessToken = await getAccessToken();
const response = await listToolsApiV1ToolsGet({
headers: { Authorization: `Bearer ${accessToken}` },
});
if (response.data) {
setTools(response.data);
const processTools = useCallback((toolsData: ToolResponse[]) => {
const filtered = toolsData.filter(tool => toolUuids.includes(tool.tool_uuid));
setSelectedTools(filtered);
// Detect stale UUIDs - this only runs when we have loaded data (not undefined)
if (onStaleUuidsDetected) {
const validUuids = new Set(toolsData.map(tool => tool.tool_uuid));
const staleUuids = toolUuids.filter(uuid => !validUuids.has(uuid));
if (staleUuids.length > 0) {
onStaleUuidsDetected(staleUuids);
}
} catch (error) {
console.error("Failed to fetch tools:", error);
}
}, [getAccessToken]);
}, [toolUuids, onStaleUuidsDetected]);
useEffect(() => {
if (toolUuids.length > 0) {
fetchTools();
if (toolUuids.length > 0 && tools !== undefined) {
processTools(tools);
} else if (toolUuids.length === 0) {
setSelectedTools([]);
}
}, [toolUuids.length, fetchTools]);
}, [toolUuids, tools, processTools]);
const selectedTools = tools.filter((tool) => toolUuids.includes(tool.tool_uuid));
if (selectedTools.length === 0 && toolUuids.length > 0) {
// Still loading or tools not found
// Show loading while data hasn't loaded yet
if (tools === undefined && toolUuids.length > 0) {
return (
<div className="flex flex-wrap gap-1">
<Badge variant="outline" className="text-xs">

View file

@ -1,20 +1,18 @@
"use client";
import { ExternalLink, Loader2 } from "lucide-react";
import { ExternalLink } from "lucide-react";
import Link from "next/link";
import { useCallback, useEffect, useState } from "react";
import { renderToolIcon } from "@/app/tools/config";
import { listToolsApiV1ToolsGet } from "@/client/sdk.gen";
import type { ToolResponse } from "@/client/types.gen";
import { Button } from "@/components/ui/button";
import { Checkbox } from "@/components/ui/checkbox";
import { Label } from "@/components/ui/label";
import { useAuth } from "@/lib/auth";
interface ToolSelectorProps {
value: string[];
onChange: (uuids: string[]) => void;
tools: ToolResponse[];
disabled?: boolean;
label?: string;
description?: string;
@ -24,43 +22,14 @@ interface ToolSelectorProps {
export function ToolSelector({
value,
onChange,
tools,
disabled = false,
label = "Tools",
description = "Select tools that the agent can use during the conversation.",
showLabel = true,
}: ToolSelectorProps) {
const { getAccessToken } = useAuth();
const [tools, setTools] = useState<ToolResponse[]>([]);
const [loading, setLoading] = useState(false);
const fetchTools = useCallback(async () => {
setLoading(true);
try {
const accessToken = await getAccessToken();
const response = await listToolsApiV1ToolsGet({
headers: { Authorization: `Bearer ${accessToken}` },
query: { status: "active" },
});
if (response.error) {
console.error("Failed to fetch tools:", response.error);
setTools([]);
return;
}
if (response.data) {
setTools(response.data);
}
} catch (error) {
console.error("Failed to fetch tools:", error);
setTools([]);
} finally {
setLoading(false);
}
}, [getAccessToken]);
useEffect(() => {
fetchTools();
}, [fetchTools]);
// Filter to only show active tools
const activeTools = tools.filter((tool) => tool.status === "active");
const handleToggle = (toolUuid: string, checked: boolean) => {
if (checked) {
@ -83,12 +52,7 @@ export function ToolSelector({
</>
)}
{loading ? (
<div className="flex items-center gap-2 p-3 border rounded-md">
<Loader2 className="h-4 w-4 animate-spin" />
<span className="text-sm text-muted-foreground">Loading tools...</span>
</div>
) : tools.length === 0 ? (
{activeTools.length === 0 ? (
<div className="p-4 border rounded-md text-center">
<p className="text-sm text-muted-foreground mb-2">
No tools available.
@ -102,7 +66,7 @@ export function ToolSelector({
</div>
) : (
<div className="border rounded-md divide-y">
{tools.map((tool) => {
{activeTools.map((tool) => {
const isSelected = value.includes(tool.tool_uuid);
return (
<label

View file

@ -1,8 +1,9 @@
import { NodeProps, NodeToolbar, Position } from "@xyflow/react";
import { Edit, FileText, Headset, PlusIcon, Trash2Icon, Wrench } from "lucide-react";
import { memo, useEffect, useMemo, useState } from "react";
import { memo, useCallback, useEffect, useMemo, useState } from "react";
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
import type { DocumentResponseSchema, ToolResponse } from "@/client/types.gen";
import { DocumentBadges } from "@/components/flow/DocumentBadges";
import { DocumentSelector } from "@/components/flow/DocumentSelector";
import { ToolBadges } from "@/components/flow/ToolBadges";
@ -38,6 +39,8 @@ interface AgentNodeEditFormProps {
setToolUuids: (value: string[]) => void;
documentUuids: string[];
setDocumentUuids: (value: string[]) => void;
tools: ToolResponse[];
documents: DocumentResponseSchema[];
}
interface AgentNodeProps extends NodeProps {
@ -46,7 +49,7 @@ interface AgentNodeProps extends NodeProps {
export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
const { open, setOpen, handleSaveNodeData, handleDeleteNode } = useNodeHandlers({ id });
const { saveWorkflow } = useWorkflow();
const { saveWorkflow, tools, documents } = useWorkflow();
// Form state
const [prompt, setPrompt] = useState(data.prompt);
@ -120,6 +123,30 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
}
}, [data, open]);
// Handle cleanup of stale document UUIDs
const handleStaleDocuments = useCallback((staleUuids: string[]) => {
const cleanedUuids = (data.document_uuids ?? []).filter(uuid => !staleUuids.includes(uuid));
handleSaveNodeData({
...data,
document_uuids: cleanedUuids.length > 0 ? cleanedUuids : undefined,
});
setTimeout(async () => {
await saveWorkflow();
}, 100);
}, [data, handleSaveNodeData, saveWorkflow]);
// Handle cleanup of stale tool UUIDs
const handleStaleTools = useCallback((staleUuids: string[]) => {
const cleanedUuids = (data.tool_uuids ?? []).filter(uuid => !staleUuids.includes(uuid));
handleSaveNodeData({
...data,
tool_uuids: cleanedUuids.length > 0 ? cleanedUuids : undefined,
});
setTimeout(async () => {
await saveWorkflow();
}, 100);
}, [data, handleSaveNodeData, saveWorkflow]);
return (
<>
<NodeContent
@ -144,7 +171,7 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
<Wrench className="h-3 w-3" />
<span>Tools:</span>
</div>
<ToolBadges toolUuids={data.tool_uuids} />
<ToolBadges toolUuids={data.tool_uuids} onStaleUuidsDetected={handleStaleTools} />
</div>
)}
{data.document_uuids && data.document_uuids.length > 0 && (
@ -153,7 +180,7 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
<FileText className="h-3 w-3" />
<span>Documents:</span>
</div>
<DocumentBadges documentUuids={data.document_uuids} />
<DocumentBadges documentUuids={data.document_uuids} onStaleUuidsDetected={handleStaleDocuments} />
</div>
)}
</NodeContent>
@ -198,6 +225,8 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
setToolUuids={setToolUuids}
documentUuids={documentUuids}
setDocumentUuids={setDocumentUuids}
tools={tools ?? []}
documents={documents ?? []}
/>
)}
</NodeEditDialog>
@ -224,6 +253,8 @@ const AgentNodeEditForm = ({
setToolUuids,
documentUuids,
setDocumentUuids,
tools,
documents,
}: AgentNodeEditFormProps) => {
const handleVariableNameChange = (idx: number, value: string) => {
const newVars = [...variables];
@ -364,6 +395,7 @@ const AgentNodeEditForm = ({
<ToolSelector
value={toolUuids}
onChange={setToolUuids}
tools={tools}
description="Select tools that the agent can invoke during this conversation step."
/>
</div>
@ -373,6 +405,7 @@ const AgentNodeEditForm = ({
<DocumentSelector
value={documentUuids}
onChange={setDocumentUuids}
documents={documents}
description="Select documents from the knowledge base that the agent can reference during this conversation step."
/>
</div>

View file

@ -1,8 +1,9 @@
import { NodeProps, NodeToolbar, Position } from "@xyflow/react";
import { Edit, FileText, Play, PlusIcon, Trash2Icon, Wrench } from "lucide-react";
import { memo, useEffect, useMemo, useState } from "react";
import { memo, useCallback, useEffect, useMemo, useState } from "react";
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
import type { DocumentResponseSchema, ToolResponse } from "@/client/types.gen";
import { DocumentBadges } from "@/components/flow/DocumentBadges";
import { DocumentSelector } from "@/components/flow/DocumentSelector";
import { ToolBadges } from "@/components/flow/ToolBadges";
@ -45,6 +46,8 @@ interface StartCallEditFormProps {
setToolUuids: (value: string[]) => void;
documentUuids: string[];
setDocumentUuids: (value: string[]) => void;
tools: ToolResponse[];
documents: DocumentResponseSchema[];
}
interface StartCallNodeProps extends NodeProps {
@ -56,7 +59,7 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
id,
additionalData: { is_start: true }
});
const { saveWorkflow } = useWorkflow();
const { saveWorkflow, tools, documents } = useWorkflow();
// Form state
const [prompt, setPrompt] = useState(data.prompt ?? "");
@ -140,6 +143,30 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
}
}, [data, open]);
// Handle cleanup of stale document UUIDs
const handleStaleDocuments = useCallback((staleUuids: string[]) => {
const cleanedUuids = (data.document_uuids ?? []).filter(uuid => !staleUuids.includes(uuid));
handleSaveNodeData({
...data,
document_uuids: cleanedUuids.length > 0 ? cleanedUuids : undefined,
});
setTimeout(async () => {
await saveWorkflow();
}, 100);
}, [data, handleSaveNodeData, saveWorkflow]);
// Handle cleanup of stale tool UUIDs
const handleStaleTools = useCallback((staleUuids: string[]) => {
const cleanedUuids = (data.tool_uuids ?? []).filter(uuid => !staleUuids.includes(uuid));
handleSaveNodeData({
...data,
tool_uuids: cleanedUuids.length > 0 ? cleanedUuids : undefined,
});
setTimeout(async () => {
await saveWorkflow();
}, 100);
}, [data, handleSaveNodeData, saveWorkflow]);
return (
<>
<NodeContent
@ -163,7 +190,7 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
<Wrench className="h-3 w-3" />
<span>Tools:</span>
</div>
<ToolBadges toolUuids={data.tool_uuids} />
<ToolBadges toolUuids={data.tool_uuids} onStaleUuidsDetected={handleStaleTools} />
</div>
)}
{data.document_uuids && data.document_uuids.length > 0 && (
@ -172,7 +199,7 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
<FileText className="h-3 w-3" />
<span>Documents:</span>
</div>
<DocumentBadges documentUuids={data.document_uuids} />
<DocumentBadges documentUuids={data.document_uuids} onStaleUuidsDetected={handleStaleDocuments} />
</div>
)}
</NodeContent>
@ -218,6 +245,8 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
setToolUuids={setToolUuids}
documentUuids={documentUuids}
setDocumentUuids={setDocumentUuids}
tools={tools ?? []}
documents={documents ?? []}
/>
)}
</NodeEditDialog>
@ -250,6 +279,8 @@ const StartCallEditForm = ({
setToolUuids,
documentUuids,
setDocumentUuids,
tools,
documents,
}: StartCallEditFormProps) => {
const handleVariableNameChange = (idx: number, value: string) => {
const newVars = [...variables];
@ -435,6 +466,7 @@ const StartCallEditForm = ({
<ToolSelector
value={toolUuids}
onChange={setToolUuids}
tools={tools}
description="Select tools that the agent can invoke during this conversation step."
/>
</div>
@ -444,6 +476,7 @@ const StartCallEditForm = ({
<DocumentSelector
value={documentUuids}
onChange={setDocumentUuids}
documents={documents}
description="Select documents from the knowledge base that the agent can reference during this conversation step."
/>
</div>

View file

@ -18,6 +18,9 @@ export type SaveUserConfigFunctionParams = {
stt?: {
[key: string]: string | number;
} | null;
embeddings?: {
[key: string]: string | number;
} | null;
test_phone_number?: string | null;
timezone?: string | null;
};