feat: knowledge base functionality for the voice agent (#120)

* feat: upload file and store embedding

* feat: add documents in nodes

* feat: add openai embedding service
This commit is contained in:
Abhishek 2026-01-17 14:37:03 +05:30 committed by GitHub
parent e2fa4bbb98
commit ef5b9e40a9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
52 changed files with 4551 additions and 114 deletions

View file

@ -48,6 +48,12 @@ class UserConfigurationValidator:
status_list.extend(self._validate_service(configuration.llm, "llm"))
status_list.extend(self._validate_service(configuration.stt, "stt"))
status_list.extend(self._validate_service(configuration.tts, "tts"))
# Embeddings is optional - only validate if configured
status_list.extend(
self._validate_service(
configuration.embeddings, "embeddings", required=False
)
)
if status_list:
raise ValueError(status_list)
@ -55,11 +61,16 @@ class UserConfigurationValidator:
return {"status": [{"model": "all", "message": "ok"}]}
def _validate_service(
self, service_config: Optional[ServiceConfig], service_name: str
self,
service_config: Optional[ServiceConfig],
service_name: str,
required: bool = True,
) -> list[APIKeyStatus]:
"""Validate a service configuration and return any error statuses."""
if not service_config:
return [{"model": service_name, "message": "API key is missing"}]
if required:
return [{"model": service_name, "message": "API key is missing"}]
return [] # Optional service not configured is OK
provider = service_config.provider
api_key = service_config.api_key

View file

@ -13,6 +13,7 @@ left as ``None``.
from api.services.configuration.registry import (
DeepgramSTTConfiguration,
ElevenlabsTTSConfiguration,
OpenAIEmbeddingsConfiguration,
OpenAILLMService,
ServiceProviders,
)
@ -22,6 +23,7 @@ _DEFAULTS = {
"llm": (ServiceProviders.OPENAI, OpenAILLMService),
"tts": (ServiceProviders.ELEVENLABS, ElevenlabsTTSConfiguration),
"stt": (ServiceProviders.DEEPGRAM, DeepgramSTTConfiguration),
"embeddings": (ServiceProviders.OPENAI, OpenAIEmbeddingsConfiguration),
}
# Public mapping of service name -> default provider

View file

@ -64,6 +64,7 @@ def mask_user_config(config: UserConfiguration) -> Dict[str, Any]:
"llm": _mask_service(config.llm),
"tts": _mask_service(config.tts),
"stt": _mask_service(config.stt),
"embeddings": _mask_service(config.embeddings),
"test_phone_number": config.test_phone_number,
"timezone": config.timezone,
}

View file

@ -9,7 +9,7 @@ from typing import Dict
from api.schemas.user_configuration import UserConfiguration
from api.services.configuration.masking import is_mask_of
SERVICE_FIELDS = ("llm", "tts", "stt")
SERVICE_FIELDS = ("llm", "tts", "stt", "embeddings")
def merge_user_configurations(

View file

@ -8,6 +8,7 @@ class ServiceType(Enum):
LLM = auto()
TTS = auto()
STT = auto()
EMBEDDINGS = auto()
class ServiceProviders(str, Enum):
@ -50,11 +51,16 @@ class BaseSTTConfiguration(BaseServiceConfiguration):
model: str
class BaseEmbeddingsConfiguration(BaseServiceConfiguration):
model: str
# Unified registry for all service types
REGISTRY: Dict[ServiceType, Dict[str, Type[BaseServiceConfiguration]]] = {
ServiceType.LLM: {},
ServiceType.TTS: {},
ServiceType.STT: {},
ServiceType.EMBEDDINGS: {},
}
T = TypeVar("T", bound=BaseServiceConfiguration)
@ -93,6 +99,10 @@ def register_stt(cls: Type[BaseSTTConfiguration]):
return register_service(ServiceType.STT)(cls)
def register_embeddings(cls: Type[BaseEmbeddingsConfiguration]):
return register_service(ServiceType.EMBEDDINGS)(cls)
###################################################### LLM ########################################################################
# Suggested models for each provider (used for UI dropdown)
@ -436,6 +446,27 @@ STTConfig = Annotated[
Field(discriminator="provider"),
]
ServiceConfig = Annotated[
Union[LLMConfig, TTSConfig, STTConfig], Field(discriminator="provider")
###################################################### EMBEDDINGS ########################################################################
OPENAI_EMBEDDING_MODELS = ["text-embedding-3-small"]
@register_embeddings
class OpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
provider: Literal[ServiceProviders.OPENAI] = ServiceProviders.OPENAI
model: str = Field(
default="text-embedding-3-small",
json_schema_extra={"examples": OPENAI_EMBEDDING_MODELS},
)
api_key: str
EmbeddingsConfig = Annotated[
Union[OpenAIEmbeddingsConfiguration],
Field(discriminator="provider"),
]
ServiceConfig = Annotated[
Union[LLMConfig, TTSConfig, STTConfig, EmbeddingsConfig],
Field(discriminator="provider"),
]

View file

@ -85,3 +85,16 @@ class BaseFileSystem(ABC):
Optional[str]: Presigned PUT URL if successful, None otherwise
"""
pass
@abstractmethod
async def adownload_file(self, source_path: str, local_path: str) -> bool:
"""Download a file from storage to local path.
Args:
source_path: Path to the file in storage
local_path: Local path where file should be downloaded
Returns:
bool: True if file was downloaded successfully, False otherwise
"""
pass

View file

@ -170,3 +170,15 @@ class MinioFileSystem(BaseFileSystem):
except Exception as e:
logger.error(f"Error generating MinIO upload URL: {e}")
return None
async def adownload_file(self, source_path: str, local_path: str) -> bool:
"""Download a file from MinIO to local path."""
try:
def _fget():
self.client.fget_object(self.bucket_name, source_path, local_path)
await asyncio.to_thread(_fget)
return True
except S3Error:
return False

View file

@ -126,3 +126,14 @@ class S3FileSystem(BaseFileSystem):
return url
except ClientError:
return None
async def adownload_file(self, source_path: str, local_path: str) -> bool:
"""Download a file from S3 to local path."""
try:
async with self.session.client(
"s3", region_name=self.region_name
) as s3_client:
await s3_client.download_file(self.bucket_name, source_path, local_path)
return True
except ClientError:
return False

View file

@ -0,0 +1,15 @@
"""Generative AI services for embeddings and document processing."""
from .embedding import (
BaseEmbeddingService,
EmbeddingAPIKeyNotConfiguredError,
OpenAIEmbeddingService,
SentenceTransformerEmbeddingService,
)
__all__ = [
"BaseEmbeddingService",
"EmbeddingAPIKeyNotConfiguredError",
"SentenceTransformerEmbeddingService",
"OpenAIEmbeddingService",
]

View file

@ -0,0 +1,12 @@
"""Embedding services for document processing and retrieval."""
from .base import BaseEmbeddingService
from .openai_service import EmbeddingAPIKeyNotConfiguredError, OpenAIEmbeddingService
from .sentence_transformer_service import SentenceTransformerEmbeddingService
__all__ = [
"BaseEmbeddingService",
"EmbeddingAPIKeyNotConfiguredError",
"SentenceTransformerEmbeddingService",
"OpenAIEmbeddingService",
]

View file

@ -0,0 +1,75 @@
"""Base class for embedding services."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
class BaseEmbeddingService(ABC):
"""Abstract base class for embedding services.
All embedding services (SentenceTransformer, OpenAI, etc.) should inherit from this class
and implement the required methods.
"""
@abstractmethod
def get_model_id(self) -> str:
"""Return the model identifier.
Returns:
String identifier for the model (e.g., 'sentence-transformers/all-MiniLM-L6-v2')
"""
pass
@abstractmethod
def get_embedding_dimension(self) -> int:
"""Return the embedding dimension.
Returns:
Integer dimension of the embedding vectors
"""
pass
@abstractmethod
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Embed a batch of texts.
Args:
texts: List of text strings to embed
Returns:
List of embedding vectors (each vector is a list of floats)
"""
pass
@abstractmethod
async def embed_query(self, query: str) -> List[float]:
"""Embed a single query text.
Args:
query: Query text to embed
Returns:
Embedding vector as list of floats
"""
pass
@abstractmethod
async def search_similar_chunks(
self,
query: str,
organization_id: int,
limit: int = 5,
document_uuids: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""Search for similar chunks using vector similarity.
Args:
query: Search query text
organization_id: Organization ID for scoping
limit: Maximum number of results to return
document_uuids: Optional list of document UUIDs to filter by
Returns:
List of dictionaries containing chunk data and similarity scores
"""
pass

View file

@ -0,0 +1,372 @@
"""OpenAI embedding service.
This module provides document processing capabilities using:
- OpenAI's text-embedding-3-small for embeddings (1536 dimensions)
- Docling for document conversion and chunking
- pgvector for vector similarity search
"""
import os
from pathlib import Path
from typing import Any, Dict, List, Optional
from docling.chunking import HybridChunker
from docling.document_converter import DocumentConverter
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
from loguru import logger
from openai import AsyncOpenAI
from transformers import AutoTokenizer
from api.db.db_client import DBClient
from api.db.models import KnowledgeBaseChunkModel
from .base import BaseEmbeddingService
# Model configuration
DEFAULT_MODEL_ID = "text-embedding-3-small"
EMBEDDING_DIMENSION = 1536 # Dimension for text-embedding-3-small
# For chunking, we'll use the same tokenizer as SentenceTransformer
# since OpenAI uses similar tokenization
TOKENIZER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
class EmbeddingAPIKeyNotConfiguredError(Exception):
"""Raised when OpenAI API key is not configured for embeddings."""
def __init__(self):
super().__init__(
"OpenAI API key not configured. Please set your API key in "
"Model Configurations > Embedding to use document processing."
)
class OpenAIEmbeddingService(BaseEmbeddingService):
"""Embedding service using OpenAI's text-embedding-3-small."""
def __init__(
self,
db_client: DBClient,
api_key: Optional[str] = None,
model_id: str = DEFAULT_MODEL_ID,
max_tokens: int = 512,
):
"""Initialize the OpenAI embedding service.
Args:
db_client: Database client for storing documents and chunks
api_key: OpenAI API key. If not provided, the client will not be
initialized and operations will fail with a clear error.
model_id: OpenAI embedding model ID (default: text-embedding-3-small)
max_tokens: Maximum number of tokens per chunk (default: 512)
"""
self.db = db_client
self.model_id = model_id
self.max_tokens = max_tokens
# Only initialize OpenAI client if API key is provided
self._api_key_configured = bool(api_key)
if self._api_key_configured:
self.client = AsyncOpenAI(api_key=api_key)
logger.info(f"OpenAI embedding service initialized with model: {model_id}")
else:
self.client = None
logger.warning(
"OpenAI embedding service initialized without API key. "
"Operations will fail until API key is configured in Model Configurations."
)
# Initialize tokenizer for chunking
# We use a HuggingFace tokenizer for consistent chunking
logger.info(
f"Loading tokenizer for chunking: {TOKENIZER_MODEL} with max_tokens={max_tokens}"
)
try:
self.tokenizer = HuggingFaceTokenizer(
tokenizer=AutoTokenizer.from_pretrained(
TOKENIZER_MODEL,
local_files_only=True,
),
max_tokens=max_tokens,
)
logger.info("Loaded tokenizer from cache")
except Exception as e:
logger.warning(f"Tokenizer not in cache, downloading: {e}")
self.tokenizer = HuggingFaceTokenizer(
tokenizer=AutoTokenizer.from_pretrained(TOKENIZER_MODEL),
max_tokens=max_tokens,
)
logger.info("Tokenizer downloaded and cached")
# Initialize chunker
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
self.chunker = HybridChunker(tokenizer=self.tokenizer)
# Initialize document converter
self.converter = DocumentConverter()
def get_model_id(self) -> str:
"""Return the model identifier."""
return self.model_id
def get_embedding_dimension(self) -> int:
"""Return the embedding dimension."""
return EMBEDDING_DIMENSION
def _ensure_api_key_configured(self):
"""Check if API key is configured and raise error if not."""
if not self._api_key_configured or self.client is None:
raise EmbeddingAPIKeyNotConfiguredError()
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Embed a batch of texts using OpenAI API.
Args:
texts: List of text strings to embed
Returns:
List of embedding vectors (each vector is a list of floats)
Raises:
EmbeddingAPIKeyNotConfiguredError: If API key is not configured
"""
self._ensure_api_key_configured()
try:
# OpenAI API call
response = await self.client.embeddings.create(
input=texts,
model=self.model_id,
)
# Extract embeddings from response
embeddings = [item.embedding for item in response.data]
return embeddings
except Exception as e:
logger.error(f"Error generating OpenAI embeddings: {e}")
raise
async def embed_query(self, query: str) -> List[float]:
"""Embed a single query text using OpenAI API.
Args:
query: Query text to embed
Returns:
Embedding vector as list of floats
Raises:
EmbeddingAPIKeyNotConfiguredError: If API key is not configured
"""
self._ensure_api_key_configured()
embeddings = await self.embed_texts([query])
return embeddings[0]
async def search_similar_chunks(
self,
query: str,
organization_id: int,
limit: int = 5,
document_uuids: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""Search for similar chunks using vector similarity.
Args:
query: Search query text
organization_id: Organization ID for scoping
limit: Maximum number of results to return
document_uuids: Optional list of document UUIDs to filter by
Returns:
List of dictionaries with chunk data and similarity scores
Raises:
EmbeddingAPIKeyNotConfiguredError: If API key is not configured
"""
self._ensure_api_key_configured()
# Generate query embedding
query_embedding = await self.embed_query(query)
# Perform vector similarity search
results = await self.db.search_similar_chunks(
query_embedding=query_embedding,
organization_id=organization_id,
limit=limit,
document_uuids=document_uuids,
embedding_model=self.model_id,
)
return results
async def process_document(
self,
file_path: str,
organization_id: int,
created_by: int,
custom_metadata: dict = None,
):
"""Process a document: convert, chunk, embed, and store in database.
Args:
file_path: Path to the document file
organization_id: Organization ID for scoping
created_by: User ID who uploaded the document
custom_metadata: Optional custom metadata dictionary
Returns:
The created document record
"""
try:
# Extract file metadata
filename = Path(file_path).name
file_hash = self.db.compute_file_hash(file_path)
file_size = os.path.getsize(file_path)
mime_type = self.db.get_mime_type(file_path)
# Check if document already exists
existing_doc = await self.db.get_document_by_hash(
file_hash, organization_id
)
if existing_doc:
logger.info(f"Document already exists: {filename} (hash: {file_hash})")
return existing_doc
# Create document record
doc_record = await self.db.create_document(
organization_id=organization_id,
created_by=created_by,
filename=filename,
file_size_bytes=file_size,
file_hash=file_hash,
mime_type=mime_type,
custom_metadata=custom_metadata or {},
)
logger.info(f"Processing document with OpenAI embeddings: {filename}")
# Update status to processing
await self.db.update_document_status(doc_record.id, "processing")
# Step 1: Convert document using docling
logger.info("Converting document with docling...")
conversion_result = self.converter.convert(file_path)
doc = conversion_result.document
# Store docling metadata
docling_metadata = {
"num_pages": len(doc.pages) if hasattr(doc, "pages") else None,
"document_type": type(doc).__name__,
}
# Step 2: Chunk the document
logger.info(f"Chunking document with max_tokens={self.max_tokens}...")
chunks = list(self.chunker.chunk(dl_doc=doc))
total_chunks = len(chunks)
logger.info(f"Generated {total_chunks} chunks")
# Step 3: Process each chunk
chunk_texts = []
chunk_records = []
token_counts = []
for i, chunk in enumerate(chunks):
# Get chunk text
chunk_text = chunk.text
# Get contextualized text
contextualized_text = self.chunker.contextualize(chunk=chunk)
# Calculate token count
text_to_tokenize = (
contextualized_text if contextualized_text else chunk_text
)
token_count = len(
self.tokenizer.tokenizer.encode(
text_to_tokenize, add_special_tokens=False
)
)
token_counts.append(token_count)
# Prepare chunk metadata
chunk_metadata = {}
if hasattr(chunk, "meta") and chunk.meta:
chunk_metadata = {
"doc_items": (
[str(item) for item in chunk.meta.doc_items]
if hasattr(chunk.meta, "doc_items")
else []
),
"headings": (
chunk.meta.headings
if hasattr(chunk.meta, "headings")
else []
),
}
# Create chunk record (without embedding yet)
chunk_record = KnowledgeBaseChunkModel(
document_id=doc_record.id,
organization_id=organization_id,
chunk_text=chunk_text,
contextualized_text=contextualized_text,
chunk_index=i,
chunk_metadata=chunk_metadata,
embedding_model=self.model_id,
embedding_dimension=EMBEDDING_DIMENSION,
token_count=token_count,
)
chunk_records.append(chunk_record)
chunk_texts.append(text_to_tokenize)
# Log chunk statistics
if token_counts:
avg_tokens = sum(token_counts) / len(token_counts)
min_tokens = min(token_counts)
max_tokens = max(token_counts)
logger.info("Chunk token statistics:")
logger.info(f" - Average: {avg_tokens:.1f} tokens")
logger.info(f" - Min: {min_tokens} tokens")
logger.info(f" - Max: {max_tokens} tokens")
# Step 4: Generate embeddings using OpenAI API
logger.info(f"Generating embeddings using OpenAI ({self.model_id})...")
embeddings = await self.embed_texts(chunk_texts)
# Step 5: Attach embeddings to chunk records
for chunk_record, embedding in zip(chunk_records, embeddings):
chunk_record.embedding = embedding
# Step 6: Save all chunks in batch
logger.info("Storing chunks in database...")
await self.db.create_chunks_batch(chunk_records)
# Update document status to completed
await self.db.update_document_status(
doc_record.id,
"completed",
total_chunks=total_chunks,
docling_metadata=docling_metadata,
)
logger.info(f"Successfully processed document: {filename}")
logger.info(f" - Total chunks: {total_chunks}")
logger.info(f" - Embedding model: {self.model_id}")
logger.info(f" - Document ID: {doc_record.id}")
logger.info(f" - Document UUID: {doc_record.document_uuid}")
return doc_record
except Exception as e:
logger.error(f"Error processing document with OpenAI: {e}")
# Update document status to failed if it exists
if "doc_record" in locals():
await self.db.update_document_status(
doc_record.id, "failed", error_message=str(e)
)
raise

View file

@ -0,0 +1,350 @@
"""Sentence Transformer embedding service.
This module provides document processing capabilities using:
- Sentence-transformers for embeddings (all-MiniLM-L6-v2)
- Docling for document conversion and chunking
- pgvector for vector similarity search
Setup for offline usage:
1. First run: Downloads and caches models to ~/.cache/sentence_transformers
2. Subsequent runs: Uses cached models (no internet needed)
3. For fully offline mode: Set TRANSFORMERS_OFFLINE=1 and HF_HUB_OFFLINE=1
"""
import os
from pathlib import Path
from typing import Any, Dict, List, Optional
from docling.chunking import HybridChunker
from docling.document_converter import DocumentConverter
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
from loguru import logger
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
from api.db.db_client import DBClient
from api.db.models import KnowledgeBaseChunkModel
from .base import BaseEmbeddingService
# Set environment variables for model caching
os.environ.setdefault("TRANSFORMERS_OFFLINE", "0")
os.environ.setdefault("HF_HUB_OFFLINE", "0")
os.environ.setdefault(
"SENTENCE_TRANSFORMERS_HOME", os.path.expanduser("~/.cache/sentence_transformers")
)
# Model configuration
DEFAULT_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
EMBEDDING_DIMENSION = 384 # Dimension for all-MiniLM-L6-v2
class SentenceTransformerEmbeddingService(BaseEmbeddingService):
"""Embedding service using Sentence Transformers."""
def __init__(
self,
db_client: DBClient,
model_id: str = DEFAULT_MODEL_ID,
max_tokens: int = 512,
):
"""Initialize the Sentence Transformer embedding service.
Args:
db_client: Database client for storing documents and chunks
model_id: Sentence-transformers model ID (default: all-MiniLM-L6-v2)
max_tokens: Maximum number of tokens per chunk (default: 512)
Note: This applies to the contextualized text (with headings/captions)
"""
self.db = db_client
self.model_id = model_id
self.max_tokens = max_tokens
# Initialize embedding model
logger.info(f"Loading embedding model: {model_id}")
try:
# Try to load from cache first (local_files_only=True)
self.embedding_model = SentenceTransformer(
model_id,
cache_folder=os.environ.get("SENTENCE_TRANSFORMERS_HOME"),
local_files_only=True,
)
logger.info("Loaded model from cache")
except Exception as e:
logger.warning(f"Model not in cache, downloading: {e}")
# If not in cache, download it (this will cache it for next time)
self.embedding_model = SentenceTransformer(
model_id,
cache_folder=os.environ.get("SENTENCE_TRANSFORMERS_HOME"),
)
logger.info("Model downloaded and cached")
# Initialize tokenizer for chunking with max_tokens
logger.info(f"Loading tokenizer: {model_id} with max_tokens={max_tokens}")
try:
# Try to load from cache first
self.tokenizer = HuggingFaceTokenizer(
tokenizer=AutoTokenizer.from_pretrained(
model_id,
local_files_only=True,
),
max_tokens=max_tokens,
)
logger.info("Loaded tokenizer from cache")
except Exception as e:
logger.warning(f"Tokenizer not in cache, downloading: {e}")
# If not in cache, download it
self.tokenizer = HuggingFaceTokenizer(
tokenizer=AutoTokenizer.from_pretrained(model_id),
max_tokens=max_tokens,
)
logger.info("Tokenizer downloaded and cached")
# Initialize chunker
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
self.chunker = HybridChunker(tokenizer=self.tokenizer)
# Initialize document converter
self.converter = DocumentConverter()
def get_model_id(self) -> str:
"""Return the model identifier."""
return self.model_id
def get_embedding_dimension(self) -> int:
"""Return the embedding dimension."""
return EMBEDDING_DIMENSION
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Embed a batch of texts.
Args:
texts: List of text strings to embed
Returns:
List of embedding vectors (each vector is a list of floats)
"""
embeddings = self.embedding_model.encode(
texts,
show_progress_bar=False,
convert_to_numpy=True,
)
return [embedding.tolist() for embedding in embeddings]
async def embed_query(self, query: str) -> List[float]:
"""Embed a single query text.
Args:
query: Query text to embed
Returns:
Embedding vector as list of floats
"""
embedding = self.embedding_model.encode([query])[0]
return embedding.tolist()
async def search_similar_chunks(
self,
query: str,
organization_id: int,
limit: int = 5,
document_uuids: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""Search for similar chunks using vector similarity.
Returns top-k most similar chunks without any threshold filtering.
Apply similarity thresholds and reranking at the application layer.
Args:
query: Search query text
organization_id: Organization ID for scoping
limit: Maximum number of results to return
document_uuids: Optional list of document UUIDs to filter by
Returns:
List of dictionaries with chunk data and similarity scores
"""
# Generate query embedding
query_embedding = await self.embed_query(query)
# Perform vector similarity search
results = await self.db.search_similar_chunks(
query_embedding=query_embedding,
organization_id=organization_id,
limit=limit,
document_uuids=document_uuids,
embedding_model=self.model_id,
)
return results
async def process_document(
self,
file_path: str,
organization_id: int,
created_by: int,
custom_metadata: dict = None,
):
"""Process a document: convert, chunk, embed, and store in database.
Args:
file_path: Path to the document file
organization_id: Organization ID for scoping
created_by: User ID who uploaded the document
custom_metadata: Optional custom metadata dictionary
Returns:
The created document record
"""
try:
# Extract file metadata
filename = Path(file_path).name
file_hash = self.db.compute_file_hash(file_path)
file_size = os.path.getsize(file_path)
mime_type = self.db.get_mime_type(file_path)
# Check if document already exists
existing_doc = await self.db.get_document_by_hash(
file_hash, organization_id
)
if existing_doc:
logger.info(f"Document already exists: {filename} (hash: {file_hash})")
return existing_doc
# Create document record
doc_record = await self.db.create_document(
organization_id=organization_id,
created_by=created_by,
filename=filename,
file_size_bytes=file_size,
file_hash=file_hash,
mime_type=mime_type,
custom_metadata=custom_metadata or {},
)
logger.info(f"Processing document: {filename}")
# Update status to processing
await self.db.update_document_status(doc_record.id, "processing")
# Step 1: Convert document using docling
logger.info("Converting document with docling...")
conversion_result = self.converter.convert(file_path)
doc = conversion_result.document
# Store docling metadata
docling_metadata = {
"num_pages": len(doc.pages) if hasattr(doc, "pages") else None,
"document_type": type(doc).__name__,
}
# Step 2: Chunk the document
logger.info(f"Chunking document with max_tokens={self.max_tokens}...")
chunks = list(self.chunker.chunk(dl_doc=doc))
total_chunks = len(chunks)
logger.info(f"Generated {total_chunks} chunks")
# Step 3: Process each chunk
chunk_texts = []
chunk_records = []
token_counts = []
for i, chunk in enumerate(chunks):
# Get chunk text
chunk_text = chunk.text
# Get contextualized text (enriched with surrounding context)
contextualized_text = self.chunker.contextualize(chunk=chunk)
# Calculate actual token count using the tokenizer
text_to_tokenize = (
contextualized_text if contextualized_text else chunk_text
)
token_count = len(
self.tokenizer.tokenizer.encode(
text_to_tokenize, add_special_tokens=False
)
)
token_counts.append(token_count)
# Prepare chunk metadata
chunk_metadata = {}
if hasattr(chunk, "meta") and chunk.meta:
chunk_metadata = {
"doc_items": (
[str(item) for item in chunk.meta.doc_items]
if hasattr(chunk.meta, "doc_items")
else []
),
"headings": (
chunk.meta.headings
if hasattr(chunk.meta, "headings")
else []
),
}
# Create chunk record (without embedding yet)
chunk_record = KnowledgeBaseChunkModel(
document_id=doc_record.id,
organization_id=organization_id,
chunk_text=chunk_text,
contextualized_text=contextualized_text,
chunk_index=i,
chunk_metadata=chunk_metadata,
embedding_model=self.model_id,
embedding_dimension=EMBEDDING_DIMENSION,
token_count=token_count,
)
chunk_records.append(chunk_record)
# Use contextualized text for embedding if available
chunk_texts.append(text_to_tokenize)
# Log chunk statistics
if token_counts:
avg_tokens = sum(token_counts) / len(token_counts)
min_tokens = min(token_counts)
max_tokens = max(token_counts)
logger.info("Chunk token statistics:")
logger.info(f" - Average: {avg_tokens:.1f} tokens")
logger.info(f" - Min: {min_tokens} tokens")
logger.info(f" - Max: {max_tokens} tokens")
# Step 4: Generate embeddings in batch
logger.info("Generating embeddings...")
embeddings = await self.embed_texts(chunk_texts)
# Step 5: Attach embeddings to chunk records
for chunk_record, embedding in zip(chunk_records, embeddings):
chunk_record.embedding = embedding
# Step 6: Save all chunks in batch
logger.info("Storing chunks in database...")
await self.db.create_chunks_batch(chunk_records)
# Update document status to completed
await self.db.update_document_status(
doc_record.id,
"completed",
total_chunks=total_chunks,
docling_metadata=docling_metadata,
)
logger.info(f"Successfully processed document: {filename}")
logger.info(f" - Total chunks: {total_chunks}")
logger.info(f" - Document ID: {doc_record.id}")
logger.info(f" - Document UUID: {doc_record.document_uuid}")
return doc_record
except Exception as e:
logger.error(f"Error processing document: {e}")
# Update document status to failed if it exists
if "doc_record" in locals():
await self.db.update_document_status(
doc_record.id, "failed", error_message=str(e)
)
raise

View file

@ -0,0 +1,44 @@
"""
Embeddings pricing models for different providers.
Prices are per token for embedding models.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import PricingModel
class EmbeddingPricingModel(PricingModel):
"""Pricing model for token-based embedding services."""
def __init__(self, token_price: Decimal):
"""Initialize with price per token.
Args:
token_price: Cost per token for embedding
"""
self.token_price = token_price
def calculate_cost(self, token_count: int) -> Decimal:
"""Calculate cost for embedding token usage."""
return Decimal(token_count) * self.token_price
# Embeddings pricing registry
EMBEDDINGS_PRICING: Dict[str, Dict[str, EmbeddingPricingModel]] = {
ServiceProviders.OPENAI: {
"text-embedding-3-small": EmbeddingPricingModel(
token_price=Decimal("0.02") / 1_000_000, # $0.02 per 1M tokens
),
"text-embedding-3-large": EmbeddingPricingModel(
token_price=Decimal("0.13") / 1_000_000, # $0.13 per 1M tokens
),
"text-embedding-ada-002": EmbeddingPricingModel(
token_price=Decimal("0.10") / 1_000_000, # $0.10 per 1M tokens (legacy)
),
},
}

View file

@ -4,6 +4,7 @@ Main pricing registry that combines all service type pricing models.
from typing import Dict
from .embeddings import EMBEDDINGS_PRICING
from .llm import LLM_PRICING
from .stt import STT_PRICING
from .tts import TTS_PRICING
@ -13,4 +14,5 @@ PRICING_REGISTRY: Dict = {
"llm": LLM_PRICING,
"tts": TTS_PRICING,
"stt": STT_PRICING,
"embeddings": EMBEDDINGS_PRICING,
}

View file

@ -58,6 +58,7 @@ class NodeDataDTO(BaseModel):
delayed_start: bool = False
delayed_start_duration: Optional[float] = None
tool_uuids: Optional[List[str]] = None
document_uuids: Optional[List[str]] = None
trigger_path: Optional[str] = None
# Webhook node specific fields
enabled: bool = True

View file

@ -41,6 +41,10 @@ from api.services.workflow.pipecat_engine_variable_extractor import (
VariableExtractionManager,
)
from api.services.workflow.tools.calculator import get_calculator_tools, safe_calculator
from api.services.workflow.tools.knowledge_base import (
get_knowledge_base_tool,
retrieve_from_knowledge_base,
)
from api.services.workflow.tools.timezone import (
convert_time,
get_current_time,
@ -290,6 +294,48 @@ class PipecatEngine:
self.llm.register_function("get_current_time", get_current_time_func)
self.llm.register_function("convert_time", convert_time_func)
async def _register_knowledge_base_function(
self, document_uuids: list[str]
) -> None:
"""Register knowledge base retrieval function with the LLM.
Args:
document_uuids: List of document UUIDs to filter the search by
"""
logger.debug(
f"Registering knowledge base retrieval function with {len(document_uuids)} document(s)"
)
async def retrieve_kb_func(function_call_params: FunctionCallParams) -> None:
logger.info("LLM Function Call EXECUTED: retrieve_from_knowledge_base")
logger.info(f"Arguments: {function_call_params.arguments}")
try:
query = function_call_params.arguments.get("query", "")
organization_id = await self._get_organization_id()
if not organization_id:
raise ValueError(
"Organization ID not available for knowledge base retrieval"
)
result = await retrieve_from_knowledge_base(
query=query,
organization_id=organization_id,
document_uuids=document_uuids,
limit=3, # Return top 3 most relevant chunks
)
await function_call_params.result_callback(result)
except Exception as e:
logger.error(f"Knowledge base retrieval failed: {e}")
await function_call_params.result_callback(
{"error": str(e), "chunks": [], "query": query, "total_results": 0}
)
# Register the function with the LLM
self.llm.register_function("retrieve_from_knowledge_base", retrieve_kb_func)
async def _perform_variable_extraction_if_needed(
self, previous_node: Optional[Node]
) -> None:
@ -346,6 +392,10 @@ class PipecatEngine:
if node.tool_uuids and self._custom_tool_manager:
await self._custom_tool_manager.register_handlers(node.tool_uuids)
# Register knowledge base retrieval handler if node has documents
if node.document_uuids:
await self._register_knowledge_base_function(node.document_uuids)
# Set up system message and functions
(
system_message,
@ -575,6 +625,17 @@ class PipecatEngine:
# Add built-in function schemas (calculator and timezone tools)
functions.extend(self.builtin_function_schemas)
# Add knowledge base retrieval tool if node has documents
if node.document_uuids:
kb_tool_def = get_knowledge_base_tool(node.document_uuids)
kb_schema = get_function_schema(
kb_tool_def["function"]["name"],
kb_tool_def["function"]["description"],
properties=kb_tool_def["function"]["parameters"].get("properties", {}),
required=kb_tool_def["function"]["parameters"].get("required", []),
)
functions.append(kb_schema)
# Add custom tools from node.tool_uuids
if node.tool_uuids and self._custom_tool_manager:
custom_tool_schemas = await self._custom_tool_manager.get_tool_schemas(

View file

@ -0,0 +1,305 @@
"""Knowledge Base retrieval tool for workflow execution.
This module provides vector similarity search capabilities for retrieving
relevant information from the knowledge base during conversations.
Implements OpenTelemetry tracing for observability in Langfuse.
"""
import json
from typing import Any, Dict, List, Optional
from loguru import logger
from opentelemetry import trace
from api.db import db_client
from api.services.gen_ai import OpenAIEmbeddingService
from api.services.pipecat.tracing_config import is_tracing_enabled
from pipecat.utils.tracing.context_registry import (
get_current_conversation_context,
get_current_turn_context,
)
async def retrieve_from_knowledge_base(
query: str,
organization_id: int,
document_uuids: Optional[List[str]] = None,
limit: int = 3,
embeddings_api_key: Optional[str] = None,
embeddings_model: Optional[str] = None,
) -> Dict[str, Any]:
"""Retrieve relevant information from the knowledge base using vector similarity search.
Uses OpenAI text-embedding-3-small for embeddings by default. This provides
high-quality 1536-dimensional embeddings for accurate retrieval.
This function includes OpenTelemetry tracing for Langfuse observability.
Args:
query: The search query to find relevant information
organization_id: Organization ID for scoping the search
document_uuids: Optional list of document UUIDs to filter by
limit: Maximum number of chunks to return (default: 3)
embeddings_api_key: Optional API key for embedding service
embeddings_model: Optional model ID for embedding service
Returns:
Dictionary containing:
- chunks: List of relevant text chunks with metadata
- query: The original query
- total_results: Number of results returned
"""
# Create span for retrieval operation if tracing is enabled
if is_tracing_enabled():
try:
# Get parent context from turn or conversation
turn_context = get_current_turn_context()
conversation_context = get_current_conversation_context()
parent_context = turn_context or conversation_context
# Get tracer
tracer = trace.get_tracer("pipecat")
except Exception as e:
logger.debug(f"Failed to setup tracing context: {e}")
# Fall back to non-traced execution
return await _perform_retrieval(
query,
organization_id,
document_uuids,
limit,
embeddings_api_key,
embeddings_model,
)
# Create span with parent context
if parent_context:
with tracer.start_as_current_span(
"knowledge_base_retrieval", context=parent_context
) as span:
try:
# Mark trace as public for Langfuse
span.set_attribute("langfuse.trace.public", True)
# Add operation metadata
span.set_attribute(
"gen_ai.operation.name", "knowledge_base_retrieval"
)
span.set_attribute("retrieval.query", query)
span.set_attribute("retrieval.limit", limit)
span.set_attribute("retrieval.organization_id", organization_id)
# Add document filter info
if document_uuids:
span.set_attribute(
"retrieval.document_count", len(document_uuids)
)
span.set_attribute(
"retrieval.document_uuids", json.dumps(document_uuids)
)
# Perform the actual retrieval
result = await _perform_retrieval(
query,
organization_id,
document_uuids,
limit,
embeddings_api_key,
embeddings_model,
)
# Add result metadata to span
span.set_attribute(
"retrieval.results_count", result["total_results"]
)
if result.get("error"):
span.set_attribute("retrieval.error", result["error"])
span.set_status(
trace.Status(trace.StatusCode.ERROR, result["error"])
)
else:
# Add similarity scores
if result["chunks"]:
similarities = [
chunk["similarity"] for chunk in result["chunks"]
]
span.set_attribute(
"retrieval.avg_similarity",
round(sum(similarities) / len(similarities), 4),
)
span.set_attribute(
"retrieval.max_similarity", max(similarities)
)
span.set_attribute(
"retrieval.min_similarity", min(similarities)
)
# Add retrieved documents info
filenames = list(
set(chunk["filename"] for chunk in result["chunks"])
)
span.set_attribute(
"retrieval.source_files", json.dumps(filenames)
)
# Add output as JSON for Langfuse
output_data = {
"query": query,
"chunks_retrieved": len(result["chunks"]),
"chunks": [
{
"text": chunk["text"][:200] + "..."
if len(chunk["text"]) > 200
else chunk["text"],
"filename": chunk["filename"],
"similarity": chunk["similarity"],
}
for chunk in result["chunks"]
],
}
span.set_attribute("output", json.dumps(output_data))
return result
except Exception as e:
logger.error(f"Error in traced retrieval: {e}")
span.record_exception(e)
span.set_status(trace.Status(trace.StatusCode.ERROR, str(e)))
raise
else:
# No parent context - perform retrieval without tracing
logger.debug(
"No parent context available for knowledge base retrieval tracing"
)
return await _perform_retrieval(
query,
organization_id,
document_uuids,
limit,
embeddings_api_key,
embeddings_model,
)
else:
# Tracing is disabled - perform retrieval without tracing
return await _perform_retrieval(
query,
organization_id,
document_uuids,
limit,
embeddings_api_key,
embeddings_model,
)
async def _perform_retrieval(
query: str,
organization_id: int,
document_uuids: Optional[List[str]],
limit: int,
embeddings_api_key: Optional[str] = None,
embeddings_model: Optional[str] = None,
) -> Dict[str, Any]:
"""Internal function to perform the actual retrieval operation.
Separated from tracing logic for cleaner code organization.
Uses OpenAI embeddings by default for high-quality retrieval.
"""
try:
# Create a new embedding service instance
# Uses OpenAI text-embedding-3-small by default, or user-provided config
embedding_service = OpenAIEmbeddingService(
db_client=db_client,
max_tokens=128, # This is only used for chunking, not for retrieval
api_key=embeddings_api_key,
model_id=embeddings_model or "text-embedding-3-small",
)
# Perform vector similarity search
results = await embedding_service.search_similar_chunks(
query=query,
organization_id=organization_id,
limit=limit,
document_uuids=document_uuids,
)
# Format results for LLM consumption
chunks = []
for result in results:
chunk_info = {
"text": result.get("contextualized_text") or result.get("chunk_text"),
"filename": result.get("filename"),
"similarity": round(result.get("similarity", 0), 4),
"chunk_index": result.get("chunk_index"),
}
chunks.append(chunk_info)
logger.info(
f"Knowledge base retrieval: query='{query}', "
f"results={len(chunks)}, "
f"document_filter={document_uuids}"
)
return {
"chunks": chunks,
"query": query,
"total_results": len(chunks),
}
except Exception as e:
logger.error(f"Error retrieving from knowledge base: {e}")
return {
"error": str(e),
"chunks": [],
"query": query,
"total_results": 0,
}
def get_knowledge_base_tool(
document_uuids: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""Get knowledge base retrieval tool definition for LLM function calling.
Args:
document_uuids: Optional list of document UUIDs to include in description
Returns:
Tool definition compatible with LLM function calling
"""
# Build description based on whether specific documents are filtered
if document_uuids and len(document_uuids) > 0:
description = (
"Retrieve relevant information from specific documents in the knowledge base. "
"Use this tool when you need to look up facts, policies, procedures, or any information "
"that might be stored in the available documents. The search will only look in the "
f"documents associated with this conversation step ({len(document_uuids)} document(s) available)."
)
else:
description = (
"Retrieve relevant information from the knowledge base. "
"Use this tool when you need to look up facts, policies, procedures, or any information "
"that might be stored in the knowledge base documents."
)
return {
"type": "function",
"function": {
"name": "retrieve_from_knowledge_base",
"description": description,
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": (
"The search query to find relevant information. "
"Be specific and use natural language. "
"Example: 'What is the refund policy for canceled orders?'"
),
}
},
"required": ["query"],
},
},
}

View file

@ -48,6 +48,7 @@ class Node:
self.delayed_start = data.delayed_start
self.delayed_start_duration = data.delayed_start_duration
self.tool_uuids = data.tool_uuids
self.document_uuids = data.document_uuids
self.data = data
@ -189,16 +190,6 @@ class WorkflowGraph:
in_d, out_d = in_deg[n.id], out_deg[n.id]
match n.node_type:
case NodeType.startNode:
if in_d != 0 or out_d < 1:
errors.append(
WorkflowError(
kind=ItemKind.node,
id=n.id,
field=None,
message=f"StartNode must have at least 1 outgoing edge",
)
)
case NodeType.endNode:
if in_d < 1 or out_d != 0:
errors.append(