mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: add openai embedding service
This commit is contained in:
parent
eb41285204
commit
3f0e500fde
39 changed files with 1902 additions and 339 deletions
|
|
@ -5,16 +5,17 @@ Revises: dcb0a27d98c6
|
|||
Create Date: 2026-01-16 13:40:17.808807
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from alembic import op
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'dc33eef8dabe'
|
||||
down_revision: Union[str, None] = 'dcb0a27d98c6'
|
||||
revision: str = "dc33eef8dabe"
|
||||
down_revision: Union[str, None] = "dcb0a27d98c6"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
|
@ -22,75 +23,170 @@ depends_on: Union[str, Sequence[str], None] = None
|
|||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# Enable pgvector extension
|
||||
op.execute('CREATE EXTENSION IF NOT EXISTS vector')
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
|
||||
sa.Enum('pending', 'processing', 'completed', 'failed', name='document_processing_status').create(op.get_bind())
|
||||
op.create_table('knowledge_base_documents',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('document_uuid', sa.String(length=36), nullable=False),
|
||||
sa.Column('organization_id', sa.Integer(), nullable=False),
|
||||
sa.Column('filename', sa.String(length=500), nullable=False),
|
||||
sa.Column('file_size_bytes', sa.Integer(), nullable=True),
|
||||
sa.Column('file_hash', sa.String(length=64), nullable=True),
|
||||
sa.Column('mime_type', sa.String(length=100), nullable=True),
|
||||
sa.Column('source_url', sa.String(), nullable=True),
|
||||
sa.Column('total_chunks', sa.Integer(), nullable=False),
|
||||
sa.Column('processing_status', postgresql.ENUM('pending', 'processing', 'completed', 'failed', name='document_processing_status', create_type=False), server_default=sa.text("'pending'::document_processing_status"), nullable=False),
|
||||
sa.Column('processing_error', sa.Text(), nullable=True),
|
||||
sa.Column('docling_metadata', sa.JSON(), nullable=False),
|
||||
sa.Column('custom_metadata', sa.JSON(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('archived_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(['created_by'], ['users.id'], ),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
op.create_index('ix_kb_documents_created_at', 'knowledge_base_documents', ['created_at'], unique=False)
|
||||
op.create_index('ix_kb_documents_organization_id', 'knowledge_base_documents', ['organization_id'], unique=False)
|
||||
op.create_index('ix_kb_documents_status', 'knowledge_base_documents', ['processing_status'], unique=False)
|
||||
op.create_index('ix_kb_documents_uuid', 'knowledge_base_documents', ['document_uuid'], unique=False)
|
||||
op.create_index(op.f('ix_knowledge_base_documents_document_uuid'), 'knowledge_base_documents', ['document_uuid'], unique=True)
|
||||
op.create_table('knowledge_base_chunks'),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('document_id', sa.Integer(), nullable=False),
|
||||
sa.Column('organization_id', sa.Integer(), nullable=False),
|
||||
sa.Column('chunk_text', sa.Text(), nullable=False),
|
||||
sa.Column('contextualized_text', sa.Text(), nullable=True),
|
||||
sa.Column('chunk_index', sa.Integer(), nullable=False),
|
||||
sa.Column('chunk_metadata', sa.JSON(), nullable=False),
|
||||
sa.Column('embedding_model', sa.String(length=200), nullable=False),
|
||||
sa.Column('embedding_dimension', sa.Integer(), nullable=False),
|
||||
sa.Column('embedding', Vector(384), nullable=True),
|
||||
sa.Column('token_count', sa.Integer(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(['document_id'], ['knowledge_base_documents.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
sa.Enum(
|
||||
"pending",
|
||||
"processing",
|
||||
"completed",
|
||||
"failed",
|
||||
name="document_processing_status",
|
||||
).create(op.get_bind())
|
||||
op.create_table(
|
||||
"knowledge_base_documents",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("document_uuid", sa.String(length=36), nullable=False),
|
||||
sa.Column("organization_id", sa.Integer(), nullable=False),
|
||||
sa.Column("filename", sa.String(length=500), nullable=False),
|
||||
sa.Column("file_size_bytes", sa.Integer(), nullable=True),
|
||||
sa.Column("file_hash", sa.String(length=64), nullable=True),
|
||||
sa.Column("mime_type", sa.String(length=100), nullable=True),
|
||||
sa.Column("source_url", sa.String(), nullable=True),
|
||||
sa.Column("total_chunks", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"processing_status",
|
||||
postgresql.ENUM(
|
||||
"pending",
|
||||
"processing",
|
||||
"completed",
|
||||
"failed",
|
||||
name="document_processing_status",
|
||||
create_type=False,
|
||||
),
|
||||
server_default=sa.text("'pending'::document_processing_status"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("processing_error", sa.Text(), nullable=True),
|
||||
sa.Column("docling_metadata", sa.JSON(), nullable=False),
|
||||
sa.Column("custom_metadata", sa.JSON(), nullable=False),
|
||||
sa.Column("created_by", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("archived_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["created_by"],
|
||||
["users.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"], ["organizations.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kb_documents_created_at",
|
||||
"knowledge_base_documents",
|
||||
["created_at"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kb_documents_organization_id",
|
||||
"knowledge_base_documents",
|
||||
["organization_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kb_documents_status",
|
||||
"knowledge_base_documents",
|
||||
["processing_status"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kb_documents_uuid",
|
||||
"knowledge_base_documents",
|
||||
["document_uuid"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_knowledge_base_documents_document_uuid"),
|
||||
"knowledge_base_documents",
|
||||
["document_uuid"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_table(
|
||||
"knowledge_base_chunks",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("document_id", sa.Integer(), nullable=False),
|
||||
sa.Column("organization_id", sa.Integer(), nullable=False),
|
||||
sa.Column("chunk_text", sa.Text(), nullable=False),
|
||||
sa.Column("contextualized_text", sa.Text(), nullable=True),
|
||||
sa.Column("chunk_index", sa.Integer(), nullable=False),
|
||||
sa.Column("chunk_metadata", sa.JSON(), nullable=False),
|
||||
sa.Column("embedding_model", sa.String(length=200), nullable=False),
|
||||
sa.Column("embedding_dimension", sa.Integer(), nullable=False),
|
||||
sa.Column("embedding", Vector(1536), nullable=True),
|
||||
sa.Column("token_count", sa.Integer(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["document_id"], ["knowledge_base_documents.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"], ["organizations.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kb_chunks_chunk_index",
|
||||
"knowledge_base_chunks",
|
||||
["chunk_index"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kb_chunks_document_id",
|
||||
"knowledge_base_chunks",
|
||||
["document_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kb_chunks_embedding_ivfflat",
|
||||
"knowledge_base_chunks",
|
||||
["embedding"],
|
||||
unique=False,
|
||||
postgresql_using="ivfflat",
|
||||
postgresql_with={"lists": 100},
|
||||
postgresql_ops={"embedding": "vector_cosine_ops"},
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kb_chunks_organization_id",
|
||||
"knowledge_base_chunks",
|
||||
["organization_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index('ix_kb_chunks_chunk_index', 'knowledge_base_chunks', ['chunk_index'], unique=False)
|
||||
op.create_index('ix_kb_chunks_document_id', 'knowledge_base_chunks', ['document_id'], unique=False)
|
||||
op.create_index('ix_kb_chunks_embedding_ivfflat', 'knowledge_base_chunks', ['embedding'], unique=False, postgresql_using='ivfflat', postgresql_with={'lists': 100}, postgresql_ops={'embedding': 'vector_cosine_ops'})
|
||||
op.create_index('ix_kb_chunks_organization_id', 'knowledge_base_chunks', ['organization_id'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index('ix_kb_chunks_organization_id', table_name='knowledge_base_chunks')
|
||||
op.drop_index('ix_kb_chunks_embedding_ivfflat', table_name='knowledge_base_chunks', postgresql_using='ivfflat', postgresql_with={'lists': 100}, postgresql_ops={'embedding': 'vector_cosine_ops'})
|
||||
op.drop_index('ix_kb_chunks_document_id', table_name='knowledge_base_chunks')
|
||||
op.drop_index('ix_kb_chunks_chunk_index', table_name='knowledge_base_chunks')
|
||||
op.drop_table('knowledge_base_chunks')
|
||||
op.drop_index(op.f('ix_knowledge_base_documents_document_uuid'), table_name='knowledge_base_documents')
|
||||
op.drop_index('ix_kb_documents_uuid', table_name='knowledge_base_documents')
|
||||
op.drop_index('ix_kb_documents_status', table_name='knowledge_base_documents')
|
||||
op.drop_index('ix_kb_documents_organization_id', table_name='knowledge_base_documents')
|
||||
op.drop_index('ix_kb_documents_created_at', table_name='knowledge_base_documents')
|
||||
op.drop_table('knowledge_base_documents')
|
||||
sa.Enum('pending', 'processing', 'completed', 'failed', name='document_processing_status').drop(op.get_bind())
|
||||
op.drop_index("ix_kb_chunks_organization_id", table_name="knowledge_base_chunks")
|
||||
op.drop_index(
|
||||
"ix_kb_chunks_embedding_ivfflat",
|
||||
table_name="knowledge_base_chunks",
|
||||
postgresql_using="ivfflat",
|
||||
postgresql_with={"lists": 100},
|
||||
postgresql_ops={"embedding": "vector_cosine_ops"},
|
||||
)
|
||||
op.drop_index("ix_kb_chunks_document_id", table_name="knowledge_base_chunks")
|
||||
op.drop_index("ix_kb_chunks_chunk_index", table_name="knowledge_base_chunks")
|
||||
op.drop_table("knowledge_base_chunks")
|
||||
op.drop_index(
|
||||
op.f("ix_knowledge_base_documents_document_uuid"),
|
||||
table_name="knowledge_base_documents",
|
||||
)
|
||||
op.drop_index("ix_kb_documents_uuid", table_name="knowledge_base_documents")
|
||||
op.drop_index("ix_kb_documents_status", table_name="knowledge_base_documents")
|
||||
op.drop_index(
|
||||
"ix_kb_documents_organization_id", table_name="knowledge_base_documents"
|
||||
)
|
||||
op.drop_index("ix_kb_documents_created_at", table_name="knowledge_base_documents")
|
||||
op.drop_table("knowledge_base_documents")
|
||||
sa.Enum(
|
||||
"pending",
|
||||
"processing",
|
||||
"completed",
|
||||
"failed",
|
||||
name="document_processing_status",
|
||||
).drop(op.get_bind())
|
||||
|
||||
# Note: We don't drop the vector extension as it may be used by other tables
|
||||
# If you want to drop it, uncomment the following line:
|
||||
|
|
|
|||
|
|
@ -332,6 +332,7 @@ class KnowledgeBaseClient(BaseDBClient):
|
|||
limit: int = 5,
|
||||
document_ids: Optional[List[int]] = None,
|
||||
document_uuids: Optional[List[str]] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
) -> List[dict]:
|
||||
"""Search for similar chunks using vector similarity.
|
||||
|
||||
|
|
@ -344,6 +345,7 @@ class KnowledgeBaseClient(BaseDBClient):
|
|||
limit: Maximum number of results to return
|
||||
document_ids: Optional list of document IDs to filter by
|
||||
document_uuids: Optional list of document UUIDs to filter by
|
||||
embedding_model: Optional embedding model to filter by (for dimension compatibility)
|
||||
|
||||
Returns:
|
||||
List of dictionaries with chunk data and similarity scores, ordered by similarity (highest first)
|
||||
|
|
@ -359,23 +361,37 @@ class KnowledgeBaseClient(BaseDBClient):
|
|||
"c.organization_id = $2",
|
||||
"d.is_active = true",
|
||||
]
|
||||
params = [None, organization_id, limit] # $1 will be embedding_str, $3 is limit
|
||||
params = [
|
||||
None,
|
||||
organization_id,
|
||||
limit,
|
||||
] # $1 will be embedding_str, $3 is limit
|
||||
param_index = 4 # Next available parameter index
|
||||
|
||||
# Add document_ids filter if provided
|
||||
if document_ids:
|
||||
placeholders = ", ".join(f"${param_index + i}" for i in range(len(document_ids)))
|
||||
placeholders = ", ".join(
|
||||
f"${param_index + i}" for i in range(len(document_ids))
|
||||
)
|
||||
where_conditions.append(f"c.document_id IN ({placeholders})")
|
||||
params.extend(document_ids)
|
||||
param_index += len(document_ids)
|
||||
|
||||
# Add document_uuids filter if provided
|
||||
if document_uuids:
|
||||
placeholders = ", ".join(f"${param_index + i}" for i in range(len(document_uuids)))
|
||||
placeholders = ", ".join(
|
||||
f"${param_index + i}" for i in range(len(document_uuids))
|
||||
)
|
||||
where_conditions.append(f"d.document_uuid IN ({placeholders})")
|
||||
params.extend(document_uuids)
|
||||
param_index += len(document_uuids)
|
||||
|
||||
# Add embedding_model filter if provided (for dimension compatibility)
|
||||
if embedding_model:
|
||||
where_conditions.append(f"c.embedding_model = ${param_index}")
|
||||
params.append(embedding_model)
|
||||
param_index += 1
|
||||
|
||||
# Build the complete SQL query
|
||||
where_clause = " AND ".join(where_conditions)
|
||||
query_sql = f"""
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import uuid
|
|||
from datetime import UTC, datetime
|
||||
|
||||
from loguru import logger
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
Boolean,
|
||||
|
|
@ -19,7 +20,6 @@ from sqlalchemy import (
|
|||
and_,
|
||||
text,
|
||||
)
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy.orm import declarative_base, relationship
|
||||
|
||||
from ..enums import (
|
||||
|
|
@ -929,7 +929,13 @@ class KnowledgeBaseDocumentModel(Base):
|
|||
source_url = Column(String, nullable=True) # If document was fetched from URL
|
||||
total_chunks = Column(Integer, nullable=False, default=0)
|
||||
processing_status = Column(
|
||||
Enum("pending", "processing", "completed", "failed", name="document_processing_status"),
|
||||
Enum(
|
||||
"pending",
|
||||
"processing",
|
||||
"completed",
|
||||
"failed",
|
||||
name="document_processing_status",
|
||||
),
|
||||
nullable=False,
|
||||
default="pending",
|
||||
server_default=text("'pending'::document_processing_status"),
|
||||
|
|
@ -937,7 +943,9 @@ class KnowledgeBaseDocumentModel(Base):
|
|||
processing_error = Column(Text, nullable=True)
|
||||
|
||||
# Docling conversion metadata
|
||||
docling_metadata = Column(JSON, nullable=False, default=dict) # Store docling document metadata
|
||||
docling_metadata = Column(
|
||||
JSON, nullable=False, default=dict
|
||||
) # Store docling document metadata
|
||||
|
||||
# Custom metadata (user-defined tags, categories, etc.)
|
||||
custom_metadata = Column(JSON, nullable=False, default=dict)
|
||||
|
|
@ -1000,21 +1008,31 @@ class KnowledgeBaseChunkModel(Base):
|
|||
|
||||
# Chunk content
|
||||
chunk_text = Column(Text, nullable=False) # The actual chunk text
|
||||
contextualized_text = Column(Text, nullable=True) # Enriched text from chunker.contextualize()
|
||||
contextualized_text = Column(
|
||||
Text, nullable=True
|
||||
) # Enriched text from chunker.contextualize()
|
||||
|
||||
# Chunk positioning and metadata
|
||||
chunk_index = Column(Integer, nullable=False) # Position in document (0-based)
|
||||
|
||||
# Docling chunk metadata
|
||||
chunk_metadata = Column(JSON, nullable=False, default=dict) # Store chunk.meta if available
|
||||
chunk_metadata = Column(
|
||||
JSON, nullable=False, default=dict
|
||||
) # Store chunk.meta if available
|
||||
|
||||
# Embedding configuration
|
||||
embedding_model = Column(String(200), nullable=False) # e.g., "sentence-transformers/all-MiniLM-L6-v2"
|
||||
embedding_dimension = Column(Integer, nullable=False) # e.g., 384 for all-MiniLM-L6-v2
|
||||
embedding_model = Column(
|
||||
String(200), nullable=False
|
||||
) # e.g., "sentence-transformers/all-MiniLM-L6-v2"
|
||||
embedding_dimension = Column(
|
||||
Integer, nullable=False
|
||||
) # e.g., 384 for all-MiniLM-L6-v2
|
||||
|
||||
# Vector embedding (pgvector column)
|
||||
# The dimension should match the embedding_dimension field
|
||||
embedding = Column(Vector(384), nullable=True) # Default to 384 for all-MiniLM-L6-v2
|
||||
# Default: 1536 dimensions for OpenAI text-embedding-3-small
|
||||
# SentenceTransformer (384-dim) also supported but stored as 384-dim vectors
|
||||
embedding = Column(Vector(1536), nullable=True)
|
||||
|
||||
# Token count (useful for chunking strategy analysis)
|
||||
token_count = Column(Integer, nullable=True)
|
||||
|
|
@ -1036,6 +1054,9 @@ class KnowledgeBaseChunkModel(Base):
|
|||
Index("ix_kb_chunks_document_id", "document_id"),
|
||||
Index("ix_kb_chunks_organization_id", "organization_id"),
|
||||
Index("ix_kb_chunks_chunk_index", "chunk_index"),
|
||||
Index(
|
||||
"ix_kb_chunks_embedding_model", "embedding_model"
|
||||
), # For filtering by model
|
||||
# Vector similarity search index (using IVFFlat or HNSW)
|
||||
# IVFFlat is good for datasets with 10k-1M vectors
|
||||
# HNSW is better for larger datasets but uses more memory
|
||||
|
|
|
|||
|
|
@ -103,6 +103,10 @@ async def process_document(
|
|||
|
||||
The document status will be updated from 'pending' -> 'processing' -> 'completed' or 'failed'.
|
||||
|
||||
Embedding Services:
|
||||
* openai (default): High-quality 1536-dimensional embeddings (requires OPENAI_API_KEY)
|
||||
* sentence_transformer: Free, offline-capable, 384-dimensional embeddings
|
||||
|
||||
Access Control:
|
||||
* Users can only process documents in their organization.
|
||||
"""
|
||||
|
|
@ -129,11 +133,13 @@ async def process_document(
|
|||
document.id,
|
||||
request.s3_key,
|
||||
user.selected_organization_id,
|
||||
128, # max_tokens (default)
|
||||
request.embedding_service,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created document {request.document_uuid} (id={document.id}) and enqueued processing, "
|
||||
f"org {user.selected_organization_id}"
|
||||
f"Created document {request.document_uuid} (id={document.id}) and enqueued processing "
|
||||
f"with {request.embedding_service} embeddings, org {user.selected_organization_id}"
|
||||
)
|
||||
|
||||
return DocumentResponseSchema(
|
||||
|
|
@ -277,9 +283,7 @@ async def get_document(
|
|||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Error getting document: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to get document"
|
||||
) from exc
|
||||
raise HTTPException(status_code=500, detail="Failed to get document") from exc
|
||||
|
||||
|
||||
@router.delete(
|
||||
|
|
@ -342,15 +346,26 @@ async def search_chunks(
|
|||
|
||||
try:
|
||||
# Import here to avoid circular dependency
|
||||
from api.services.admin_utils.local_exec import DocumentProcessor
|
||||
from api.services.gen_ai import OpenAIEmbeddingService
|
||||
|
||||
# Initialize processor (reuses cached models)
|
||||
processor = DocumentProcessor(
|
||||
# Try to get user's embeddings configuration
|
||||
user_config = await db_client.get_user_configurations(user.id)
|
||||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
|
||||
if user_config.embeddings:
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
|
||||
# Initialize embedding service with user config or fallback to env
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
)
|
||||
|
||||
# Perform search
|
||||
results = await processor.search_similar_chunks(
|
||||
results = await embedding_service.search_similar_chunks(
|
||||
query=request.query,
|
||||
organization_id=user.selected_organization_id,
|
||||
limit=request.limit,
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ class DefaultConfigurationsResponse(TypedDict):
|
|||
llm: dict[str, dict]
|
||||
tts: dict[str, dict]
|
||||
stt: dict[str, dict]
|
||||
embeddings: dict[str, dict]
|
||||
default_providers: dict[str, str]
|
||||
|
||||
|
||||
|
|
@ -50,6 +51,10 @@ async def get_default_configurations() -> DefaultConfigurationsResponse:
|
|||
provider: model_cls.model_json_schema()
|
||||
for provider, model_cls in REGISTRY[ServiceType.STT].items()
|
||||
},
|
||||
"embeddings": {
|
||||
provider: model_cls.model_json_schema()
|
||||
for provider, model_cls in REGISTRY[ServiceType.EMBEDDINGS].items()
|
||||
},
|
||||
"default_providers": DEFAULT_SERVICE_PROVIDERS,
|
||||
}
|
||||
return configurations
|
||||
|
|
@ -69,6 +74,7 @@ class UserConfigurationRequestResponseSchema(BaseModel):
|
|||
llm: dict[str, Union[str, float]] | None = None
|
||||
tts: dict[str, Union[str, float]] | None = None
|
||||
stt: dict[str, Union[str, float]] | None = None
|
||||
embeddings: dict[str, Union[str, float]] | None = None
|
||||
test_phone_number: str | None = None
|
||||
timezone: str | None = None
|
||||
organization_pricing: dict[str, Union[float, str, bool]] | None = None
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""Pydantic schemas for knowledge base operations."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -29,6 +29,11 @@ class ProcessDocumentRequestSchema(BaseModel):
|
|||
|
||||
document_uuid: str = Field(..., description="Document UUID to process")
|
||||
s3_key: str = Field(..., description="S3 key of the uploaded file")
|
||||
embedding_service: Literal["sentence_transformer", "openai"] = Field(
|
||||
default="openai",
|
||||
description="Embedding service to use for processing. "
|
||||
"Options: 'openai' (default, 1536-dim, requires API key) or 'sentence_transformer' (free, 384-dim)",
|
||||
)
|
||||
|
||||
|
||||
class DocumentResponseSchema(BaseModel):
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from datetime import datetime
|
|||
from pydantic import BaseModel
|
||||
|
||||
from api.services.configuration.registry import (
|
||||
EmbeddingsConfig,
|
||||
LLMConfig,
|
||||
STTConfig,
|
||||
TTSConfig,
|
||||
|
|
@ -13,6 +14,7 @@ class UserConfiguration(BaseModel):
|
|||
llm: LLMConfig | None = None
|
||||
stt: STTConfig | None = None
|
||||
tts: TTSConfig | None = None
|
||||
embeddings: EmbeddingsConfig | None = None
|
||||
test_phone_number: str | None = None
|
||||
timezone: str | None = None
|
||||
last_validated_at: datetime | None = None
|
||||
|
|
|
|||
|
|
@ -48,6 +48,12 @@ class UserConfigurationValidator:
|
|||
status_list.extend(self._validate_service(configuration.llm, "llm"))
|
||||
status_list.extend(self._validate_service(configuration.stt, "stt"))
|
||||
status_list.extend(self._validate_service(configuration.tts, "tts"))
|
||||
# Embeddings is optional - only validate if configured
|
||||
status_list.extend(
|
||||
self._validate_service(
|
||||
configuration.embeddings, "embeddings", required=False
|
||||
)
|
||||
)
|
||||
|
||||
if status_list:
|
||||
raise ValueError(status_list)
|
||||
|
|
@ -55,11 +61,16 @@ class UserConfigurationValidator:
|
|||
return {"status": [{"model": "all", "message": "ok"}]}
|
||||
|
||||
def _validate_service(
|
||||
self, service_config: Optional[ServiceConfig], service_name: str
|
||||
self,
|
||||
service_config: Optional[ServiceConfig],
|
||||
service_name: str,
|
||||
required: bool = True,
|
||||
) -> list[APIKeyStatus]:
|
||||
"""Validate a service configuration and return any error statuses."""
|
||||
if not service_config:
|
||||
return [{"model": service_name, "message": "API key is missing"}]
|
||||
if required:
|
||||
return [{"model": service_name, "message": "API key is missing"}]
|
||||
return [] # Optional service not configured is OK
|
||||
|
||||
provider = service_config.provider
|
||||
api_key = service_config.api_key
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ left as ``None``.
|
|||
from api.services.configuration.registry import (
|
||||
DeepgramSTTConfiguration,
|
||||
ElevenlabsTTSConfiguration,
|
||||
OpenAIEmbeddingsConfiguration,
|
||||
OpenAILLMService,
|
||||
ServiceProviders,
|
||||
)
|
||||
|
|
@ -22,6 +23,7 @@ _DEFAULTS = {
|
|||
"llm": (ServiceProviders.OPENAI, OpenAILLMService),
|
||||
"tts": (ServiceProviders.ELEVENLABS, ElevenlabsTTSConfiguration),
|
||||
"stt": (ServiceProviders.DEEPGRAM, DeepgramSTTConfiguration),
|
||||
"embeddings": (ServiceProviders.OPENAI, OpenAIEmbeddingsConfiguration),
|
||||
}
|
||||
|
||||
# Public mapping of service name -> default provider
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ def mask_user_config(config: UserConfiguration) -> Dict[str, Any]:
|
|||
"llm": _mask_service(config.llm),
|
||||
"tts": _mask_service(config.tts),
|
||||
"stt": _mask_service(config.stt),
|
||||
"embeddings": _mask_service(config.embeddings),
|
||||
"test_phone_number": config.test_phone_number,
|
||||
"timezone": config.timezone,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing import Dict
|
|||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.configuration.masking import is_mask_of
|
||||
|
||||
SERVICE_FIELDS = ("llm", "tts", "stt")
|
||||
SERVICE_FIELDS = ("llm", "tts", "stt", "embeddings")
|
||||
|
||||
|
||||
def merge_user_configurations(
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ class ServiceType(Enum):
|
|||
LLM = auto()
|
||||
TTS = auto()
|
||||
STT = auto()
|
||||
EMBEDDINGS = auto()
|
||||
|
||||
|
||||
class ServiceProviders(str, Enum):
|
||||
|
|
@ -50,11 +51,16 @@ class BaseSTTConfiguration(BaseServiceConfiguration):
|
|||
model: str
|
||||
|
||||
|
||||
class BaseEmbeddingsConfiguration(BaseServiceConfiguration):
|
||||
model: str
|
||||
|
||||
|
||||
# Unified registry for all service types
|
||||
REGISTRY: Dict[ServiceType, Dict[str, Type[BaseServiceConfiguration]]] = {
|
||||
ServiceType.LLM: {},
|
||||
ServiceType.TTS: {},
|
||||
ServiceType.STT: {},
|
||||
ServiceType.EMBEDDINGS: {},
|
||||
}
|
||||
|
||||
T = TypeVar("T", bound=BaseServiceConfiguration)
|
||||
|
|
@ -93,6 +99,10 @@ def register_stt(cls: Type[BaseSTTConfiguration]):
|
|||
return register_service(ServiceType.STT)(cls)
|
||||
|
||||
|
||||
def register_embeddings(cls: Type[BaseEmbeddingsConfiguration]):
|
||||
return register_service(ServiceType.EMBEDDINGS)(cls)
|
||||
|
||||
|
||||
###################################################### LLM ########################################################################
|
||||
|
||||
# Suggested models for each provider (used for UI dropdown)
|
||||
|
|
@ -436,6 +446,27 @@ STTConfig = Annotated[
|
|||
Field(discriminator="provider"),
|
||||
]
|
||||
|
||||
ServiceConfig = Annotated[
|
||||
Union[LLMConfig, TTSConfig, STTConfig], Field(discriminator="provider")
|
||||
###################################################### EMBEDDINGS ########################################################################
|
||||
|
||||
OPENAI_EMBEDDING_MODELS = ["text-embedding-3-small"]
|
||||
|
||||
|
||||
@register_embeddings
|
||||
class OpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
|
||||
provider: Literal[ServiceProviders.OPENAI] = ServiceProviders.OPENAI
|
||||
model: str = Field(
|
||||
default="text-embedding-3-small",
|
||||
json_schema_extra={"examples": OPENAI_EMBEDDING_MODELS},
|
||||
)
|
||||
api_key: str
|
||||
|
||||
|
||||
EmbeddingsConfig = Annotated[
|
||||
Union[OpenAIEmbeddingsConfiguration],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
||||
ServiceConfig = Annotated[
|
||||
Union[LLMConfig, TTSConfig, STTConfig, EmbeddingsConfig],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -133,9 +133,7 @@ class S3FileSystem(BaseFileSystem):
|
|||
async with self.session.client(
|
||||
"s3", region_name=self.region_name
|
||||
) as s3_client:
|
||||
await s3_client.download_file(
|
||||
self.bucket_name, source_path, local_path
|
||||
)
|
||||
await s3_client.download_file(self.bucket_name, source_path, local_path)
|
||||
return True
|
||||
except ClientError:
|
||||
return False
|
||||
|
|
|
|||
15
api/services/gen_ai/__init__.py
Normal file
15
api/services/gen_ai/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
"""Generative AI services for embeddings and document processing."""
|
||||
|
||||
from .embedding import (
|
||||
BaseEmbeddingService,
|
||||
EmbeddingAPIKeyNotConfiguredError,
|
||||
OpenAIEmbeddingService,
|
||||
SentenceTransformerEmbeddingService,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseEmbeddingService",
|
||||
"EmbeddingAPIKeyNotConfiguredError",
|
||||
"SentenceTransformerEmbeddingService",
|
||||
"OpenAIEmbeddingService",
|
||||
]
|
||||
12
api/services/gen_ai/embedding/__init__.py
Normal file
12
api/services/gen_ai/embedding/__init__.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
"""Embedding services for document processing and retrieval."""
|
||||
|
||||
from .base import BaseEmbeddingService
|
||||
from .openai_service import EmbeddingAPIKeyNotConfiguredError, OpenAIEmbeddingService
|
||||
from .sentence_transformer_service import SentenceTransformerEmbeddingService
|
||||
|
||||
__all__ = [
|
||||
"BaseEmbeddingService",
|
||||
"EmbeddingAPIKeyNotConfiguredError",
|
||||
"SentenceTransformerEmbeddingService",
|
||||
"OpenAIEmbeddingService",
|
||||
]
|
||||
75
api/services/gen_ai/embedding/base.py
Normal file
75
api/services/gen_ai/embedding/base.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
"""Base class for embedding services."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class BaseEmbeddingService(ABC):
|
||||
"""Abstract base class for embedding services.
|
||||
|
||||
All embedding services (SentenceTransformer, OpenAI, etc.) should inherit from this class
|
||||
and implement the required methods.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_model_id(self) -> str:
|
||||
"""Return the model identifier.
|
||||
|
||||
Returns:
|
||||
String identifier for the model (e.g., 'sentence-transformers/all-MiniLM-L6-v2')
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_embedding_dimension(self) -> int:
|
||||
"""Return the embedding dimension.
|
||||
|
||||
Returns:
|
||||
Integer dimension of the embedding vectors
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a batch of texts.
|
||||
|
||||
Args:
|
||||
texts: List of text strings to embed
|
||||
|
||||
Returns:
|
||||
List of embedding vectors (each vector is a list of floats)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def embed_query(self, query: str) -> List[float]:
|
||||
"""Embed a single query text.
|
||||
|
||||
Args:
|
||||
query: Query text to embed
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def search_similar_chunks(
|
||||
self,
|
||||
query: str,
|
||||
organization_id: int,
|
||||
limit: int = 5,
|
||||
document_uuids: Optional[List[str]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search for similar chunks using vector similarity.
|
||||
|
||||
Args:
|
||||
query: Search query text
|
||||
organization_id: Organization ID for scoping
|
||||
limit: Maximum number of results to return
|
||||
document_uuids: Optional list of document UUIDs to filter by
|
||||
|
||||
Returns:
|
||||
List of dictionaries containing chunk data and similarity scores
|
||||
"""
|
||||
pass
|
||||
372
api/services/gen_ai/embedding/openai_service.py
Normal file
372
api/services/gen_ai/embedding/openai_service.py
Normal file
|
|
@ -0,0 +1,372 @@
|
|||
"""OpenAI embedding service.
|
||||
|
||||
This module provides document processing capabilities using:
|
||||
- OpenAI's text-embedding-3-small for embeddings (1536 dimensions)
|
||||
- Docling for document conversion and chunking
|
||||
- pgvector for vector similarity search
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from docling.chunking import HybridChunker
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
|
||||
from loguru import logger
|
||||
from openai import AsyncOpenAI
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from api.db.db_client import DBClient
|
||||
from api.db.models import KnowledgeBaseChunkModel
|
||||
|
||||
from .base import BaseEmbeddingService
|
||||
|
||||
# Model configuration
|
||||
DEFAULT_MODEL_ID = "text-embedding-3-small"
|
||||
EMBEDDING_DIMENSION = 1536 # Dimension for text-embedding-3-small
|
||||
|
||||
# For chunking, we'll use the same tokenizer as SentenceTransformer
|
||||
# since OpenAI uses similar tokenization
|
||||
TOKENIZER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
|
||||
class EmbeddingAPIKeyNotConfiguredError(Exception):
|
||||
"""Raised when OpenAI API key is not configured for embeddings."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
"OpenAI API key not configured. Please set your API key in "
|
||||
"Model Configurations > Embedding to use document processing."
|
||||
)
|
||||
|
||||
|
||||
class OpenAIEmbeddingService(BaseEmbeddingService):
|
||||
"""Embedding service using OpenAI's text-embedding-3-small."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_client: DBClient,
|
||||
api_key: Optional[str] = None,
|
||||
model_id: str = DEFAULT_MODEL_ID,
|
||||
max_tokens: int = 512,
|
||||
):
|
||||
"""Initialize the OpenAI embedding service.
|
||||
|
||||
Args:
|
||||
db_client: Database client for storing documents and chunks
|
||||
api_key: OpenAI API key. If not provided, the client will not be
|
||||
initialized and operations will fail with a clear error.
|
||||
model_id: OpenAI embedding model ID (default: text-embedding-3-small)
|
||||
max_tokens: Maximum number of tokens per chunk (default: 512)
|
||||
"""
|
||||
self.db = db_client
|
||||
self.model_id = model_id
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
# Only initialize OpenAI client if API key is provided
|
||||
self._api_key_configured = bool(api_key)
|
||||
if self._api_key_configured:
|
||||
self.client = AsyncOpenAI(api_key=api_key)
|
||||
logger.info(f"OpenAI embedding service initialized with model: {model_id}")
|
||||
else:
|
||||
self.client = None
|
||||
logger.warning(
|
||||
"OpenAI embedding service initialized without API key. "
|
||||
"Operations will fail until API key is configured in Model Configurations."
|
||||
)
|
||||
|
||||
# Initialize tokenizer for chunking
|
||||
# We use a HuggingFace tokenizer for consistent chunking
|
||||
logger.info(
|
||||
f"Loading tokenizer for chunking: {TOKENIZER_MODEL} with max_tokens={max_tokens}"
|
||||
)
|
||||
try:
|
||||
self.tokenizer = HuggingFaceTokenizer(
|
||||
tokenizer=AutoTokenizer.from_pretrained(
|
||||
TOKENIZER_MODEL,
|
||||
local_files_only=True,
|
||||
),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
logger.info("Loaded tokenizer from cache")
|
||||
except Exception as e:
|
||||
logger.warning(f"Tokenizer not in cache, downloading: {e}")
|
||||
self.tokenizer = HuggingFaceTokenizer(
|
||||
tokenizer=AutoTokenizer.from_pretrained(TOKENIZER_MODEL),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
logger.info("Tokenizer downloaded and cached")
|
||||
|
||||
# Initialize chunker
|
||||
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
|
||||
self.chunker = HybridChunker(tokenizer=self.tokenizer)
|
||||
|
||||
# Initialize document converter
|
||||
self.converter = DocumentConverter()
|
||||
|
||||
def get_model_id(self) -> str:
|
||||
"""Return the model identifier."""
|
||||
return self.model_id
|
||||
|
||||
def get_embedding_dimension(self) -> int:
|
||||
"""Return the embedding dimension."""
|
||||
return EMBEDDING_DIMENSION
|
||||
|
||||
def _ensure_api_key_configured(self):
|
||||
"""Check if API key is configured and raise error if not."""
|
||||
if not self._api_key_configured or self.client is None:
|
||||
raise EmbeddingAPIKeyNotConfiguredError()
|
||||
|
||||
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a batch of texts using OpenAI API.
|
||||
|
||||
Args:
|
||||
texts: List of text strings to embed
|
||||
|
||||
Returns:
|
||||
List of embedding vectors (each vector is a list of floats)
|
||||
|
||||
Raises:
|
||||
EmbeddingAPIKeyNotConfiguredError: If API key is not configured
|
||||
"""
|
||||
self._ensure_api_key_configured()
|
||||
|
||||
try:
|
||||
# OpenAI API call
|
||||
response = await self.client.embeddings.create(
|
||||
input=texts,
|
||||
model=self.model_id,
|
||||
)
|
||||
|
||||
# Extract embeddings from response
|
||||
embeddings = [item.embedding for item in response.data]
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating OpenAI embeddings: {e}")
|
||||
raise
|
||||
|
||||
async def embed_query(self, query: str) -> List[float]:
|
||||
"""Embed a single query text using OpenAI API.
|
||||
|
||||
Args:
|
||||
query: Query text to embed
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats
|
||||
|
||||
Raises:
|
||||
EmbeddingAPIKeyNotConfiguredError: If API key is not configured
|
||||
"""
|
||||
self._ensure_api_key_configured()
|
||||
embeddings = await self.embed_texts([query])
|
||||
return embeddings[0]
|
||||
|
||||
async def search_similar_chunks(
|
||||
self,
|
||||
query: str,
|
||||
organization_id: int,
|
||||
limit: int = 5,
|
||||
document_uuids: Optional[List[str]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search for similar chunks using vector similarity.
|
||||
|
||||
Args:
|
||||
query: Search query text
|
||||
organization_id: Organization ID for scoping
|
||||
limit: Maximum number of results to return
|
||||
document_uuids: Optional list of document UUIDs to filter by
|
||||
|
||||
Returns:
|
||||
List of dictionaries with chunk data and similarity scores
|
||||
|
||||
Raises:
|
||||
EmbeddingAPIKeyNotConfiguredError: If API key is not configured
|
||||
"""
|
||||
self._ensure_api_key_configured()
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await self.embed_query(query)
|
||||
|
||||
# Perform vector similarity search
|
||||
results = await self.db.search_similar_chunks(
|
||||
query_embedding=query_embedding,
|
||||
organization_id=organization_id,
|
||||
limit=limit,
|
||||
document_uuids=document_uuids,
|
||||
embedding_model=self.model_id,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def process_document(
|
||||
self,
|
||||
file_path: str,
|
||||
organization_id: int,
|
||||
created_by: int,
|
||||
custom_metadata: dict = None,
|
||||
):
|
||||
"""Process a document: convert, chunk, embed, and store in database.
|
||||
|
||||
Args:
|
||||
file_path: Path to the document file
|
||||
organization_id: Organization ID for scoping
|
||||
created_by: User ID who uploaded the document
|
||||
custom_metadata: Optional custom metadata dictionary
|
||||
|
||||
Returns:
|
||||
The created document record
|
||||
"""
|
||||
try:
|
||||
# Extract file metadata
|
||||
filename = Path(file_path).name
|
||||
file_hash = self.db.compute_file_hash(file_path)
|
||||
file_size = os.path.getsize(file_path)
|
||||
mime_type = self.db.get_mime_type(file_path)
|
||||
|
||||
# Check if document already exists
|
||||
existing_doc = await self.db.get_document_by_hash(
|
||||
file_hash, organization_id
|
||||
)
|
||||
if existing_doc:
|
||||
logger.info(f"Document already exists: {filename} (hash: {file_hash})")
|
||||
return existing_doc
|
||||
|
||||
# Create document record
|
||||
doc_record = await self.db.create_document(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
filename=filename,
|
||||
file_size_bytes=file_size,
|
||||
file_hash=file_hash,
|
||||
mime_type=mime_type,
|
||||
custom_metadata=custom_metadata or {},
|
||||
)
|
||||
|
||||
logger.info(f"Processing document with OpenAI embeddings: {filename}")
|
||||
|
||||
# Update status to processing
|
||||
await self.db.update_document_status(doc_record.id, "processing")
|
||||
|
||||
# Step 1: Convert document using docling
|
||||
logger.info("Converting document with docling...")
|
||||
conversion_result = self.converter.convert(file_path)
|
||||
doc = conversion_result.document
|
||||
|
||||
# Store docling metadata
|
||||
docling_metadata = {
|
||||
"num_pages": len(doc.pages) if hasattr(doc, "pages") else None,
|
||||
"document_type": type(doc).__name__,
|
||||
}
|
||||
|
||||
# Step 2: Chunk the document
|
||||
logger.info(f"Chunking document with max_tokens={self.max_tokens}...")
|
||||
chunks = list(self.chunker.chunk(dl_doc=doc))
|
||||
total_chunks = len(chunks)
|
||||
|
||||
logger.info(f"Generated {total_chunks} chunks")
|
||||
|
||||
# Step 3: Process each chunk
|
||||
chunk_texts = []
|
||||
chunk_records = []
|
||||
token_counts = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Get chunk text
|
||||
chunk_text = chunk.text
|
||||
|
||||
# Get contextualized text
|
||||
contextualized_text = self.chunker.contextualize(chunk=chunk)
|
||||
|
||||
# Calculate token count
|
||||
text_to_tokenize = (
|
||||
contextualized_text if contextualized_text else chunk_text
|
||||
)
|
||||
token_count = len(
|
||||
self.tokenizer.tokenizer.encode(
|
||||
text_to_tokenize, add_special_tokens=False
|
||||
)
|
||||
)
|
||||
token_counts.append(token_count)
|
||||
|
||||
# Prepare chunk metadata
|
||||
chunk_metadata = {}
|
||||
if hasattr(chunk, "meta") and chunk.meta:
|
||||
chunk_metadata = {
|
||||
"doc_items": (
|
||||
[str(item) for item in chunk.meta.doc_items]
|
||||
if hasattr(chunk.meta, "doc_items")
|
||||
else []
|
||||
),
|
||||
"headings": (
|
||||
chunk.meta.headings
|
||||
if hasattr(chunk.meta, "headings")
|
||||
else []
|
||||
),
|
||||
}
|
||||
|
||||
# Create chunk record (without embedding yet)
|
||||
chunk_record = KnowledgeBaseChunkModel(
|
||||
document_id=doc_record.id,
|
||||
organization_id=organization_id,
|
||||
chunk_text=chunk_text,
|
||||
contextualized_text=contextualized_text,
|
||||
chunk_index=i,
|
||||
chunk_metadata=chunk_metadata,
|
||||
embedding_model=self.model_id,
|
||||
embedding_dimension=EMBEDDING_DIMENSION,
|
||||
token_count=token_count,
|
||||
)
|
||||
|
||||
chunk_records.append(chunk_record)
|
||||
chunk_texts.append(text_to_tokenize)
|
||||
|
||||
# Log chunk statistics
|
||||
if token_counts:
|
||||
avg_tokens = sum(token_counts) / len(token_counts)
|
||||
min_tokens = min(token_counts)
|
||||
max_tokens = max(token_counts)
|
||||
logger.info("Chunk token statistics:")
|
||||
logger.info(f" - Average: {avg_tokens:.1f} tokens")
|
||||
logger.info(f" - Min: {min_tokens} tokens")
|
||||
logger.info(f" - Max: {max_tokens} tokens")
|
||||
|
||||
# Step 4: Generate embeddings using OpenAI API
|
||||
logger.info(f"Generating embeddings using OpenAI ({self.model_id})...")
|
||||
embeddings = await self.embed_texts(chunk_texts)
|
||||
|
||||
# Step 5: Attach embeddings to chunk records
|
||||
for chunk_record, embedding in zip(chunk_records, embeddings):
|
||||
chunk_record.embedding = embedding
|
||||
|
||||
# Step 6: Save all chunks in batch
|
||||
logger.info("Storing chunks in database...")
|
||||
await self.db.create_chunks_batch(chunk_records)
|
||||
|
||||
# Update document status to completed
|
||||
await self.db.update_document_status(
|
||||
doc_record.id,
|
||||
"completed",
|
||||
total_chunks=total_chunks,
|
||||
docling_metadata=docling_metadata,
|
||||
)
|
||||
|
||||
logger.info(f"Successfully processed document: {filename}")
|
||||
logger.info(f" - Total chunks: {total_chunks}")
|
||||
logger.info(f" - Embedding model: {self.model_id}")
|
||||
logger.info(f" - Document ID: {doc_record.id}")
|
||||
logger.info(f" - Document UUID: {doc_record.document_uuid}")
|
||||
|
||||
return doc_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing document with OpenAI: {e}")
|
||||
|
||||
# Update document status to failed if it exists
|
||||
if "doc_record" in locals():
|
||||
await self.db.update_document_status(
|
||||
doc_record.id, "failed", error_message=str(e)
|
||||
)
|
||||
|
||||
raise
|
||||
350
api/services/gen_ai/embedding/sentence_transformer_service.py
Normal file
350
api/services/gen_ai/embedding/sentence_transformer_service.py
Normal file
|
|
@ -0,0 +1,350 @@
|
|||
"""Sentence Transformer embedding service.
|
||||
|
||||
This module provides document processing capabilities using:
|
||||
- Sentence-transformers for embeddings (all-MiniLM-L6-v2)
|
||||
- Docling for document conversion and chunking
|
||||
- pgvector for vector similarity search
|
||||
|
||||
Setup for offline usage:
|
||||
1. First run: Downloads and caches models to ~/.cache/sentence_transformers
|
||||
2. Subsequent runs: Uses cached models (no internet needed)
|
||||
3. For fully offline mode: Set TRANSFORMERS_OFFLINE=1 and HF_HUB_OFFLINE=1
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from docling.chunking import HybridChunker
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
|
||||
from loguru import logger
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from api.db.db_client import DBClient
|
||||
from api.db.models import KnowledgeBaseChunkModel
|
||||
|
||||
from .base import BaseEmbeddingService
|
||||
|
||||
# Set environment variables for model caching
|
||||
os.environ.setdefault("TRANSFORMERS_OFFLINE", "0")
|
||||
os.environ.setdefault("HF_HUB_OFFLINE", "0")
|
||||
os.environ.setdefault(
|
||||
"SENTENCE_TRANSFORMERS_HOME", os.path.expanduser("~/.cache/sentence_transformers")
|
||||
)
|
||||
|
||||
# Model configuration
|
||||
DEFAULT_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
EMBEDDING_DIMENSION = 384 # Dimension for all-MiniLM-L6-v2
|
||||
|
||||
|
||||
class SentenceTransformerEmbeddingService(BaseEmbeddingService):
|
||||
"""Embedding service using Sentence Transformers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_client: DBClient,
|
||||
model_id: str = DEFAULT_MODEL_ID,
|
||||
max_tokens: int = 512,
|
||||
):
|
||||
"""Initialize the Sentence Transformer embedding service.
|
||||
|
||||
Args:
|
||||
db_client: Database client for storing documents and chunks
|
||||
model_id: Sentence-transformers model ID (default: all-MiniLM-L6-v2)
|
||||
max_tokens: Maximum number of tokens per chunk (default: 512)
|
||||
Note: This applies to the contextualized text (with headings/captions)
|
||||
"""
|
||||
self.db = db_client
|
||||
self.model_id = model_id
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
# Initialize embedding model
|
||||
logger.info(f"Loading embedding model: {model_id}")
|
||||
try:
|
||||
# Try to load from cache first (local_files_only=True)
|
||||
self.embedding_model = SentenceTransformer(
|
||||
model_id,
|
||||
cache_folder=os.environ.get("SENTENCE_TRANSFORMERS_HOME"),
|
||||
local_files_only=True,
|
||||
)
|
||||
logger.info("Loaded model from cache")
|
||||
except Exception as e:
|
||||
logger.warning(f"Model not in cache, downloading: {e}")
|
||||
# If not in cache, download it (this will cache it for next time)
|
||||
self.embedding_model = SentenceTransformer(
|
||||
model_id,
|
||||
cache_folder=os.environ.get("SENTENCE_TRANSFORMERS_HOME"),
|
||||
)
|
||||
logger.info("Model downloaded and cached")
|
||||
|
||||
# Initialize tokenizer for chunking with max_tokens
|
||||
logger.info(f"Loading tokenizer: {model_id} with max_tokens={max_tokens}")
|
||||
try:
|
||||
# Try to load from cache first
|
||||
self.tokenizer = HuggingFaceTokenizer(
|
||||
tokenizer=AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
local_files_only=True,
|
||||
),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
logger.info("Loaded tokenizer from cache")
|
||||
except Exception as e:
|
||||
logger.warning(f"Tokenizer not in cache, downloading: {e}")
|
||||
# If not in cache, download it
|
||||
self.tokenizer = HuggingFaceTokenizer(
|
||||
tokenizer=AutoTokenizer.from_pretrained(model_id),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
logger.info("Tokenizer downloaded and cached")
|
||||
|
||||
# Initialize chunker
|
||||
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
|
||||
self.chunker = HybridChunker(tokenizer=self.tokenizer)
|
||||
|
||||
# Initialize document converter
|
||||
self.converter = DocumentConverter()
|
||||
|
||||
def get_model_id(self) -> str:
|
||||
"""Return the model identifier."""
|
||||
return self.model_id
|
||||
|
||||
def get_embedding_dimension(self) -> int:
|
||||
"""Return the embedding dimension."""
|
||||
return EMBEDDING_DIMENSION
|
||||
|
||||
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a batch of texts.
|
||||
|
||||
Args:
|
||||
texts: List of text strings to embed
|
||||
|
||||
Returns:
|
||||
List of embedding vectors (each vector is a list of floats)
|
||||
"""
|
||||
embeddings = self.embedding_model.encode(
|
||||
texts,
|
||||
show_progress_bar=False,
|
||||
convert_to_numpy=True,
|
||||
)
|
||||
return [embedding.tolist() for embedding in embeddings]
|
||||
|
||||
async def embed_query(self, query: str) -> List[float]:
|
||||
"""Embed a single query text.
|
||||
|
||||
Args:
|
||||
query: Query text to embed
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats
|
||||
"""
|
||||
embedding = self.embedding_model.encode([query])[0]
|
||||
return embedding.tolist()
|
||||
|
||||
async def search_similar_chunks(
|
||||
self,
|
||||
query: str,
|
||||
organization_id: int,
|
||||
limit: int = 5,
|
||||
document_uuids: Optional[List[str]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search for similar chunks using vector similarity.
|
||||
|
||||
Returns top-k most similar chunks without any threshold filtering.
|
||||
Apply similarity thresholds and reranking at the application layer.
|
||||
|
||||
Args:
|
||||
query: Search query text
|
||||
organization_id: Organization ID for scoping
|
||||
limit: Maximum number of results to return
|
||||
document_uuids: Optional list of document UUIDs to filter by
|
||||
|
||||
Returns:
|
||||
List of dictionaries with chunk data and similarity scores
|
||||
"""
|
||||
# Generate query embedding
|
||||
query_embedding = await self.embed_query(query)
|
||||
|
||||
# Perform vector similarity search
|
||||
results = await self.db.search_similar_chunks(
|
||||
query_embedding=query_embedding,
|
||||
organization_id=organization_id,
|
||||
limit=limit,
|
||||
document_uuids=document_uuids,
|
||||
embedding_model=self.model_id,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def process_document(
|
||||
self,
|
||||
file_path: str,
|
||||
organization_id: int,
|
||||
created_by: int,
|
||||
custom_metadata: dict = None,
|
||||
):
|
||||
"""Process a document: convert, chunk, embed, and store in database.
|
||||
|
||||
Args:
|
||||
file_path: Path to the document file
|
||||
organization_id: Organization ID for scoping
|
||||
created_by: User ID who uploaded the document
|
||||
custom_metadata: Optional custom metadata dictionary
|
||||
|
||||
Returns:
|
||||
The created document record
|
||||
"""
|
||||
try:
|
||||
# Extract file metadata
|
||||
filename = Path(file_path).name
|
||||
file_hash = self.db.compute_file_hash(file_path)
|
||||
file_size = os.path.getsize(file_path)
|
||||
mime_type = self.db.get_mime_type(file_path)
|
||||
|
||||
# Check if document already exists
|
||||
existing_doc = await self.db.get_document_by_hash(
|
||||
file_hash, organization_id
|
||||
)
|
||||
if existing_doc:
|
||||
logger.info(f"Document already exists: {filename} (hash: {file_hash})")
|
||||
return existing_doc
|
||||
|
||||
# Create document record
|
||||
doc_record = await self.db.create_document(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
filename=filename,
|
||||
file_size_bytes=file_size,
|
||||
file_hash=file_hash,
|
||||
mime_type=mime_type,
|
||||
custom_metadata=custom_metadata or {},
|
||||
)
|
||||
|
||||
logger.info(f"Processing document: {filename}")
|
||||
|
||||
# Update status to processing
|
||||
await self.db.update_document_status(doc_record.id, "processing")
|
||||
|
||||
# Step 1: Convert document using docling
|
||||
logger.info("Converting document with docling...")
|
||||
conversion_result = self.converter.convert(file_path)
|
||||
doc = conversion_result.document
|
||||
|
||||
# Store docling metadata
|
||||
docling_metadata = {
|
||||
"num_pages": len(doc.pages) if hasattr(doc, "pages") else None,
|
||||
"document_type": type(doc).__name__,
|
||||
}
|
||||
|
||||
# Step 2: Chunk the document
|
||||
logger.info(f"Chunking document with max_tokens={self.max_tokens}...")
|
||||
chunks = list(self.chunker.chunk(dl_doc=doc))
|
||||
total_chunks = len(chunks)
|
||||
|
||||
logger.info(f"Generated {total_chunks} chunks")
|
||||
|
||||
# Step 3: Process each chunk
|
||||
chunk_texts = []
|
||||
chunk_records = []
|
||||
token_counts = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Get chunk text
|
||||
chunk_text = chunk.text
|
||||
|
||||
# Get contextualized text (enriched with surrounding context)
|
||||
contextualized_text = self.chunker.contextualize(chunk=chunk)
|
||||
|
||||
# Calculate actual token count using the tokenizer
|
||||
text_to_tokenize = (
|
||||
contextualized_text if contextualized_text else chunk_text
|
||||
)
|
||||
token_count = len(
|
||||
self.tokenizer.tokenizer.encode(
|
||||
text_to_tokenize, add_special_tokens=False
|
||||
)
|
||||
)
|
||||
token_counts.append(token_count)
|
||||
|
||||
# Prepare chunk metadata
|
||||
chunk_metadata = {}
|
||||
if hasattr(chunk, "meta") and chunk.meta:
|
||||
chunk_metadata = {
|
||||
"doc_items": (
|
||||
[str(item) for item in chunk.meta.doc_items]
|
||||
if hasattr(chunk.meta, "doc_items")
|
||||
else []
|
||||
),
|
||||
"headings": (
|
||||
chunk.meta.headings
|
||||
if hasattr(chunk.meta, "headings")
|
||||
else []
|
||||
),
|
||||
}
|
||||
|
||||
# Create chunk record (without embedding yet)
|
||||
chunk_record = KnowledgeBaseChunkModel(
|
||||
document_id=doc_record.id,
|
||||
organization_id=organization_id,
|
||||
chunk_text=chunk_text,
|
||||
contextualized_text=contextualized_text,
|
||||
chunk_index=i,
|
||||
chunk_metadata=chunk_metadata,
|
||||
embedding_model=self.model_id,
|
||||
embedding_dimension=EMBEDDING_DIMENSION,
|
||||
token_count=token_count,
|
||||
)
|
||||
|
||||
chunk_records.append(chunk_record)
|
||||
# Use contextualized text for embedding if available
|
||||
chunk_texts.append(text_to_tokenize)
|
||||
|
||||
# Log chunk statistics
|
||||
if token_counts:
|
||||
avg_tokens = sum(token_counts) / len(token_counts)
|
||||
min_tokens = min(token_counts)
|
||||
max_tokens = max(token_counts)
|
||||
logger.info("Chunk token statistics:")
|
||||
logger.info(f" - Average: {avg_tokens:.1f} tokens")
|
||||
logger.info(f" - Min: {min_tokens} tokens")
|
||||
logger.info(f" - Max: {max_tokens} tokens")
|
||||
|
||||
# Step 4: Generate embeddings in batch
|
||||
logger.info("Generating embeddings...")
|
||||
embeddings = await self.embed_texts(chunk_texts)
|
||||
|
||||
# Step 5: Attach embeddings to chunk records
|
||||
for chunk_record, embedding in zip(chunk_records, embeddings):
|
||||
chunk_record.embedding = embedding
|
||||
|
||||
# Step 6: Save all chunks in batch
|
||||
logger.info("Storing chunks in database...")
|
||||
await self.db.create_chunks_batch(chunk_records)
|
||||
|
||||
# Update document status to completed
|
||||
await self.db.update_document_status(
|
||||
doc_record.id,
|
||||
"completed",
|
||||
total_chunks=total_chunks,
|
||||
docling_metadata=docling_metadata,
|
||||
)
|
||||
|
||||
logger.info(f"Successfully processed document: {filename}")
|
||||
logger.info(f" - Total chunks: {total_chunks}")
|
||||
logger.info(f" - Document ID: {doc_record.id}")
|
||||
logger.info(f" - Document UUID: {doc_record.document_uuid}")
|
||||
|
||||
return doc_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing document: {e}")
|
||||
|
||||
# Update document status to failed if it exists
|
||||
if "doc_record" in locals():
|
||||
await self.db.update_document_status(
|
||||
doc_record.id, "failed", error_message=str(e)
|
||||
)
|
||||
|
||||
raise
|
||||
44
api/services/pricing/embeddings.py
Normal file
44
api/services/pricing/embeddings.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
"""
|
||||
Embeddings pricing models for different providers.
|
||||
|
||||
Prices are per token for embedding models.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import PricingModel
|
||||
|
||||
|
||||
class EmbeddingPricingModel(PricingModel):
|
||||
"""Pricing model for token-based embedding services."""
|
||||
|
||||
def __init__(self, token_price: Decimal):
|
||||
"""Initialize with price per token.
|
||||
|
||||
Args:
|
||||
token_price: Cost per token for embedding
|
||||
"""
|
||||
self.token_price = token_price
|
||||
|
||||
def calculate_cost(self, token_count: int) -> Decimal:
|
||||
"""Calculate cost for embedding token usage."""
|
||||
return Decimal(token_count) * self.token_price
|
||||
|
||||
|
||||
# Embeddings pricing registry
|
||||
EMBEDDINGS_PRICING: Dict[str, Dict[str, EmbeddingPricingModel]] = {
|
||||
ServiceProviders.OPENAI: {
|
||||
"text-embedding-3-small": EmbeddingPricingModel(
|
||||
token_price=Decimal("0.02") / 1_000_000, # $0.02 per 1M tokens
|
||||
),
|
||||
"text-embedding-3-large": EmbeddingPricingModel(
|
||||
token_price=Decimal("0.13") / 1_000_000, # $0.13 per 1M tokens
|
||||
),
|
||||
"text-embedding-ada-002": EmbeddingPricingModel(
|
||||
token_price=Decimal("0.10") / 1_000_000, # $0.10 per 1M tokens (legacy)
|
||||
),
|
||||
},
|
||||
}
|
||||
|
|
@ -4,6 +4,7 @@ Main pricing registry that combines all service type pricing models.
|
|||
|
||||
from typing import Dict
|
||||
|
||||
from .embeddings import EMBEDDINGS_PRICING
|
||||
from .llm import LLM_PRICING
|
||||
from .stt import STT_PRICING
|
||||
from .tts import TTS_PRICING
|
||||
|
|
@ -13,4 +14,5 @@ PRICING_REGISTRY: Dict = {
|
|||
"llm": LLM_PRICING,
|
||||
"tts": TTS_PRICING,
|
||||
"stt": STT_PRICING,
|
||||
"embeddings": EMBEDDINGS_PRICING,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ class NodeDataDTO(BaseModel):
|
|||
delayed_start: bool = False
|
||||
delayed_start_duration: Optional[float] = None
|
||||
tool_uuids: Optional[List[str]] = None
|
||||
document_uuids: Optional[List[str]] = None
|
||||
trigger_path: Optional[str] = None
|
||||
# Webhook node specific fields
|
||||
enabled: bool = True
|
||||
|
|
|
|||
|
|
@ -41,6 +41,10 @@ from api.services.workflow.pipecat_engine_variable_extractor import (
|
|||
VariableExtractionManager,
|
||||
)
|
||||
from api.services.workflow.tools.calculator import get_calculator_tools, safe_calculator
|
||||
from api.services.workflow.tools.knowledge_base import (
|
||||
get_knowledge_base_tool,
|
||||
retrieve_from_knowledge_base,
|
||||
)
|
||||
from api.services.workflow.tools.timezone import (
|
||||
convert_time,
|
||||
get_current_time,
|
||||
|
|
@ -290,6 +294,48 @@ class PipecatEngine:
|
|||
self.llm.register_function("get_current_time", get_current_time_func)
|
||||
self.llm.register_function("convert_time", convert_time_func)
|
||||
|
||||
async def _register_knowledge_base_function(
|
||||
self, document_uuids: list[str]
|
||||
) -> None:
|
||||
"""Register knowledge base retrieval function with the LLM.
|
||||
|
||||
Args:
|
||||
document_uuids: List of document UUIDs to filter the search by
|
||||
"""
|
||||
logger.debug(
|
||||
f"Registering knowledge base retrieval function with {len(document_uuids)} document(s)"
|
||||
)
|
||||
|
||||
async def retrieve_kb_func(function_call_params: FunctionCallParams) -> None:
|
||||
logger.info("LLM Function Call EXECUTED: retrieve_from_knowledge_base")
|
||||
logger.info(f"Arguments: {function_call_params.arguments}")
|
||||
try:
|
||||
query = function_call_params.arguments.get("query", "")
|
||||
organization_id = await self._get_organization_id()
|
||||
|
||||
if not organization_id:
|
||||
raise ValueError(
|
||||
"Organization ID not available for knowledge base retrieval"
|
||||
)
|
||||
|
||||
result = await retrieve_from_knowledge_base(
|
||||
query=query,
|
||||
organization_id=organization_id,
|
||||
document_uuids=document_uuids,
|
||||
limit=3, # Return top 3 most relevant chunks
|
||||
)
|
||||
|
||||
await function_call_params.result_callback(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Knowledge base retrieval failed: {e}")
|
||||
await function_call_params.result_callback(
|
||||
{"error": str(e), "chunks": [], "query": query, "total_results": 0}
|
||||
)
|
||||
|
||||
# Register the function with the LLM
|
||||
self.llm.register_function("retrieve_from_knowledge_base", retrieve_kb_func)
|
||||
|
||||
async def _perform_variable_extraction_if_needed(
|
||||
self, previous_node: Optional[Node]
|
||||
) -> None:
|
||||
|
|
@ -346,6 +392,10 @@ class PipecatEngine:
|
|||
if node.tool_uuids and self._custom_tool_manager:
|
||||
await self._custom_tool_manager.register_handlers(node.tool_uuids)
|
||||
|
||||
# Register knowledge base retrieval handler if node has documents
|
||||
if node.document_uuids:
|
||||
await self._register_knowledge_base_function(node.document_uuids)
|
||||
|
||||
# Set up system message and functions
|
||||
(
|
||||
system_message,
|
||||
|
|
@ -575,6 +625,17 @@ class PipecatEngine:
|
|||
# Add built-in function schemas (calculator and timezone tools)
|
||||
functions.extend(self.builtin_function_schemas)
|
||||
|
||||
# Add knowledge base retrieval tool if node has documents
|
||||
if node.document_uuids:
|
||||
kb_tool_def = get_knowledge_base_tool(node.document_uuids)
|
||||
kb_schema = get_function_schema(
|
||||
kb_tool_def["function"]["name"],
|
||||
kb_tool_def["function"]["description"],
|
||||
properties=kb_tool_def["function"]["parameters"].get("properties", {}),
|
||||
required=kb_tool_def["function"]["parameters"].get("required", []),
|
||||
)
|
||||
functions.append(kb_schema)
|
||||
|
||||
# Add custom tools from node.tool_uuids
|
||||
if node.tool_uuids and self._custom_tool_manager:
|
||||
custom_tool_schemas = await self._custom_tool_manager.get_tool_schemas(
|
||||
|
|
|
|||
305
api/services/workflow/tools/knowledge_base.py
Normal file
305
api/services/workflow/tools/knowledge_base.py
Normal file
|
|
@ -0,0 +1,305 @@
|
|||
"""Knowledge Base retrieval tool for workflow execution.
|
||||
|
||||
This module provides vector similarity search capabilities for retrieving
|
||||
relevant information from the knowledge base during conversations.
|
||||
|
||||
Implements OpenTelemetry tracing for observability in Langfuse.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from opentelemetry import trace
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.gen_ai import OpenAIEmbeddingService
|
||||
from api.services.pipecat.tracing_config import is_tracing_enabled
|
||||
from pipecat.utils.tracing.context_registry import (
|
||||
get_current_conversation_context,
|
||||
get_current_turn_context,
|
||||
)
|
||||
|
||||
|
||||
async def retrieve_from_knowledge_base(
|
||||
query: str,
|
||||
organization_id: int,
|
||||
document_uuids: Optional[List[str]] = None,
|
||||
limit: int = 3,
|
||||
embeddings_api_key: Optional[str] = None,
|
||||
embeddings_model: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Retrieve relevant information from the knowledge base using vector similarity search.
|
||||
|
||||
Uses OpenAI text-embedding-3-small for embeddings by default. This provides
|
||||
high-quality 1536-dimensional embeddings for accurate retrieval.
|
||||
|
||||
This function includes OpenTelemetry tracing for Langfuse observability.
|
||||
|
||||
Args:
|
||||
query: The search query to find relevant information
|
||||
organization_id: Organization ID for scoping the search
|
||||
document_uuids: Optional list of document UUIDs to filter by
|
||||
limit: Maximum number of chunks to return (default: 3)
|
||||
embeddings_api_key: Optional API key for embedding service
|
||||
embeddings_model: Optional model ID for embedding service
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- chunks: List of relevant text chunks with metadata
|
||||
- query: The original query
|
||||
- total_results: Number of results returned
|
||||
"""
|
||||
# Create span for retrieval operation if tracing is enabled
|
||||
if is_tracing_enabled():
|
||||
try:
|
||||
# Get parent context from turn or conversation
|
||||
turn_context = get_current_turn_context()
|
||||
conversation_context = get_current_conversation_context()
|
||||
parent_context = turn_context or conversation_context
|
||||
|
||||
# Get tracer
|
||||
tracer = trace.get_tracer("pipecat")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to setup tracing context: {e}")
|
||||
# Fall back to non-traced execution
|
||||
return await _perform_retrieval(
|
||||
query,
|
||||
organization_id,
|
||||
document_uuids,
|
||||
limit,
|
||||
embeddings_api_key,
|
||||
embeddings_model,
|
||||
)
|
||||
|
||||
# Create span with parent context
|
||||
if parent_context:
|
||||
with tracer.start_as_current_span(
|
||||
"knowledge_base_retrieval", context=parent_context
|
||||
) as span:
|
||||
try:
|
||||
# Mark trace as public for Langfuse
|
||||
span.set_attribute("langfuse.trace.public", True)
|
||||
|
||||
# Add operation metadata
|
||||
span.set_attribute(
|
||||
"gen_ai.operation.name", "knowledge_base_retrieval"
|
||||
)
|
||||
span.set_attribute("retrieval.query", query)
|
||||
span.set_attribute("retrieval.limit", limit)
|
||||
span.set_attribute("retrieval.organization_id", organization_id)
|
||||
|
||||
# Add document filter info
|
||||
if document_uuids:
|
||||
span.set_attribute(
|
||||
"retrieval.document_count", len(document_uuids)
|
||||
)
|
||||
span.set_attribute(
|
||||
"retrieval.document_uuids", json.dumps(document_uuids)
|
||||
)
|
||||
|
||||
# Perform the actual retrieval
|
||||
result = await _perform_retrieval(
|
||||
query,
|
||||
organization_id,
|
||||
document_uuids,
|
||||
limit,
|
||||
embeddings_api_key,
|
||||
embeddings_model,
|
||||
)
|
||||
|
||||
# Add result metadata to span
|
||||
span.set_attribute(
|
||||
"retrieval.results_count", result["total_results"]
|
||||
)
|
||||
|
||||
if result.get("error"):
|
||||
span.set_attribute("retrieval.error", result["error"])
|
||||
span.set_status(
|
||||
trace.Status(trace.StatusCode.ERROR, result["error"])
|
||||
)
|
||||
else:
|
||||
# Add similarity scores
|
||||
if result["chunks"]:
|
||||
similarities = [
|
||||
chunk["similarity"] for chunk in result["chunks"]
|
||||
]
|
||||
span.set_attribute(
|
||||
"retrieval.avg_similarity",
|
||||
round(sum(similarities) / len(similarities), 4),
|
||||
)
|
||||
span.set_attribute(
|
||||
"retrieval.max_similarity", max(similarities)
|
||||
)
|
||||
span.set_attribute(
|
||||
"retrieval.min_similarity", min(similarities)
|
||||
)
|
||||
|
||||
# Add retrieved documents info
|
||||
filenames = list(
|
||||
set(chunk["filename"] for chunk in result["chunks"])
|
||||
)
|
||||
span.set_attribute(
|
||||
"retrieval.source_files", json.dumps(filenames)
|
||||
)
|
||||
|
||||
# Add output as JSON for Langfuse
|
||||
output_data = {
|
||||
"query": query,
|
||||
"chunks_retrieved": len(result["chunks"]),
|
||||
"chunks": [
|
||||
{
|
||||
"text": chunk["text"][:200] + "..."
|
||||
if len(chunk["text"]) > 200
|
||||
else chunk["text"],
|
||||
"filename": chunk["filename"],
|
||||
"similarity": chunk["similarity"],
|
||||
}
|
||||
for chunk in result["chunks"]
|
||||
],
|
||||
}
|
||||
span.set_attribute("output", json.dumps(output_data))
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in traced retrieval: {e}")
|
||||
span.record_exception(e)
|
||||
span.set_status(trace.Status(trace.StatusCode.ERROR, str(e)))
|
||||
raise
|
||||
else:
|
||||
# No parent context - perform retrieval without tracing
|
||||
logger.debug(
|
||||
"No parent context available for knowledge base retrieval tracing"
|
||||
)
|
||||
return await _perform_retrieval(
|
||||
query,
|
||||
organization_id,
|
||||
document_uuids,
|
||||
limit,
|
||||
embeddings_api_key,
|
||||
embeddings_model,
|
||||
)
|
||||
else:
|
||||
# Tracing is disabled - perform retrieval without tracing
|
||||
return await _perform_retrieval(
|
||||
query,
|
||||
organization_id,
|
||||
document_uuids,
|
||||
limit,
|
||||
embeddings_api_key,
|
||||
embeddings_model,
|
||||
)
|
||||
|
||||
|
||||
async def _perform_retrieval(
|
||||
query: str,
|
||||
organization_id: int,
|
||||
document_uuids: Optional[List[str]],
|
||||
limit: int,
|
||||
embeddings_api_key: Optional[str] = None,
|
||||
embeddings_model: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Internal function to perform the actual retrieval operation.
|
||||
|
||||
Separated from tracing logic for cleaner code organization.
|
||||
Uses OpenAI embeddings by default for high-quality retrieval.
|
||||
"""
|
||||
try:
|
||||
# Create a new embedding service instance
|
||||
# Uses OpenAI text-embedding-3-small by default, or user-provided config
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
max_tokens=128, # This is only used for chunking, not for retrieval
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
)
|
||||
|
||||
# Perform vector similarity search
|
||||
results = await embedding_service.search_similar_chunks(
|
||||
query=query,
|
||||
organization_id=organization_id,
|
||||
limit=limit,
|
||||
document_uuids=document_uuids,
|
||||
)
|
||||
|
||||
# Format results for LLM consumption
|
||||
chunks = []
|
||||
for result in results:
|
||||
chunk_info = {
|
||||
"text": result.get("contextualized_text") or result.get("chunk_text"),
|
||||
"filename": result.get("filename"),
|
||||
"similarity": round(result.get("similarity", 0), 4),
|
||||
"chunk_index": result.get("chunk_index"),
|
||||
}
|
||||
chunks.append(chunk_info)
|
||||
|
||||
logger.info(
|
||||
f"Knowledge base retrieval: query='{query}', "
|
||||
f"results={len(chunks)}, "
|
||||
f"document_filter={document_uuids}"
|
||||
)
|
||||
|
||||
return {
|
||||
"chunks": chunks,
|
||||
"query": query,
|
||||
"total_results": len(chunks),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving from knowledge base: {e}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"chunks": [],
|
||||
"query": query,
|
||||
"total_results": 0,
|
||||
}
|
||||
|
||||
|
||||
def get_knowledge_base_tool(
|
||||
document_uuids: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get knowledge base retrieval tool definition for LLM function calling.
|
||||
|
||||
Args:
|
||||
document_uuids: Optional list of document UUIDs to include in description
|
||||
|
||||
Returns:
|
||||
Tool definition compatible with LLM function calling
|
||||
"""
|
||||
# Build description based on whether specific documents are filtered
|
||||
if document_uuids and len(document_uuids) > 0:
|
||||
description = (
|
||||
"Retrieve relevant information from specific documents in the knowledge base. "
|
||||
"Use this tool when you need to look up facts, policies, procedures, or any information "
|
||||
"that might be stored in the available documents. The search will only look in the "
|
||||
f"documents associated with this conversation step ({len(document_uuids)} document(s) available)."
|
||||
)
|
||||
else:
|
||||
description = (
|
||||
"Retrieve relevant information from the knowledge base. "
|
||||
"Use this tool when you need to look up facts, policies, procedures, or any information "
|
||||
"that might be stored in the knowledge base documents."
|
||||
)
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "retrieve_from_knowledge_base",
|
||||
"description": description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The search query to find relevant information. "
|
||||
"Be specific and use natural language. "
|
||||
"Example: 'What is the refund policy for canceled orders?'"
|
||||
),
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -48,6 +48,7 @@ class Node:
|
|||
self.delayed_start = data.delayed_start
|
||||
self.delayed_start_duration = data.delayed_start_duration
|
||||
self.tool_uuids = data.tool_uuids
|
||||
self.document_uuids = data.document_uuids
|
||||
|
||||
self.data = data
|
||||
|
||||
|
|
@ -189,16 +190,6 @@ class WorkflowGraph:
|
|||
in_d, out_d = in_deg[n.id], out_deg[n.id]
|
||||
|
||||
match n.node_type:
|
||||
case NodeType.startNode:
|
||||
if in_d != 0 or out_d < 1:
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.node,
|
||||
id=n.id,
|
||||
field=None,
|
||||
message=f"StartNode must have at least 1 outgoing edge",
|
||||
)
|
||||
)
|
||||
case NodeType.endNode:
|
||||
if in_d < 1 or out_d != 0:
|
||||
errors.append(
|
||||
|
|
|
|||
|
|
@ -2,25 +2,33 @@
|
|||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Literal
|
||||
|
||||
from docling.chunking import HybridChunker
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
|
||||
from loguru import logger
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import KnowledgeBaseChunkModel, KnowledgeBaseDocumentModel
|
||||
from api.db.models import KnowledgeBaseChunkModel
|
||||
from api.services.gen_ai import (
|
||||
OpenAIEmbeddingService,
|
||||
SentenceTransformerEmbeddingService,
|
||||
)
|
||||
from api.services.storage import storage_fs
|
||||
|
||||
# Constants
|
||||
EMBED_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
EMBEDDING_DIMENSION = 384
|
||||
# For tokenization/chunking - use SentenceTransformer tokenizer as baseline
|
||||
TOKENIZER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
|
||||
async def process_knowledge_base_document(
|
||||
ctx, document_id: int, s3_key: str, organization_id: int, max_tokens: int = 128
|
||||
ctx,
|
||||
document_id: int,
|
||||
s3_key: str,
|
||||
organization_id: int,
|
||||
max_tokens: int = 128,
|
||||
embedding_service: Literal["sentence_transformer", "openai"] = "openai",
|
||||
):
|
||||
"""Process a knowledge base document: download, chunk, embed, and store.
|
||||
|
||||
|
|
@ -30,6 +38,9 @@ async def process_knowledge_base_document(
|
|||
s3_key: S3 key where the file is stored
|
||||
organization_id: Organization ID
|
||||
max_tokens: Maximum number of tokens per chunk (default: 128)
|
||||
embedding_service: Embedding service to use (default: "openai")
|
||||
- "openai": Use OpenAI text-embedding-3-small (1536-dim, requires API key)
|
||||
- "sentence_transformer": Use SentenceTransformer (all-MiniLM-L6-v2, 384-dim, free)
|
||||
"""
|
||||
logger.info(
|
||||
f"Starting knowledge base document processing for document_id={document_id}, "
|
||||
|
|
@ -42,8 +53,14 @@ async def process_knowledge_base_document(
|
|||
# Update status to processing
|
||||
await db_client.update_document_status(document_id, "processing")
|
||||
|
||||
# Create temp file for download
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
|
||||
# Extract file extension from S3 key
|
||||
filename = s3_key.split("/")[-1]
|
||||
file_extension = (
|
||||
os.path.splitext(filename)[1] or ".bin"
|
||||
) # Default to .bin if no extension
|
||||
|
||||
# Create temp file for download with correct extension
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=file_extension)
|
||||
temp_file_path = temp_file.name
|
||||
temp_file.close()
|
||||
|
||||
|
|
@ -108,27 +125,58 @@ async def process_knowledge_base_document(
|
|||
mime_type=mime_type,
|
||||
)
|
||||
|
||||
# Initialize models for processing
|
||||
cache_dir = os.path.expanduser("~/.cache/hf_models")
|
||||
logger.info(f"Loading embedding model: {EMBED_MODEL_ID} (cache: {cache_dir})")
|
||||
embedding_model = SentenceTransformer(
|
||||
EMBED_MODEL_ID,
|
||||
cache_folder=cache_dir,
|
||||
)
|
||||
# Initialize the embedding service based on the parameter
|
||||
if embedding_service == "openai":
|
||||
logger.info(
|
||||
f"Initializing OpenAI embedding service with max_tokens={max_tokens}"
|
||||
)
|
||||
# Try to get user's embeddings configuration
|
||||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
if document.created_by:
|
||||
user_config = await db_client.get_user_configurations(
|
||||
document.created_by
|
||||
)
|
||||
if user_config.embeddings:
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
logger.info(
|
||||
f"Using user embeddings config: model={embeddings_model}"
|
||||
)
|
||||
|
||||
logger.info(f"Loading tokenizer: {EMBED_MODEL_ID} (cache: {cache_dir})")
|
||||
tokenizer = HuggingFaceTokenizer(
|
||||
tokenizer=AutoTokenizer.from_pretrained(
|
||||
EMBED_MODEL_ID,
|
||||
cache_dir=cache_dir,
|
||||
),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
# Check if API key is configured
|
||||
if not embeddings_api_key:
|
||||
error_message = (
|
||||
"OpenAI API key not configured. Please set your API key in "
|
||||
"Model Configurations > Embedding to process documents."
|
||||
)
|
||||
logger.warning(f"Document {document_id}: {error_message}")
|
||||
await db_client.update_document_status(
|
||||
document_id, "failed", error_message=error_message
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
|
||||
chunker = HybridChunker(tokenizer=tokenizer)
|
||||
service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
max_tokens=max_tokens,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
)
|
||||
elif embedding_service == "sentence_transformer":
|
||||
logger.info(
|
||||
f"Initializing SentenceTransformer embedding service with max_tokens={max_tokens}"
|
||||
)
|
||||
service = SentenceTransformerEmbeddingService(
|
||||
db_client=db_client,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid embedding_service: {embedding_service}. "
|
||||
f"Must be 'sentence_transformer' or 'openai'"
|
||||
)
|
||||
|
||||
# Convert document with docling
|
||||
# Step 1: Convert document with docling
|
||||
logger.info("Converting document with docling")
|
||||
converter = DocumentConverter()
|
||||
conversion_result = converter.convert(temp_file_path)
|
||||
|
|
@ -140,13 +188,26 @@ async def process_knowledge_base_document(
|
|||
"document_type": type(doc).__name__,
|
||||
}
|
||||
|
||||
# Chunk the document
|
||||
# Step 2: Initialize tokenizer for chunking
|
||||
logger.info(
|
||||
f"Loading tokenizer: {TOKENIZER_MODEL} with max_tokens={max_tokens}"
|
||||
)
|
||||
tokenizer = HuggingFaceTokenizer(
|
||||
tokenizer=AutoTokenizer.from_pretrained(TOKENIZER_MODEL),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# Step 3: Initialize chunker
|
||||
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
|
||||
chunker = HybridChunker(tokenizer=tokenizer)
|
||||
|
||||
# Step 4: Chunk the document
|
||||
logger.info(f"Chunking document with max_tokens={max_tokens}")
|
||||
chunks = list(chunker.chunk(dl_doc=doc))
|
||||
total_chunks = len(chunks)
|
||||
logger.info(f"Generated {total_chunks} chunks")
|
||||
|
||||
# Process each chunk
|
||||
# Step 5: Process each chunk
|
||||
chunk_texts = []
|
||||
chunk_records = []
|
||||
token_counts = []
|
||||
|
|
@ -156,7 +217,9 @@ async def process_knowledge_base_document(
|
|||
contextualized_text = chunker.contextualize(chunk=chunk)
|
||||
|
||||
# Calculate token count
|
||||
text_to_tokenize = contextualized_text if contextualized_text else chunk_text
|
||||
text_to_tokenize = (
|
||||
contextualized_text if contextualized_text else chunk_text
|
||||
)
|
||||
token_count = len(
|
||||
tokenizer.tokenizer.encode(text_to_tokenize, add_special_tokens=False)
|
||||
)
|
||||
|
|
@ -176,7 +239,7 @@ async def process_knowledge_base_document(
|
|||
),
|
||||
}
|
||||
|
||||
# Create chunk record
|
||||
# Create chunk record (without embedding yet)
|
||||
chunk_record = KnowledgeBaseChunkModel(
|
||||
document_id=document_id,
|
||||
organization_id=organization_id,
|
||||
|
|
@ -184,8 +247,8 @@ async def process_knowledge_base_document(
|
|||
contextualized_text=contextualized_text,
|
||||
chunk_index=i,
|
||||
chunk_metadata=chunk_metadata,
|
||||
embedding_model=EMBED_MODEL_ID,
|
||||
embedding_dimension=EMBEDDING_DIMENSION,
|
||||
embedding_model=service.get_model_id(),
|
||||
embedding_dimension=service.get_embedding_dimension(),
|
||||
token_count=token_count,
|
||||
)
|
||||
|
||||
|
|
@ -196,29 +259,25 @@ async def process_knowledge_base_document(
|
|||
if token_counts:
|
||||
avg_tokens = sum(token_counts) / len(token_counts)
|
||||
min_tokens = min(token_counts)
|
||||
max_tokens = max(token_counts)
|
||||
logger.info(f"Chunk token statistics:")
|
||||
max_tokens_actual = max(token_counts)
|
||||
logger.info("Chunk token statistics:")
|
||||
logger.info(f" - Average: {avg_tokens:.1f} tokens")
|
||||
logger.info(f" - Min: {min_tokens} tokens")
|
||||
logger.info(f" - Max: {max_tokens} tokens")
|
||||
logger.info(f" - Max: {max_tokens_actual} tokens")
|
||||
|
||||
# Generate embeddings in batch
|
||||
logger.info("Generating embeddings")
|
||||
embeddings = embedding_model.encode(
|
||||
chunk_texts,
|
||||
show_progress_bar=False,
|
||||
convert_to_numpy=True,
|
||||
)
|
||||
# Step 6: Generate embeddings using the embedding service
|
||||
logger.info(f"Generating embeddings using {embedding_service}")
|
||||
embeddings = await service.embed_texts(chunk_texts)
|
||||
|
||||
# Attach embeddings to chunk records
|
||||
# Step 7: Attach embeddings to chunk records
|
||||
for chunk_record, embedding in zip(chunk_records, embeddings):
|
||||
chunk_record.embedding = embedding.tolist()
|
||||
chunk_record.embedding = embedding
|
||||
|
||||
# Save chunks in database
|
||||
# Step 8: Save chunks in database
|
||||
logger.info("Storing chunks in database")
|
||||
await db_client.create_chunks_batch(chunk_records)
|
||||
|
||||
# Update document status to completed
|
||||
# Step 9: Update document status to completed
|
||||
await db_client.update_document_status(
|
||||
document_id,
|
||||
"completed",
|
||||
|
|
|
|||
|
|
@ -7,8 +7,10 @@ import {
|
|||
ReactFlow,
|
||||
} from "@xyflow/react";
|
||||
import { BrushCleaning, Maximize2, Minus, Plus, Rocket, Settings, Variable } from 'lucide-react';
|
||||
import React, { useMemo, useState } from 'react';
|
||||
import React, { useEffect, useMemo, useState } from 'react';
|
||||
|
||||
import { listDocumentsApiV1KnowledgeBaseDocumentsGet, listToolsApiV1ToolsGet } from '@/client';
|
||||
import type { DocumentResponseSchema, ToolResponse } from '@/client/types.gen';
|
||||
import { FlowEdge, FlowNode, NodeType } from "@/components/flow/types";
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/tooltip';
|
||||
|
|
@ -63,6 +65,8 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
|
|||
const [isConfigurationsDialogOpen, setIsConfigurationsDialogOpen] = useState(false);
|
||||
const [isEmbedDialogOpen, setIsEmbedDialogOpen] = useState(false);
|
||||
const [isPhoneCallDialogOpen, setIsPhoneCallDialogOpen] = useState(false);
|
||||
const [documents, setDocuments] = useState<DocumentResponseSchema[] | undefined>(undefined);
|
||||
const [tools, setTools] = useState<ToolResponse[] | undefined>(undefined);
|
||||
|
||||
const {
|
||||
rfInstance,
|
||||
|
|
@ -95,6 +99,36 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
|
|||
getAccessToken
|
||||
});
|
||||
|
||||
// Fetch documents and tools once for the entire workflow
|
||||
useEffect(() => {
|
||||
const fetchData = async () => {
|
||||
try {
|
||||
const accessToken = await getAccessToken();
|
||||
|
||||
// Fetch documents
|
||||
const documentsResponse = await listDocumentsApiV1KnowledgeBaseDocumentsGet({
|
||||
headers: { Authorization: `Bearer ${accessToken}` },
|
||||
query: { limit: 100 },
|
||||
});
|
||||
if (documentsResponse.data) {
|
||||
setDocuments(documentsResponse.data.documents);
|
||||
}
|
||||
|
||||
// Fetch tools
|
||||
const toolsResponse = await listToolsApiV1ToolsGet({
|
||||
headers: { Authorization: `Bearer ${accessToken}` },
|
||||
});
|
||||
if (toolsResponse.data) {
|
||||
setTools(toolsResponse.data);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch documents and tools:', error);
|
||||
}
|
||||
};
|
||||
|
||||
fetchData();
|
||||
}, [getAccessToken]);
|
||||
|
||||
// Memoize defaultEdgeOptions to prevent unnecessary re-renders
|
||||
const defaultEdgeOptions = useMemo(() => ({
|
||||
animated: true,
|
||||
|
|
@ -102,7 +136,11 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
|
|||
}), []);
|
||||
|
||||
// Memoize the context value to prevent unnecessary re-renders
|
||||
const workflowContextValue = useMemo(() => ({ saveWorkflow }), [saveWorkflow]);
|
||||
const workflowContextValue = useMemo(() => ({
|
||||
saveWorkflow,
|
||||
documents,
|
||||
tools
|
||||
}), [saveWorkflow, documents, tools]);
|
||||
|
||||
return (
|
||||
<WorkflowProvider value={workflowContextValue}>
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
import { createContext, useContext } from 'react';
|
||||
|
||||
import type { DocumentResponseSchema, ToolResponse } from '@/client/types.gen';
|
||||
|
||||
interface WorkflowContextType {
|
||||
saveWorkflow: (updateWorkflowDefinition?: boolean) => Promise<void>;
|
||||
documents?: DocumentResponseSchema[];
|
||||
tools?: ToolResponse[];
|
||||
}
|
||||
|
||||
const WorkflowContext = createContext<WorkflowContextType | undefined>(undefined);
|
||||
|
|
@ -15,3 +19,8 @@ export const useWorkflow = () => {
|
|||
}
|
||||
return context;
|
||||
};
|
||||
|
||||
// Optional hook that doesn't throw if context is not available
|
||||
export const useWorkflowOptional = () => {
|
||||
return useContext(WorkflowContext);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
// This file is auto-generated by @hey-api/openapi-ts
|
||||
|
||||
import { type ClientOptions as DefaultClientOptions, type Config, createClient, createConfig } from '@hey-api/client-fetch';
|
||||
|
||||
import { createClientConfig } from '../lib/apiClient';
|
||||
import type { ClientOptions } from './types.gen';
|
||||
import { type Config, type ClientOptions as DefaultClientOptions, createClient, createConfig } from '@hey-api/client-fetch';
|
||||
import { createClientConfig } from '../lib/apiClient';
|
||||
|
||||
/**
|
||||
* The `createClientConfig()` function will be called on client initialization
|
||||
|
|
@ -17,4 +16,4 @@ export type CreateClientConfig<T extends DefaultClientOptions = ClientOptions> =
|
|||
|
||||
export const client = createClient(createClientConfig(createConfig<ClientOptions>({
|
||||
baseUrl: 'http://127.0.0.1:8000'
|
||||
})));
|
||||
})));
|
||||
|
|
@ -1,3 +1,3 @@
|
|||
// This file is auto-generated by @hey-api/openapi-ts
|
||||
export * from './sdk.gen';
|
||||
export * from './types.gen';
|
||||
export * from './sdk.gen';
|
||||
File diff suppressed because one or more lines are too long
|
|
@ -351,6 +351,11 @@ export type DefaultConfigurationsResponse = {
|
|||
[key: string]: unknown;
|
||||
};
|
||||
};
|
||||
embeddings: {
|
||||
[key: string]: {
|
||||
[key: string]: unknown;
|
||||
};
|
||||
};
|
||||
default_providers: {
|
||||
[key: string]: string;
|
||||
};
|
||||
|
|
@ -672,6 +677,10 @@ export type ProcessDocumentRequestSchema = {
|
|||
* S3 key of the uploaded file
|
||||
*/
|
||||
s3_key: string;
|
||||
/**
|
||||
* Embedding service to use for processing. Options: 'openai' (default, 1536-dim, requires API key) or 'sentence_transformer' (free, 384-dim)
|
||||
*/
|
||||
embedding_service?: 'sentence_transformer' | 'openai';
|
||||
};
|
||||
|
||||
export type S3SignedUrlResponse = {
|
||||
|
|
@ -924,6 +933,9 @@ export type UserConfigurationRequestResponseSchema = {
|
|||
stt?: {
|
||||
[key: string]: string | number;
|
||||
} | null;
|
||||
embeddings?: {
|
||||
[key: string]: string | number;
|
||||
} | null;
|
||||
test_phone_number?: string | null;
|
||||
timezone?: string | null;
|
||||
organization_pricing?: {
|
||||
|
|
@ -4493,4 +4505,4 @@ export type HealthApiV1HealthGetResponses = {
|
|||
|
||||
export type ClientOptions = {
|
||||
baseUrl: 'http://127.0.0.1:8000' | (string & {});
|
||||
};
|
||||
};
|
||||
|
|
@ -14,7 +14,7 @@ import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
|||
import { VoiceSelector } from "@/components/VoiceSelector";
|
||||
import { useUserConfig } from "@/context/UserConfigContext";
|
||||
|
||||
type ServiceSegment = "llm" | "tts" | "stt";
|
||||
type ServiceSegment = "llm" | "tts" | "stt" | "embeddings";
|
||||
|
||||
interface SchemaProperty {
|
||||
type?: string;
|
||||
|
|
@ -41,6 +41,7 @@ const TAB_CONFIG: { key: ServiceSegment; label: string }[] = [
|
|||
{ key: "llm", label: "LLM" },
|
||||
{ key: "tts", label: "Voice" },
|
||||
{ key: "stt", label: "Transcriber" },
|
||||
{ key: "embeddings", label: "Embedding" },
|
||||
];
|
||||
|
||||
// Display names for language codes (Deepgram + Sarvam)
|
||||
|
|
@ -109,12 +110,14 @@ export default function ServiceConfiguration() {
|
|||
const [schemas, setSchemas] = useState<Record<ServiceSegment, Record<string, ProviderSchema>>>({
|
||||
llm: {},
|
||||
tts: {},
|
||||
stt: {}
|
||||
stt: {},
|
||||
embeddings: {}
|
||||
});
|
||||
const [serviceProviders, setServiceProviders] = useState<Record<ServiceSegment, string>>({
|
||||
llm: "",
|
||||
tts: "",
|
||||
stt: ""
|
||||
stt: "",
|
||||
embeddings: ""
|
||||
});
|
||||
const [isManualModelInput, setIsManualModelInput] = useState(false);
|
||||
const [hasCheckedManualMode, setHasCheckedManualMode] = useState(false);
|
||||
|
|
@ -136,7 +139,8 @@ export default function ServiceConfiguration() {
|
|||
setSchemas({
|
||||
llm: response.data.llm as Record<string, ProviderSchema>,
|
||||
tts: response.data.tts as Record<string, ProviderSchema>,
|
||||
stt: response.data.stt as Record<string, ProviderSchema>
|
||||
stt: response.data.stt as Record<string, ProviderSchema>,
|
||||
embeddings: response.data.embeddings as Record<string, ProviderSchema>
|
||||
});
|
||||
} else {
|
||||
console.error("Failed to fetch configurations");
|
||||
|
|
@ -147,7 +151,8 @@ export default function ServiceConfiguration() {
|
|||
const selectedProviders: Record<ServiceSegment, string> = {
|
||||
llm: response.data.default_providers.llm,
|
||||
tts: response.data.default_providers.tts,
|
||||
stt: response.data.default_providers.stt
|
||||
stt: response.data.default_providers.stt,
|
||||
embeddings: response.data.default_providers.embeddings
|
||||
};
|
||||
|
||||
const setServicePropertyValues = (service: ServiceSegment) => {
|
||||
|
|
@ -173,6 +178,7 @@ export default function ServiceConfiguration() {
|
|||
setServicePropertyValues("llm");
|
||||
setServicePropertyValues("tts");
|
||||
setServicePropertyValues("stt");
|
||||
setServicePropertyValues("embeddings");
|
||||
|
||||
// IMPORTANT: Reset form values BEFORE changing providers
|
||||
// Otherwise, Radix Select sees old values that don't match new provider's enum
|
||||
|
|
@ -246,7 +252,7 @@ export default function ServiceConfiguration() {
|
|||
setApiError(null);
|
||||
setIsSaving(true);
|
||||
|
||||
const userConfig = {
|
||||
const userConfig: Record<ServiceSegment, Record<string, string | number>> = {
|
||||
llm: {
|
||||
provider: serviceProviders.llm,
|
||||
api_key: data.llm_api_key as string,
|
||||
|
|
@ -259,6 +265,11 @@ export default function ServiceConfiguration() {
|
|||
stt: {
|
||||
provider: serviceProviders.stt,
|
||||
api_key: data.stt_api_key as string
|
||||
},
|
||||
embeddings: {
|
||||
provider: serviceProviders.embeddings,
|
||||
api_key: data.embeddings_api_key as string,
|
||||
model: data.embeddings_model as string
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -273,12 +284,25 @@ export default function ServiceConfiguration() {
|
|||
}
|
||||
});
|
||||
|
||||
// Build save config - only include embeddings if api_key is provided
|
||||
const saveConfig: {
|
||||
llm: Record<string, string | number>;
|
||||
tts: Record<string, string | number>;
|
||||
stt: Record<string, string | number>;
|
||||
embeddings?: Record<string, string | number>;
|
||||
} = {
|
||||
llm: userConfig.llm,
|
||||
tts: userConfig.tts,
|
||||
stt: userConfig.stt
|
||||
};
|
||||
|
||||
// Only include embeddings if user has configured it (has api_key)
|
||||
if (userConfig.embeddings.api_key) {
|
||||
saveConfig.embeddings = userConfig.embeddings;
|
||||
}
|
||||
|
||||
try {
|
||||
await saveUserConfig({
|
||||
llm: userConfig.llm,
|
||||
tts: userConfig.tts,
|
||||
stt: userConfig.stt
|
||||
});
|
||||
await saveUserConfig(saveConfig);
|
||||
setApiError(null);
|
||||
} catch (error: unknown) {
|
||||
if (error instanceof Error) {
|
||||
|
|
@ -543,7 +567,7 @@ export default function ServiceConfiguration() {
|
|||
<Card>
|
||||
<CardContent className="pt-6">
|
||||
<Tabs defaultValue="llm" className="w-full">
|
||||
<TabsList className="grid w-full grid-cols-3 mb-6">
|
||||
<TabsList className="grid w-full grid-cols-4 mb-6">
|
||||
{TAB_CONFIG.map(({ key, label }) => (
|
||||
<TabsTrigger key={key} value={key}>
|
||||
{label}
|
||||
|
|
|
|||
|
|
@ -1,57 +1,55 @@
|
|||
"use client";
|
||||
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { listDocumentsApiV1KnowledgeBaseDocumentsGet } from "@/client/sdk.gen";
|
||||
import { useAuth } from "@/lib/auth";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
|
||||
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
|
||||
import type { DocumentResponseSchema } from "@/client/types.gen";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
|
||||
interface DocumentBadgesProps {
|
||||
documentUuids: string[];
|
||||
onStaleUuidsDetected?: (staleUuids: string[]) => void;
|
||||
}
|
||||
|
||||
export const DocumentBadges = ({ documentUuids }: DocumentBadgesProps) => {
|
||||
const { getAccessToken } = useAuth();
|
||||
export const DocumentBadges = ({ documentUuids, onStaleUuidsDetected }: DocumentBadgesProps) => {
|
||||
const { documents } = useWorkflow();
|
||||
const [documentNames, setDocumentNames] = useState<Record<string, string>>({});
|
||||
const [loading, setLoading] = useState(false);
|
||||
|
||||
const fetchDocuments = useCallback(async () => {
|
||||
if (documentUuids.length === 0) return;
|
||||
const processDocuments = useCallback((docs: DocumentResponseSchema[]) => {
|
||||
const nameMap: Record<string, string> = {};
|
||||
const validUuids = new Set<string>();
|
||||
|
||||
setLoading(true);
|
||||
try {
|
||||
const accessToken = await getAccessToken();
|
||||
const response = await listDocumentsApiV1KnowledgeBaseDocumentsGet({
|
||||
headers: { Authorization: `Bearer ${accessToken}` },
|
||||
query: {
|
||||
limit: 100,
|
||||
},
|
||||
docs
|
||||
.filter((doc) => documentUuids.includes(doc.document_uuid))
|
||||
.forEach((doc) => {
|
||||
nameMap[doc.document_uuid] = doc.filename;
|
||||
validUuids.add(doc.document_uuid);
|
||||
});
|
||||
setDocumentNames(nameMap);
|
||||
|
||||
if (response.data) {
|
||||
const nameMap: Record<string, string> = {};
|
||||
response.data.documents
|
||||
.filter((doc) => documentUuids.includes(doc.document_uuid))
|
||||
.forEach((doc) => {
|
||||
nameMap[doc.document_uuid] = doc.filename;
|
||||
});
|
||||
setDocumentNames(nameMap);
|
||||
// Detect stale UUIDs - this only runs when we have loaded data (not undefined)
|
||||
if (onStaleUuidsDetected) {
|
||||
const staleUuids = documentUuids.filter(uuid => !validUuids.has(uuid));
|
||||
if (staleUuids.length > 0) {
|
||||
onStaleUuidsDetected(staleUuids);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch documents:", error);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, [documentUuids, getAccessToken]);
|
||||
}, [documentUuids, onStaleUuidsDetected]);
|
||||
|
||||
useEffect(() => {
|
||||
fetchDocuments();
|
||||
}, [fetchDocuments]);
|
||||
if (documentUuids.length > 0 && documents !== undefined) {
|
||||
processDocuments(documents);
|
||||
} else if (documentUuids.length === 0) {
|
||||
setDocumentNames({});
|
||||
}
|
||||
}, [documentUuids, documents, processDocuments]);
|
||||
|
||||
if (documentUuids.length === 0) {
|
||||
return <></>;
|
||||
}
|
||||
|
||||
if (loading) {
|
||||
// Show loading while data hasn't loaded yet
|
||||
if (documents === undefined) {
|
||||
return <Badge variant="outline">Loading...</Badge>;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,18 +1,18 @@
|
|||
"use client";
|
||||
|
||||
import { FileText } from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { useMemo } from "react";
|
||||
|
||||
import type { DocumentResponseSchema } from "@/client/types.gen";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Checkbox } from "@/components/ui/checkbox";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { listDocumentsApiV1KnowledgeBaseDocumentsGet } from "@/client/sdk.gen";
|
||||
import type { DocumentResponseSchema } from "@/client/types.gen";
|
||||
import { useAuth } from "@/lib/auth";
|
||||
import { FileText } from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
|
||||
interface DocumentSelectorProps {
|
||||
value: string[];
|
||||
onChange: (uuids: string[]) => void;
|
||||
documents: DocumentResponseSchema[];
|
||||
disabled?: boolean;
|
||||
label?: string;
|
||||
description?: string;
|
||||
|
|
@ -22,43 +22,17 @@ interface DocumentSelectorProps {
|
|||
export const DocumentSelector = ({
|
||||
value,
|
||||
onChange,
|
||||
documents,
|
||||
disabled = false,
|
||||
label = "Knowledge Base Documents",
|
||||
description = "Select documents that the agent can reference during conversations.",
|
||||
showLabel = true,
|
||||
}: DocumentSelectorProps) => {
|
||||
const { getAccessToken } = useAuth();
|
||||
const [documents, setDocuments] = useState<DocumentResponseSchema[]>([]);
|
||||
const [loading, setLoading] = useState(false);
|
||||
|
||||
const fetchDocuments = useCallback(async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const accessToken = await getAccessToken();
|
||||
const response = await listDocumentsApiV1KnowledgeBaseDocumentsGet({
|
||||
headers: { Authorization: `Bearer ${accessToken}` },
|
||||
query: {
|
||||
limit: 100,
|
||||
},
|
||||
});
|
||||
|
||||
if (response.data) {
|
||||
// Only show completed documents
|
||||
const completedDocs = response.data.documents.filter(
|
||||
(doc) => doc.processing_status === "completed"
|
||||
);
|
||||
setDocuments(completedDocs);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch documents:", error);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, [getAccessToken]);
|
||||
|
||||
useEffect(() => {
|
||||
fetchDocuments();
|
||||
}, [fetchDocuments]);
|
||||
// Only show completed documents
|
||||
const completedDocuments = useMemo(
|
||||
() => documents.filter((doc) => doc.processing_status === "completed"),
|
||||
[documents]
|
||||
);
|
||||
|
||||
const handleToggle = (documentUuid: string, checked: boolean) => {
|
||||
if (checked) {
|
||||
|
|
@ -76,25 +50,7 @@ export const DocumentSelector = ({
|
|||
return Math.round(bytes / Math.pow(k, i) * 100) / 100 + " " + sizes[i];
|
||||
};
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
{showLabel && (
|
||||
<>
|
||||
<Label>{label}</Label>
|
||||
{description && (
|
||||
<Label className="text-xs text-muted-foreground">{description}</Label>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
<div className="border rounded-md p-4 text-sm text-muted-foreground text-center">
|
||||
Loading documents...
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (documents.length === 0) {
|
||||
if (completedDocuments.length === 0) {
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
{showLabel && (
|
||||
|
|
@ -133,7 +89,7 @@ export const DocumentSelector = ({
|
|||
)}
|
||||
<div className="border rounded-md max-h-[300px] overflow-y-auto">
|
||||
<div className="divide-y">
|
||||
{documents.map((doc) => (
|
||||
{completedDocuments.map((doc) => (
|
||||
<div
|
||||
key={doc.document_uuid}
|
||||
className="flex items-start gap-3 p-3 hover:bg-muted/50 transition-colors"
|
||||
|
|
|
|||
|
|
@ -2,43 +2,43 @@
|
|||
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
|
||||
import { listToolsApiV1ToolsGet } from "@/client/sdk.gen";
|
||||
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
|
||||
import type { ToolResponse } from "@/client/types.gen";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { useAuth } from "@/lib/auth";
|
||||
|
||||
interface ToolBadgesProps {
|
||||
toolUuids: string[];
|
||||
onStaleUuidsDetected?: (staleUuids: string[]) => void;
|
||||
}
|
||||
|
||||
export function ToolBadges({ toolUuids }: ToolBadgesProps) {
|
||||
const { getAccessToken } = useAuth();
|
||||
const [tools, setTools] = useState<ToolResponse[]>([]);
|
||||
export function ToolBadges({ toolUuids, onStaleUuidsDetected }: ToolBadgesProps) {
|
||||
const { tools } = useWorkflow();
|
||||
const [selectedTools, setSelectedTools] = useState<ToolResponse[]>([]);
|
||||
|
||||
const fetchTools = useCallback(async () => {
|
||||
try {
|
||||
const accessToken = await getAccessToken();
|
||||
const response = await listToolsApiV1ToolsGet({
|
||||
headers: { Authorization: `Bearer ${accessToken}` },
|
||||
});
|
||||
if (response.data) {
|
||||
setTools(response.data);
|
||||
const processTools = useCallback((toolsData: ToolResponse[]) => {
|
||||
const filtered = toolsData.filter(tool => toolUuids.includes(tool.tool_uuid));
|
||||
setSelectedTools(filtered);
|
||||
|
||||
// Detect stale UUIDs - this only runs when we have loaded data (not undefined)
|
||||
if (onStaleUuidsDetected) {
|
||||
const validUuids = new Set(toolsData.map(tool => tool.tool_uuid));
|
||||
const staleUuids = toolUuids.filter(uuid => !validUuids.has(uuid));
|
||||
if (staleUuids.length > 0) {
|
||||
onStaleUuidsDetected(staleUuids);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch tools:", error);
|
||||
}
|
||||
}, [getAccessToken]);
|
||||
}, [toolUuids, onStaleUuidsDetected]);
|
||||
|
||||
useEffect(() => {
|
||||
if (toolUuids.length > 0) {
|
||||
fetchTools();
|
||||
if (toolUuids.length > 0 && tools !== undefined) {
|
||||
processTools(tools);
|
||||
} else if (toolUuids.length === 0) {
|
||||
setSelectedTools([]);
|
||||
}
|
||||
}, [toolUuids.length, fetchTools]);
|
||||
}, [toolUuids, tools, processTools]);
|
||||
|
||||
const selectedTools = tools.filter((tool) => toolUuids.includes(tool.tool_uuid));
|
||||
|
||||
if (selectedTools.length === 0 && toolUuids.length > 0) {
|
||||
// Still loading or tools not found
|
||||
// Show loading while data hasn't loaded yet
|
||||
if (tools === undefined && toolUuids.length > 0) {
|
||||
return (
|
||||
<div className="flex flex-wrap gap-1">
|
||||
<Badge variant="outline" className="text-xs">
|
||||
|
|
|
|||
|
|
@ -1,20 +1,18 @@
|
|||
"use client";
|
||||
|
||||
import { ExternalLink, Loader2 } from "lucide-react";
|
||||
import { ExternalLink } from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
|
||||
import { renderToolIcon } from "@/app/tools/config";
|
||||
import { listToolsApiV1ToolsGet } from "@/client/sdk.gen";
|
||||
import type { ToolResponse } from "@/client/types.gen";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Checkbox } from "@/components/ui/checkbox";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { useAuth } from "@/lib/auth";
|
||||
|
||||
interface ToolSelectorProps {
|
||||
value: string[];
|
||||
onChange: (uuids: string[]) => void;
|
||||
tools: ToolResponse[];
|
||||
disabled?: boolean;
|
||||
label?: string;
|
||||
description?: string;
|
||||
|
|
@ -24,43 +22,14 @@ interface ToolSelectorProps {
|
|||
export function ToolSelector({
|
||||
value,
|
||||
onChange,
|
||||
tools,
|
||||
disabled = false,
|
||||
label = "Tools",
|
||||
description = "Select tools that the agent can use during the conversation.",
|
||||
showLabel = true,
|
||||
}: ToolSelectorProps) {
|
||||
const { getAccessToken } = useAuth();
|
||||
|
||||
const [tools, setTools] = useState<ToolResponse[]>([]);
|
||||
const [loading, setLoading] = useState(false);
|
||||
|
||||
const fetchTools = useCallback(async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const accessToken = await getAccessToken();
|
||||
const response = await listToolsApiV1ToolsGet({
|
||||
headers: { Authorization: `Bearer ${accessToken}` },
|
||||
query: { status: "active" },
|
||||
});
|
||||
if (response.error) {
|
||||
console.error("Failed to fetch tools:", response.error);
|
||||
setTools([]);
|
||||
return;
|
||||
}
|
||||
if (response.data) {
|
||||
setTools(response.data);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch tools:", error);
|
||||
setTools([]);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, [getAccessToken]);
|
||||
|
||||
useEffect(() => {
|
||||
fetchTools();
|
||||
}, [fetchTools]);
|
||||
// Filter to only show active tools
|
||||
const activeTools = tools.filter((tool) => tool.status === "active");
|
||||
|
||||
const handleToggle = (toolUuid: string, checked: boolean) => {
|
||||
if (checked) {
|
||||
|
|
@ -83,12 +52,7 @@ export function ToolSelector({
|
|||
</>
|
||||
)}
|
||||
|
||||
{loading ? (
|
||||
<div className="flex items-center gap-2 p-3 border rounded-md">
|
||||
<Loader2 className="h-4 w-4 animate-spin" />
|
||||
<span className="text-sm text-muted-foreground">Loading tools...</span>
|
||||
</div>
|
||||
) : tools.length === 0 ? (
|
||||
{activeTools.length === 0 ? (
|
||||
<div className="p-4 border rounded-md text-center">
|
||||
<p className="text-sm text-muted-foreground mb-2">
|
||||
No tools available.
|
||||
|
|
@ -102,7 +66,7 @@ export function ToolSelector({
|
|||
</div>
|
||||
) : (
|
||||
<div className="border rounded-md divide-y">
|
||||
{tools.map((tool) => {
|
||||
{activeTools.map((tool) => {
|
||||
const isSelected = value.includes(tool.tool_uuid);
|
||||
return (
|
||||
<label
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import { NodeProps, NodeToolbar, Position } from "@xyflow/react";
|
||||
import { Edit, FileText, Headset, PlusIcon, Trash2Icon, Wrench } from "lucide-react";
|
||||
import { memo, useEffect, useMemo, useState } from "react";
|
||||
import { memo, useCallback, useEffect, useMemo, useState } from "react";
|
||||
|
||||
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
|
||||
import type { DocumentResponseSchema, ToolResponse } from "@/client/types.gen";
|
||||
import { DocumentBadges } from "@/components/flow/DocumentBadges";
|
||||
import { DocumentSelector } from "@/components/flow/DocumentSelector";
|
||||
import { ToolBadges } from "@/components/flow/ToolBadges";
|
||||
|
|
@ -38,6 +39,8 @@ interface AgentNodeEditFormProps {
|
|||
setToolUuids: (value: string[]) => void;
|
||||
documentUuids: string[];
|
||||
setDocumentUuids: (value: string[]) => void;
|
||||
tools: ToolResponse[];
|
||||
documents: DocumentResponseSchema[];
|
||||
}
|
||||
|
||||
interface AgentNodeProps extends NodeProps {
|
||||
|
|
@ -46,7 +49,7 @@ interface AgentNodeProps extends NodeProps {
|
|||
|
||||
export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
|
||||
const { open, setOpen, handleSaveNodeData, handleDeleteNode } = useNodeHandlers({ id });
|
||||
const { saveWorkflow } = useWorkflow();
|
||||
const { saveWorkflow, tools, documents } = useWorkflow();
|
||||
|
||||
// Form state
|
||||
const [prompt, setPrompt] = useState(data.prompt);
|
||||
|
|
@ -120,6 +123,30 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
|
|||
}
|
||||
}, [data, open]);
|
||||
|
||||
// Handle cleanup of stale document UUIDs
|
||||
const handleStaleDocuments = useCallback((staleUuids: string[]) => {
|
||||
const cleanedUuids = (data.document_uuids ?? []).filter(uuid => !staleUuids.includes(uuid));
|
||||
handleSaveNodeData({
|
||||
...data,
|
||||
document_uuids: cleanedUuids.length > 0 ? cleanedUuids : undefined,
|
||||
});
|
||||
setTimeout(async () => {
|
||||
await saveWorkflow();
|
||||
}, 100);
|
||||
}, [data, handleSaveNodeData, saveWorkflow]);
|
||||
|
||||
// Handle cleanup of stale tool UUIDs
|
||||
const handleStaleTools = useCallback((staleUuids: string[]) => {
|
||||
const cleanedUuids = (data.tool_uuids ?? []).filter(uuid => !staleUuids.includes(uuid));
|
||||
handleSaveNodeData({
|
||||
...data,
|
||||
tool_uuids: cleanedUuids.length > 0 ? cleanedUuids : undefined,
|
||||
});
|
||||
setTimeout(async () => {
|
||||
await saveWorkflow();
|
||||
}, 100);
|
||||
}, [data, handleSaveNodeData, saveWorkflow]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<NodeContent
|
||||
|
|
@ -144,7 +171,7 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
|
|||
<Wrench className="h-3 w-3" />
|
||||
<span>Tools:</span>
|
||||
</div>
|
||||
<ToolBadges toolUuids={data.tool_uuids} />
|
||||
<ToolBadges toolUuids={data.tool_uuids} onStaleUuidsDetected={handleStaleTools} />
|
||||
</div>
|
||||
)}
|
||||
{data.document_uuids && data.document_uuids.length > 0 && (
|
||||
|
|
@ -153,7 +180,7 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
|
|||
<FileText className="h-3 w-3" />
|
||||
<span>Documents:</span>
|
||||
</div>
|
||||
<DocumentBadges documentUuids={data.document_uuids} />
|
||||
<DocumentBadges documentUuids={data.document_uuids} onStaleUuidsDetected={handleStaleDocuments} />
|
||||
</div>
|
||||
)}
|
||||
</NodeContent>
|
||||
|
|
@ -198,6 +225,8 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
|
|||
setToolUuids={setToolUuids}
|
||||
documentUuids={documentUuids}
|
||||
setDocumentUuids={setDocumentUuids}
|
||||
tools={tools ?? []}
|
||||
documents={documents ?? []}
|
||||
/>
|
||||
)}
|
||||
</NodeEditDialog>
|
||||
|
|
@ -224,6 +253,8 @@ const AgentNodeEditForm = ({
|
|||
setToolUuids,
|
||||
documentUuids,
|
||||
setDocumentUuids,
|
||||
tools,
|
||||
documents,
|
||||
}: AgentNodeEditFormProps) => {
|
||||
const handleVariableNameChange = (idx: number, value: string) => {
|
||||
const newVars = [...variables];
|
||||
|
|
@ -364,6 +395,7 @@ const AgentNodeEditForm = ({
|
|||
<ToolSelector
|
||||
value={toolUuids}
|
||||
onChange={setToolUuids}
|
||||
tools={tools}
|
||||
description="Select tools that the agent can invoke during this conversation step."
|
||||
/>
|
||||
</div>
|
||||
|
|
@ -373,6 +405,7 @@ const AgentNodeEditForm = ({
|
|||
<DocumentSelector
|
||||
value={documentUuids}
|
||||
onChange={setDocumentUuids}
|
||||
documents={documents}
|
||||
description="Select documents from the knowledge base that the agent can reference during this conversation step."
|
||||
/>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import { NodeProps, NodeToolbar, Position } from "@xyflow/react";
|
||||
import { Edit, FileText, Play, PlusIcon, Trash2Icon, Wrench } from "lucide-react";
|
||||
import { memo, useEffect, useMemo, useState } from "react";
|
||||
import { memo, useCallback, useEffect, useMemo, useState } from "react";
|
||||
|
||||
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
|
||||
import type { DocumentResponseSchema, ToolResponse } from "@/client/types.gen";
|
||||
import { DocumentBadges } from "@/components/flow/DocumentBadges";
|
||||
import { DocumentSelector } from "@/components/flow/DocumentSelector";
|
||||
import { ToolBadges } from "@/components/flow/ToolBadges";
|
||||
|
|
@ -45,6 +46,8 @@ interface StartCallEditFormProps {
|
|||
setToolUuids: (value: string[]) => void;
|
||||
documentUuids: string[];
|
||||
setDocumentUuids: (value: string[]) => void;
|
||||
tools: ToolResponse[];
|
||||
documents: DocumentResponseSchema[];
|
||||
}
|
||||
|
||||
interface StartCallNodeProps extends NodeProps {
|
||||
|
|
@ -56,7 +59,7 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
id,
|
||||
additionalData: { is_start: true }
|
||||
});
|
||||
const { saveWorkflow } = useWorkflow();
|
||||
const { saveWorkflow, tools, documents } = useWorkflow();
|
||||
|
||||
// Form state
|
||||
const [prompt, setPrompt] = useState(data.prompt ?? "");
|
||||
|
|
@ -140,6 +143,30 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
}
|
||||
}, [data, open]);
|
||||
|
||||
// Handle cleanup of stale document UUIDs
|
||||
const handleStaleDocuments = useCallback((staleUuids: string[]) => {
|
||||
const cleanedUuids = (data.document_uuids ?? []).filter(uuid => !staleUuids.includes(uuid));
|
||||
handleSaveNodeData({
|
||||
...data,
|
||||
document_uuids: cleanedUuids.length > 0 ? cleanedUuids : undefined,
|
||||
});
|
||||
setTimeout(async () => {
|
||||
await saveWorkflow();
|
||||
}, 100);
|
||||
}, [data, handleSaveNodeData, saveWorkflow]);
|
||||
|
||||
// Handle cleanup of stale tool UUIDs
|
||||
const handleStaleTools = useCallback((staleUuids: string[]) => {
|
||||
const cleanedUuids = (data.tool_uuids ?? []).filter(uuid => !staleUuids.includes(uuid));
|
||||
handleSaveNodeData({
|
||||
...data,
|
||||
tool_uuids: cleanedUuids.length > 0 ? cleanedUuids : undefined,
|
||||
});
|
||||
setTimeout(async () => {
|
||||
await saveWorkflow();
|
||||
}, 100);
|
||||
}, [data, handleSaveNodeData, saveWorkflow]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<NodeContent
|
||||
|
|
@ -163,7 +190,7 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
<Wrench className="h-3 w-3" />
|
||||
<span>Tools:</span>
|
||||
</div>
|
||||
<ToolBadges toolUuids={data.tool_uuids} />
|
||||
<ToolBadges toolUuids={data.tool_uuids} onStaleUuidsDetected={handleStaleTools} />
|
||||
</div>
|
||||
)}
|
||||
{data.document_uuids && data.document_uuids.length > 0 && (
|
||||
|
|
@ -172,7 +199,7 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
<FileText className="h-3 w-3" />
|
||||
<span>Documents:</span>
|
||||
</div>
|
||||
<DocumentBadges documentUuids={data.document_uuids} />
|
||||
<DocumentBadges documentUuids={data.document_uuids} onStaleUuidsDetected={handleStaleDocuments} />
|
||||
</div>
|
||||
)}
|
||||
</NodeContent>
|
||||
|
|
@ -218,6 +245,8 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
setToolUuids={setToolUuids}
|
||||
documentUuids={documentUuids}
|
||||
setDocumentUuids={setDocumentUuids}
|
||||
tools={tools ?? []}
|
||||
documents={documents ?? []}
|
||||
/>
|
||||
)}
|
||||
</NodeEditDialog>
|
||||
|
|
@ -250,6 +279,8 @@ const StartCallEditForm = ({
|
|||
setToolUuids,
|
||||
documentUuids,
|
||||
setDocumentUuids,
|
||||
tools,
|
||||
documents,
|
||||
}: StartCallEditFormProps) => {
|
||||
const handleVariableNameChange = (idx: number, value: string) => {
|
||||
const newVars = [...variables];
|
||||
|
|
@ -435,6 +466,7 @@ const StartCallEditForm = ({
|
|||
<ToolSelector
|
||||
value={toolUuids}
|
||||
onChange={setToolUuids}
|
||||
tools={tools}
|
||||
description="Select tools that the agent can invoke during this conversation step."
|
||||
/>
|
||||
</div>
|
||||
|
|
@ -444,6 +476,7 @@ const StartCallEditForm = ({
|
|||
<DocumentSelector
|
||||
value={documentUuids}
|
||||
onChange={setDocumentUuids}
|
||||
documents={documents}
|
||||
description="Select documents from the knowledge base that the agent can reference during this conversation step."
|
||||
/>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -18,6 +18,9 @@ export type SaveUserConfigFunctionParams = {
|
|||
stt?: {
|
||||
[key: string]: string | number;
|
||||
} | null;
|
||||
embeddings?: {
|
||||
[key: string]: string | number;
|
||||
} | null;
|
||||
test_phone_number?: string | null;
|
||||
timezone?: string | null;
|
||||
};
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue