feat: add full document mode in knowledge base

This commit is contained in:
Abhishek Kumar 2026-04-09 13:49:20 +05:30
parent c085398933
commit 87c8c5e2c8
26 changed files with 1144 additions and 351 deletions

View file

@ -0,0 +1,42 @@
"""add retrieval mode in document
Revision ID: e7254d2c6c18
Revises: d688d0da1123
Create Date: 2026-04-09 13:00:13.020713
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "e7254d2c6c18"
down_revision: Union[str, None] = "d688d0da1123"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"knowledge_base_documents",
sa.Column(
"retrieval_mode",
sa.String(length=20),
server_default="chunked",
nullable=False,
),
)
op.add_column(
"knowledge_base_documents", sa.Column("full_text", sa.Text(), nullable=True)
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("knowledge_base_documents", "full_text")
op.drop_column("knowledge_base_documents", "retrieval_mode")
# ### end Alembic commands ###

View file

@ -27,6 +27,7 @@ class KnowledgeBaseClient(BaseDBClient):
custom_metadata: Optional[dict] = None,
docling_metadata: Optional[dict] = None,
document_uuid: Optional[str] = None,
retrieval_mode: str = "chunked",
) -> KnowledgeBaseDocumentModel:
"""Create a new knowledge base document record.
@ -58,6 +59,7 @@ class KnowledgeBaseClient(BaseDBClient):
docling_metadata=docling_metadata or {},
processing_status="pending",
total_chunks=0,
retrieval_mode=retrieval_mode,
)
# Use provided UUID or let the model generate one
@ -425,6 +427,55 @@ class KnowledgeBaseClient(BaseDBClient):
# Convert asyncpg records to dictionaries
return [dict(row) for row in rows]
async def update_document_full_text(
self,
document_id: int,
full_text: str,
) -> None:
"""Store full document text for full_document retrieval mode.
Args:
document_id: ID of the document
full_text: The full extracted text content
"""
async with self.async_session() as session:
query = select(KnowledgeBaseDocumentModel).where(
KnowledgeBaseDocumentModel.id == document_id
)
result = await session.execute(query)
document = result.scalar_one_or_none()
if document:
document.full_text = full_text
await session.commit()
logger.info(
f"Stored full text for document {document_id} ({len(full_text)} chars)"
)
async def get_full_text_documents(
self,
organization_id: int,
document_uuids: List[str],
) -> List[KnowledgeBaseDocumentModel]:
"""Get full_document mode documents by their UUIDs.
Args:
organization_id: Organization ID for scoping
document_uuids: List of document UUIDs to fetch
Returns:
List of documents with retrieval_mode='full_document' and full_text set
"""
async with self.async_session() as session:
query = select(KnowledgeBaseDocumentModel).where(
KnowledgeBaseDocumentModel.organization_id == organization_id,
KnowledgeBaseDocumentModel.document_uuid.in_(document_uuids),
KnowledgeBaseDocumentModel.retrieval_mode == "full_document",
KnowledgeBaseDocumentModel.is_active == True,
KnowledgeBaseDocumentModel.processing_status == "completed",
)
result = await session.execute(query)
return list(result.scalars().all())
async def delete_document(
self,
document_uuid: str,

View file

@ -940,6 +940,14 @@ class KnowledgeBaseDocumentModel(Base):
file_hash = Column(String(64), nullable=True) # SHA-256 hash for deduplication
mime_type = Column(String(100), nullable=True)
# Retrieval mode: "chunked" (vector search) or "full_document" (return full text)
retrieval_mode = Column(
String(20), nullable=False, default="chunked", server_default="chunked"
)
full_text = Column(
Text, nullable=True
) # Stored when retrieval_mode is "full_document"
# Processing metadata
source_url = Column(String, nullable=True) # If document was fetched from URL
total_chunks = Column(Integer, nullable=False, default=0)

View file

@ -124,6 +124,7 @@ async def process_document(
mime_type="application/octet-stream", # Will be detected by background task
custom_metadata={"s3_key": request.s3_key},
document_uuid=request.document_uuid, # Use UUID from upload
retrieval_mode=request.retrieval_mode,
)
# Enqueue background task for processing
@ -133,6 +134,7 @@ async def process_document(
request.s3_key,
user.selected_organization_id,
128, # max_tokens (default)
request.retrieval_mode,
)
logger.info(
@ -150,6 +152,7 @@ async def process_document(
processing_status="pending",
processing_error=None,
total_chunks=0,
retrieval_mode=request.retrieval_mode,
custom_metadata={"s3_key": request.s3_key},
docling_metadata={},
source_url=None,
@ -209,6 +212,7 @@ async def list_documents(
processing_status=doc.processing_status,
processing_error=doc.processing_error,
total_chunks=doc.total_chunks,
retrieval_mode=doc.retrieval_mode,
custom_metadata=doc.custom_metadata,
docling_metadata=doc.docling_metadata,
source_url=doc.source_url,
@ -267,6 +271,7 @@ async def get_document(
processing_status=document.processing_status,
processing_error=document.processing_error,
total_chunks=document.total_chunks,
retrieval_mode=document.retrieval_mode,
custom_metadata=document.custom_metadata,
docling_metadata=document.docling_metadata,
source_url=document.source_url,

View file

@ -1,4 +1,5 @@
import json
import re
import uuid
from datetime import datetime
from typing import List, Literal, Optional
@ -6,13 +7,13 @@ from typing import List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from httpx import HTTPStatusError
from loguru import logger
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel, Field, ValidationError
from api.constants import DEPLOYMENT_MODE
from api.db import db_client
from api.db.models import UserModel
from api.db.workflow_template_client import WorkflowTemplateClient
from api.enums import CallType
from api.enums import CallType, StorageBackend
from api.schemas.workflow import WorkflowRunResponseSchema
from api.services.auth.depends import get_user
from api.services.configuration.check_validity import UserConfigurationValidator
@ -22,6 +23,7 @@ from api.services.configuration.masking import (
)
from api.services.configuration.resolve import resolve_effective_config
from api.services.mps_service_key_client import mps_service_key_client
from api.services.storage import storage_fs
from api.services.workflow.dto import ReactFlowDTO
from api.services.workflow.duplicate import duplicate_workflow
from api.services.workflow.errors import ItemKind, WorkflowError
@ -1030,3 +1032,60 @@ async def duplicate_workflow_template(
"call_disposition_codes": workflow.call_disposition_codes,
"workflow_configurations": workflow.workflow_configurations,
}
# ---------------------------------------------------------------------------
# Ambient Noise Upload
# ---------------------------------------------------------------------------
class AmbientNoiseUploadRequest(BaseModel):
workflow_id: int
filename: str
mime_type: str = "audio/wav"
file_size: int = Field(..., gt=0, le=10_485_760, description="Max 10MB")
class AmbientNoiseUploadResponse(BaseModel):
upload_url: str
storage_key: str
storage_backend: str
@router.post(
"/ambient-noise/upload-url",
response_model=AmbientNoiseUploadResponse,
summary="Get a presigned URL to upload a custom ambient noise audio file",
)
async def get_ambient_noise_upload_url(
request: AmbientNoiseUploadRequest,
user=Depends(get_user),
):
"""Generate a presigned PUT URL for uploading a custom ambient noise file."""
# Verify user owns this workflow
workflow = await db_client.get_workflow(
request.workflow_id, organization_id=user.selected_organization_id
)
if not workflow:
raise HTTPException(status_code=404, detail="Workflow not found")
sanitized = re.sub(r"[^a-zA-Z0-9._-]", "_", request.filename)
storage_key = (
f"ambient-noise/{user.selected_organization_id}"
f"/{request.workflow_id}/{uuid.uuid4()}_{sanitized}"
)
upload_url = await storage_fs.aget_presigned_put_url(
file_path=storage_key,
expiration=1800,
content_type=request.mime_type,
max_size=request.file_size,
)
if not upload_url:
raise HTTPException(status_code=500, detail="Failed to generate upload URL")
return AmbientNoiseUploadResponse(
upload_url=upload_url,
storage_key=storage_key,
storage_backend=StorageBackend.get_current_backend().value,
)

View file

@ -29,6 +29,10 @@ class ProcessDocumentRequestSchema(BaseModel):
document_uuid: str = Field(..., description="Document UUID to process")
s3_key: str = Field(..., description="S3 key of the uploaded file")
retrieval_mode: str = Field(
default="chunked",
description="Retrieval mode: 'chunked' for vector search or 'full_document' for full text retrieval",
)
class DocumentResponseSchema(BaseModel):
@ -43,6 +47,7 @@ class DocumentResponseSchema(BaseModel):
processing_status: str # pending, processing, completed, failed
processing_error: Optional[str] = None
total_chunks: int
retrieval_mode: str = "chunked"
custom_metadata: Dict[str, Any]
docling_metadata: Dict[str, Any]
source_url: Optional[str] = None

View file

@ -0,0 +1,220 @@
"""Shared utilities for downloading, converting, and caching audio files.
Provides helpers used by both the recording audio cache and the ambient
noise cache to avoid duplicating download / ffmpeg / disk-cache logic.
"""
import asyncio
import os
import shutil
import tempfile
from typing import Literal, Optional
from loguru import logger
from api.constants import APP_ROOT_DIR
# ---------------------------------------------------------------------------
# Filesystem cache directory (shared by all audio caches)
# ---------------------------------------------------------------------------
CACHE_DIR = os.path.join(os.path.dirname(APP_ROOT_DIR), "dograh_pcm_cache")
os.makedirs(CACHE_DIR, exist_ok=True)
# ---------------------------------------------------------------------------
# Download helper
# ---------------------------------------------------------------------------
async def download_storage_file(
storage_key: str,
storage_backend: str,
get_storage_fn,
) -> Optional[str]:
"""Download a file from object storage to a local temp file.
Returns the temp file path on success, or None on failure.
The caller is responsible for cleaning up the temp file.
"""
ext = ext_from_key(storage_key)
fd, tmp_path = tempfile.mkstemp(suffix=ext, prefix="dograh_dl_")
os.close(fd)
try:
storage = get_storage_fn(storage_backend)
success = await storage.adownload_file(storage_key, tmp_path)
if not success:
logger.error(f"Failed to download {storage_key}")
_safe_unlink(tmp_path)
return None
return tmp_path
except Exception:
logger.exception(f"Error downloading {storage_key}")
_safe_unlink(tmp_path)
return None
# ---------------------------------------------------------------------------
# Audio conversion via ffmpeg
# ---------------------------------------------------------------------------
async def convert_audio_file(
file_path: str,
target_sample_rate: int,
output_format: Literal["pcm", "wav"] = "pcm",
) -> Optional[bytes]:
"""Convert an audio file via ffmpeg.
Args:
file_path: Path to the source audio file.
target_sample_rate: Desired output sample rate.
output_format: ``"pcm"`` for raw s16le bytes, ``"wav"`` for a
complete WAV file (16-bit mono).
Returns:
Converted audio bytes, or None on failure.
"""
ffmpeg = shutil.which("ffmpeg")
if not ffmpeg:
logger.error("ffmpeg not found on PATH - cannot convert audio")
return None
if output_format == "pcm":
fmt_args = ["-f", "s16le", "-acodec", "pcm_s16le"]
else:
fmt_args = ["-f", "wav", "-acodec", "pcm_s16le"]
cmd = [
ffmpeg,
"-i",
file_path,
*fmt_args,
"-ac",
"1",
"-ar",
str(target_sample_rate),
"-loglevel",
"error",
"pipe:1",
]
try:
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
logger.error(f"ffmpeg failed (rc={proc.returncode}): {stderr.decode()}")
return None
if not stdout:
logger.error("ffmpeg produced no output")
return None
return stdout
except Exception:
logger.exception("ffmpeg subprocess error")
return None
# ---------------------------------------------------------------------------
# File I/O helpers
# ---------------------------------------------------------------------------
def read_cached_file(path: str) -> bytes:
with open(path, "rb") as f:
return f.read()
def write_cache_file(path: str, data: bytes) -> None:
"""Atomically write *data* to *path* (write-to-tmp then rename)."""
fd, tmp = tempfile.mkstemp(dir=CACHE_DIR, suffix=".tmp")
os.close(fd)
with open(tmp, "wb") as f:
f.write(data)
os.replace(tmp, path)
def ext_from_key(storage_key: str) -> str:
"""Extract file extension from a storage key, defaulting to .wav."""
_, ext = os.path.splitext(storage_key)
return ext if ext else ".wav"
def _safe_unlink(path: str) -> None:
try:
if os.path.exists(path):
os.unlink(path)
except OSError:
pass
# ---------------------------------------------------------------------------
# Ambient noise file cache
# ---------------------------------------------------------------------------
def _ambient_noise_cache_path(storage_key: str, sample_rate: int) -> str:
"""Return the on-disk path for a cached ambient noise WAV file."""
# Use a stable hash of the storage key so different uploads get different cache entries
import hashlib
key_hash = hashlib.sha256(storage_key.encode()).hexdigest()[:16]
return os.path.join(CACHE_DIR, f"ambient_{key_hash}_{sample_rate}.wav")
async def get_cached_ambient_noise_path(
storage_key: str,
storage_backend: str,
target_sample_rate: int,
) -> Optional[str]:
"""Return a local WAV file path for a custom ambient noise file.
Downloads from object storage and converts to mono WAV at
*target_sample_rate* on the first call; subsequent calls return the
cached path immediately.
Args:
storage_key: Object storage key for the uploaded audio file.
storage_backend: Storage backend identifier (e.g. ``"minio"``, ``"s3"``).
target_sample_rate: Target sample rate for the output WAV.
Returns:
Absolute path to the cached WAV file, or None on failure.
"""
from api.services.storage import get_storage_for_backend
cached = _ambient_noise_cache_path(storage_key, target_sample_rate)
if os.path.exists(cached):
logger.debug(f"Ambient noise served from cache: {cached}")
return cached
logger.info(f"Downloading custom ambient noise: {storage_key}")
def _get_storage(backend: str):
return get_storage_for_backend(backend)
tmp_path = await download_storage_file(storage_key, storage_backend, _get_storage)
if not tmp_path:
return None
try:
wav_data = await convert_audio_file(
tmp_path, target_sample_rate, output_format="wav"
)
if wav_data is None:
return None
write_cache_file(cached, wav_data)
logger.info(f"Cached custom ambient noise: {cached} ({len(wav_data)} bytes)")
return cached
except Exception:
logger.exception("Error caching ambient noise file")
return None
finally:
_safe_unlink(tmp_path)

View file

@ -6,29 +6,30 @@ leading/trailing silence, and caches the processed bytes on disk so
subsequent plays (even from other workers) are instantaneous.
"""
import asyncio
import os
import shutil
import tempfile
from typing import Awaitable, Callable, Optional
import numpy as np
from loguru import logger
from api.constants import APP_ROOT_DIR
from pipecat.audio.utils import SPEAKING_THRESHOLD
# ---------------------------------------------------------------------------
# Filesystem cache directory
# ---------------------------------------------------------------------------
from .audio_file_cache import (
CACHE_DIR,
convert_audio_file,
download_storage_file,
read_cached_file,
write_cache_file,
)
_CACHE_DIR = os.path.join(os.path.dirname(APP_ROOT_DIR), "dograh_pcm_cache")
os.makedirs(_CACHE_DIR, exist_ok=True)
# ---------------------------------------------------------------------------
# Cache path helper
# ---------------------------------------------------------------------------
def _cache_path(recording_id: str, sample_rate: int) -> str:
"""Return the on-disk path for a cached PCM file."""
return os.path.join(_CACHE_DIR, f"{recording_id}_{sample_rate}.pcm")
return os.path.join(CACHE_DIR, f"{recording_id}_{sample_rate}.pcm")
# ---------------------------------------------------------------------------
@ -72,7 +73,7 @@ def create_recording_audio_fetcher(
# 1. Serve from filesystem cache
if os.path.exists(cached):
logger.debug(f"Recording {recording_id} served from disk cache")
return _read_file(cached)
return read_cached_file(cached)
# 2. DB lookup
recording = await db_client.get_recording_by_recording_id(
@ -172,109 +173,33 @@ async def _download_and_convert(
Returns the processed PCM bytes, or None on failure.
"""
ext = _ext_from_key(recording.storage_key)
fd, tmp_path = tempfile.mkstemp(
suffix=ext, prefix=f"dograh_dl_{recording.recording_id}_"
tmp_path = await download_storage_file(
recording.storage_key, recording.storage_backend, get_storage_fn
)
os.close(fd)
try:
storage = get_storage_fn(recording.storage_backend)
success = await storage.adownload_file(recording.storage_key, tmp_path)
if not success:
logger.error(f"Failed to download recording {recording.recording_id}")
return None
if not tmp_path:
return None
pcm_data = await _audio_file_to_pcm(tmp_path, sample_rate)
try:
pcm_data = await convert_audio_file(tmp_path, sample_rate, output_format="pcm")
if pcm_data is None:
return None
pcm_data = _trim_silence(pcm_data, sample_rate)
# Write to disk cache atomically (write to tmp then rename)
# Write to disk cache
cached = _cache_path(recording.recording_id, sample_rate)
fd, tmp_cache = tempfile.mkstemp(dir=_CACHE_DIR, suffix=".pcm.tmp")
os.close(fd)
_write_file(tmp_cache, pcm_data)
os.replace(tmp_cache, cached)
write_cache_file(cached, pcm_data)
return pcm_data
except Exception:
logger.exception(f"Error fetching recording {recording.recording_id}")
return None
finally:
if os.path.exists(tmp_path):
try:
try:
if os.path.exists(tmp_path):
os.unlink(tmp_path)
except OSError:
pass
# ---------------------------------------------------------------------------
# File I/O helpers (run via asyncio.to_thread)
# ---------------------------------------------------------------------------
def _read_file(path: str) -> bytes:
with open(path, "rb") as f:
return f.read()
def _write_file(path: str, data: bytes) -> None:
with open(path, "wb") as f:
f.write(data)
# ---------------------------------------------------------------------------
# Audio conversion
# ---------------------------------------------------------------------------
async def _audio_file_to_pcm(
file_path: str, target_sample_rate: int
) -> Optional[bytes]:
"""Convert an audio file to raw 16-bit mono PCM bytes via ffmpeg."""
ffmpeg = shutil.which("ffmpeg")
if not ffmpeg:
logger.error("ffmpeg not found on PATH — cannot decode recording")
return None
cmd = [
ffmpeg,
"-i",
file_path,
"-f",
"s16le", # raw 16-bit signed little-endian PCM
"-acodec",
"pcm_s16le",
"-ac",
"1", # mono
"-ar",
str(target_sample_rate),
"-loglevel",
"error",
"pipe:1", # output to stdout
]
try:
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
logger.error(f"ffmpeg failed (rc={proc.returncode}): {stderr.decode()}")
return None
if not stdout:
logger.error("ffmpeg produced no output")
return None
return stdout
except Exception:
logger.exception("ffmpeg subprocess error")
return None
except OSError:
pass
# ---------------------------------------------------------------------------
@ -327,14 +252,3 @@ def _trim_silence(pcm_data: bytes, sample_rate: int) -> bytes:
)
return trimmed.tobytes()
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _ext_from_key(storage_key: str) -> str:
"""Extract file extension from a storage key, defaulting to .wav."""
_, ext = os.path.splitext(storage_key)
return ext if ext else ".wav"

View file

@ -510,7 +510,7 @@ async def run_pipeline_smallwebrtc(
# Create audio configuration for WebRTC
audio_config = create_audio_config(WorkflowRunMode.SMALLWEBRTC.value)
transport = create_webrtc_transport(
transport = await create_webrtc_transport(
webrtc_connection,
workflow_run_id,
audio_config,

View file

@ -1,11 +1,13 @@
import os
from fastapi import WebSocket
from loguru import logger
from api.constants import APP_ROOT_DIR
from api.db import db_client
from api.enums import OrganizationConfigurationKey
from api.services.pipecat.audio_config import AudioConfig
from api.services.pipecat.audio_file_cache import get_cached_ambient_noise_path
from api.services.telephony.providers.ari_call_strategies import (
ARIBridgeSwapStrategy,
ARIHangupStrategy,
@ -37,6 +39,49 @@ librnnoise_path = os.path.normpath(
)
async def _build_audio_out_mixer(
audio_out_sample_rate: int,
ambient_noise_config: dict | None,
):
"""Build the audio output mixer based on the ambient noise configuration.
Returns a ``SoundfileMixer`` when ambient noise is enabled, or a
``SilenceAudioMixer`` otherwise. Supports custom user-uploaded audio
files via the ``storage_key`` / ``storage_backend`` fields in the config.
"""
if not ambient_noise_config or not ambient_noise_config.get("enabled", False):
return SilenceAudioMixer()
volume = ambient_noise_config.get("volume", 0.3)
# Check for a custom uploaded ambient noise file
storage_key = ambient_noise_config.get("storage_key")
storage_backend = ambient_noise_config.get("storage_backend")
if storage_key and storage_backend:
cached_path = await get_cached_ambient_noise_path(
storage_key, storage_backend, audio_out_sample_rate
)
if cached_path:
return SoundfileMixer(
sound_files={"custom": cached_path},
default_sound="custom",
volume=volume,
)
logger.warning("Custom ambient noise file unavailable, falling back to default")
# Default built-in office ambience
return SoundfileMixer(
sound_files={
"office": APP_ROOT_DIR
/ "assets"
/ f"office-ambience-{audio_out_sample_rate}-mono.wav"
},
default_sound="office",
volume=volume,
)
async def create_twilio_transport(
websocket_client: WebSocket,
stream_sid: str,
@ -79,6 +124,10 @@ async def create_twilio_transport(
hangup_strategy=hangup_strategy,
)
mixer = await _build_audio_out_mixer(
audio_config.transport_out_sample_rate, ambient_noise_config
)
return FastAPIWebsocketTransport(
websocket=websocket_client,
params=FastAPIWebsocketParams(
@ -86,19 +135,7 @@ async def create_twilio_transport(
audio_out_enabled=True,
audio_in_sample_rate=audio_config.transport_in_sample_rate,
audio_out_sample_rate=audio_config.transport_out_sample_rate,
audio_out_mixer=(
SoundfileMixer(
sound_files={
"office": APP_ROOT_DIR
/ "assets"
/ f"office-ambience-{audio_config.transport_out_sample_rate}-mono.wav"
},
default_sound="office",
volume=ambient_noise_config.get("volume", 0.3),
)
if ambient_noise_config and ambient_noise_config.get("enabled", False)
else SilenceAudioMixer()
),
audio_out_mixer=mixer,
serializer=serializer,
),
)
@ -144,6 +181,10 @@ async def create_cloudonix_transport(
hangup_strategy=hangup_strategy,
)
mixer = await _build_audio_out_mixer(
audio_config.transport_out_sample_rate, ambient_noise_config
)
return FastAPIWebsocketTransport(
websocket=websocket_client,
params=FastAPIWebsocketParams(
@ -151,19 +192,7 @@ async def create_cloudonix_transport(
audio_out_enabled=True,
audio_in_sample_rate=audio_config.transport_in_sample_rate,
audio_out_sample_rate=audio_config.transport_out_sample_rate,
audio_out_mixer=(
SoundfileMixer(
sound_files={
"office": APP_ROOT_DIR
/ "assets"
/ f"office-ambience-{audio_config.transport_out_sample_rate}-mono.wav"
},
default_sound="office",
volume=ambient_noise_config.get("volume", 0.3),
)
if ambient_noise_config and ambient_noise_config.get("enabled", False)
else SilenceAudioMixer()
),
audio_out_mixer=mixer,
serializer=serializer,
audio_out_10ms_chunks=2,
),
@ -209,6 +238,10 @@ async def create_telnyx_transport(
inbound_encoding="PCMU",
)
mixer = await _build_audio_out_mixer(
audio_config.transport_out_sample_rate, ambient_noise_config
)
return FastAPIWebsocketTransport(
websocket=websocket_client,
params=FastAPIWebsocketParams(
@ -216,19 +249,7 @@ async def create_telnyx_transport(
audio_out_enabled=True,
audio_in_sample_rate=audio_config.transport_in_sample_rate,
audio_out_sample_rate=audio_config.transport_out_sample_rate,
audio_out_mixer=(
SoundfileMixer(
sound_files={
"office": APP_ROOT_DIR
/ "assets"
/ f"office-ambience-{audio_config.transport_out_sample_rate}-mono.wav"
},
default_sound="office",
volume=ambient_noise_config.get("volume", 0.3),
)
if ambient_noise_config and ambient_noise_config.get("enabled", False)
else SilenceAudioMixer()
),
audio_out_mixer=mixer,
serializer=serializer,
),
)
@ -278,6 +299,10 @@ async def create_ari_transport(
),
)
mixer = await _build_audio_out_mixer(
audio_config.transport_out_sample_rate, ambient_noise_config
)
return FastAPIWebsocketTransport(
websocket=websocket_client,
params=FastAPIWebsocketParams(
@ -285,19 +310,7 @@ async def create_ari_transport(
audio_out_enabled=True,
audio_in_sample_rate=audio_config.transport_in_sample_rate,
audio_out_sample_rate=audio_config.transport_out_sample_rate,
audio_out_mixer=(
SoundfileMixer(
sound_files={
"office": APP_ROOT_DIR
/ "assets"
/ f"office-ambience-{audio_config.transport_out_sample_rate}-mono.wav"
},
default_sound="office",
volume=ambient_noise_config.get("volume", 0.3),
)
if ambient_noise_config and ambient_noise_config.get("enabled", False)
else SilenceAudioMixer()
),
audio_out_mixer=mixer,
serializer=serializer,
),
)
@ -340,6 +353,10 @@ async def create_vonage_transport(
),
)
mixer = await _build_audio_out_mixer(
audio_config.transport_out_sample_rate, ambient_noise_config
)
# Important: Vonage uses binary WebSocket mode, not text
return FastAPIWebsocketTransport(
websocket=websocket_client,
@ -348,19 +365,7 @@ async def create_vonage_transport(
audio_out_enabled=True,
audio_in_sample_rate=audio_config.transport_in_sample_rate,
audio_out_sample_rate=audio_config.transport_out_sample_rate,
audio_out_mixer=(
SoundfileMixer(
sound_files={
"office": APP_ROOT_DIR
/ "assets"
/ f"office-ambience-{audio_config.transport_out_sample_rate}-mono.wav"
},
default_sound="office",
volume=ambient_noise_config.get("volume", 0.3),
)
if ambient_noise_config and ambient_noise_config.get("enabled", False)
else SilenceAudioMixer()
),
audio_out_mixer=mixer,
serializer=serializer,
),
)
@ -428,6 +433,10 @@ async def create_vobiz_transport(
f"transport_rate=8000Hz, pipeline_rate={audio_config.pipeline_sample_rate}Hz"
)
mixer = await _build_audio_out_mixer(
audio_config.transport_out_sample_rate, ambient_noise_config
)
# Create WebSocket transport (same structure as Twilio/Vonage)
transport = FastAPIWebsocketTransport(
websocket=websocket_client,
@ -436,19 +445,7 @@ async def create_vobiz_transport(
audio_out_enabled=True,
audio_in_sample_rate=audio_config.transport_in_sample_rate,
audio_out_sample_rate=audio_config.transport_out_sample_rate,
audio_out_mixer=(
SoundfileMixer(
sound_files={
"office": APP_ROOT_DIR
/ "assets"
/ f"office-ambience-{audio_config.transport_out_sample_rate}-mono.wav"
},
default_sound="office",
volume=ambient_noise_config.get("volume", 0.3),
)
if ambient_noise_config and ambient_noise_config.get("enabled", False)
else SilenceAudioMixer()
),
audio_out_mixer=mixer,
serializer=serializer,
),
)
@ -459,7 +456,7 @@ async def create_vobiz_transport(
return transport
def create_webrtc_transport(
async def create_webrtc_transport(
webrtc_connection: SmallWebRTCConnection,
workflow_run_id: int,
audio_config: AudioConfig,
@ -468,6 +465,10 @@ def create_webrtc_transport(
):
"""Create a transport for WebRTC connections"""
mixer = await _build_audio_out_mixer(
audio_config.transport_out_sample_rate, ambient_noise_config
)
return SmallWebRTCTransport(
webrtc_connection=webrtc_connection,
params=TransportParams(
@ -475,19 +476,7 @@ def create_webrtc_transport(
audio_out_enabled=True,
audio_in_sample_rate=audio_config.transport_in_sample_rate,
audio_out_sample_rate=audio_config.transport_out_sample_rate,
audio_out_mixer=(
SoundfileMixer(
sound_files={
"office": APP_ROOT_DIR
/ "assets"
/ f"office-ambience-{audio_config.transport_out_sample_rate}-mono.wav"
},
default_sound="office",
volume=ambient_noise_config.get("volume", 0.3),
)
if ambient_noise_config and ambient_noise_config.get("enabled", False)
else SilenceAudioMixer()
),
audio_out_mixer=mixer,
),
)

View file

@ -301,12 +301,6 @@ class PipecatEngine:
"Organization ID not available for knowledge base retrieval"
)
if not self._embeddings_api_key:
raise ValueError(
"Embeddings API key not configured. Please set your API key in "
"Model Configurations > Embedding."
)
result = await retrieve_from_knowledge_base(
query=query,
organization_id=organization_id,

View file

@ -204,37 +204,66 @@ async def _perform_retrieval(
"""Internal function to perform the actual retrieval operation.
Separated from tracing logic for cleaner code organization.
Uses OpenAI embeddings by default for high-quality retrieval.
Handles both chunked (vector search) and full_document (full text) modes.
"""
try:
# Create a new embedding service instance
# Uses OpenAI text-embedding-3-small by default, or user-provided config
embedding_service = OpenAIEmbeddingService(
db_client=db_client,
max_tokens=128, # This is only used for chunking, not for retrieval
api_key=embeddings_api_key,
model_id=embeddings_model or "text-embedding-3-small",
base_url=embeddings_base_url,
)
# Perform vector similarity search
results = await embedding_service.search_similar_chunks(
query=query,
organization_id=organization_id,
limit=limit,
document_uuids=document_uuids,
)
# Format results for LLM consumption
chunks = []
for result in results:
chunk_info = {
"text": result.get("contextualized_text") or result.get("chunk_text"),
"filename": result.get("filename"),
"similarity": round(result.get("similarity", 0), 4),
"chunk_index": result.get("chunk_index"),
}
chunks.append(chunk_info)
# Check for full_document mode documents and return their full text
if document_uuids:
full_text_docs = await db_client.get_full_text_documents(
organization_id=organization_id,
document_uuids=document_uuids,
)
for doc in full_text_docs:
if doc.full_text:
chunks.append(
{
"text": doc.full_text,
"filename": doc.filename,
"similarity": 1.0,
"chunk_index": 0,
}
)
# Filter out full_document UUIDs so vector search only hits chunked docs
full_doc_uuids = {doc.document_uuid for doc in full_text_docs}
chunked_uuids = [u for u in document_uuids if u not in full_doc_uuids]
else:
chunked_uuids = document_uuids
# Perform vector similarity search on chunked documents
if chunked_uuids is None or len(chunked_uuids) > 0:
if not embeddings_api_key:
raise ValueError(
"Embeddings API key not configured. Please set your API key in "
"Model Configurations > Embedding."
)
embedding_service = OpenAIEmbeddingService(
db_client=db_client,
max_tokens=128,
api_key=embeddings_api_key,
model_id=embeddings_model or "text-embedding-3-small",
base_url=embeddings_base_url,
)
results = await embedding_service.search_similar_chunks(
query=query,
organization_id=organization_id,
limit=limit,
document_uuids=chunked_uuids if chunked_uuids else None,
)
for result in results:
chunk_info = {
"text": result.get("contextualized_text")
or result.get("chunk_text"),
"filename": result.get("filename"),
"similarity": round(result.get("similarity", 0), 4),
"chunk_index": result.get("chunk_index"),
}
chunks.append(chunk_info)
logger.info(
f"Knowledge base retrieval: query='{query}', "

View file

@ -25,6 +25,7 @@ async def process_knowledge_base_document(
s3_key: str,
organization_id: int,
max_tokens: int = 128,
retrieval_mode: str = "chunked",
):
"""Process a knowledge base document: download, chunk, embed, and store.
@ -34,6 +35,7 @@ async def process_knowledge_base_document(
s3_key: S3 key where the file is stored
organization_id: Organization ID
max_tokens: Maximum number of tokens per chunk (default: 128)
retrieval_mode: "chunked" for vector search or "full_document" for full text
"""
logger.info(
f"Starting knowledge base document processing for document_id={document_id}, "
@ -128,6 +130,47 @@ async def process_knowledge_base_document(
mime_type=mime_type,
)
# Full document mode: extract text and store it, skip chunking/embedding
if retrieval_mode == "full_document":
logger.info(f"Document {document_id}: full_document mode, extracting text")
plain_text_extensions = {".txt", ".json"}
if file_extension.lower() in plain_text_extensions:
with open(temp_file_path, "r", encoding="utf-8") as f:
full_text = f.read()
if file_extension.lower() == ".json":
try:
parsed = json.loads(full_text)
full_text = json.dumps(parsed, indent=2, ensure_ascii=False)
except json.JSONDecodeError:
pass
docling_metadata = {"document_type": "PlainText"}
else:
converter = DocumentConverter()
conversion_result = converter.convert(temp_file_path)
doc = conversion_result.document
full_text = doc.export_to_text()
docling_metadata = {
"num_pages": len(doc.pages) if hasattr(doc, "pages") else None,
"document_type": type(doc).__name__,
}
# Store full text on the document record
await db_client.update_document_full_text(document_id, full_text)
await db_client.update_document_status(
document_id,
"completed",
total_chunks=0,
docling_metadata=docling_metadata,
)
logger.info(
f"Successfully processed full_document {document_id}. "
f"Text length: {len(full_text)} chars"
)
return
# Initialize the OpenAI embedding service
logger.info(
f"Initializing OpenAI embedding service with max_tokens={max_tokens}"