mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
367 lines
14 KiB
Python
367 lines
14 KiB
Python
"""Knowledge Base retrieval tool for workflow execution.
|
|
|
|
This module provides vector similarity search capabilities for retrieving
|
|
relevant information from the knowledge base during conversations.
|
|
|
|
Implements OpenTelemetry tracing for observability in Langfuse.
|
|
"""
|
|
|
|
import json
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from loguru import logger
|
|
from opentelemetry import trace
|
|
|
|
from api.db import db_client
|
|
from api.services.configuration.registry import ServiceProviders
|
|
from api.services.gen_ai import AzureOpenAIEmbeddingService, OpenAIEmbeddingService
|
|
from api.services.pipecat.tracing_config import ensure_tracing
|
|
|
|
|
|
async def retrieve_from_knowledge_base(
|
|
query: str,
|
|
organization_id: int,
|
|
document_uuids: Optional[List[str]] = None,
|
|
limit: int = 3,
|
|
embeddings_api_key: Optional[str] = None,
|
|
embeddings_model: Optional[str] = None,
|
|
embeddings_base_url: Optional[str] = None,
|
|
embeddings_provider: Optional[str] = None,
|
|
embeddings_endpoint: Optional[str] = None,
|
|
embeddings_api_version: Optional[str] = None,
|
|
tracing_context=None,
|
|
) -> Dict[str, Any]:
|
|
"""Retrieve relevant information from the knowledge base using vector similarity search.
|
|
|
|
Uses OpenAI text-embedding-3-small for embeddings by default. This provides
|
|
high-quality 1536-dimensional embeddings for accurate retrieval.
|
|
|
|
This function includes OpenTelemetry tracing for Langfuse observability.
|
|
|
|
Args:
|
|
query: The search query to find relevant information
|
|
organization_id: Organization ID for scoping the search
|
|
document_uuids: Optional list of document UUIDs to filter by
|
|
limit: Maximum number of chunks to return (default: 3)
|
|
embeddings_api_key: Optional API key for embedding service
|
|
embeddings_model: Optional model ID for embedding service
|
|
embeddings_base_url: Optional base URL for embedding service
|
|
tracing_context: Optional OpenTelemetry context for tracing
|
|
|
|
Returns:
|
|
Dictionary containing:
|
|
- chunks: List of relevant text chunks with metadata
|
|
- query: The original query
|
|
- total_results: Number of results returned
|
|
"""
|
|
# Create span for retrieval operation if tracing is enabled
|
|
if ensure_tracing():
|
|
try:
|
|
parent_context = tracing_context
|
|
|
|
# Get tracer
|
|
tracer = trace.get_tracer("pipecat")
|
|
except Exception as e:
|
|
logger.debug(f"Failed to setup tracing context: {e}")
|
|
# Fall back to non-traced execution
|
|
return await _perform_retrieval(
|
|
query,
|
|
organization_id,
|
|
document_uuids,
|
|
limit,
|
|
embeddings_api_key,
|
|
embeddings_model,
|
|
embeddings_base_url,
|
|
embeddings_provider,
|
|
embeddings_endpoint,
|
|
embeddings_api_version,
|
|
)
|
|
|
|
# Create span with parent context
|
|
if parent_context:
|
|
with tracer.start_as_current_span(
|
|
"knowledge_base_retrieval", context=parent_context
|
|
) as span:
|
|
try:
|
|
# Mark trace as public for Langfuse
|
|
span.set_attribute("langfuse.trace.public", True)
|
|
|
|
# Add operation metadata
|
|
span.set_attribute(
|
|
"gen_ai.operation.name", "knowledge_base_retrieval"
|
|
)
|
|
span.set_attribute("retrieval.query", query)
|
|
span.set_attribute("retrieval.limit", limit)
|
|
span.set_attribute("retrieval.organization_id", organization_id)
|
|
|
|
# Add document filter info
|
|
if document_uuids:
|
|
span.set_attribute(
|
|
"retrieval.document_count", len(document_uuids)
|
|
)
|
|
span.set_attribute(
|
|
"retrieval.document_uuids", json.dumps(document_uuids)
|
|
)
|
|
|
|
# Perform the actual retrieval
|
|
result = await _perform_retrieval(
|
|
query,
|
|
organization_id,
|
|
document_uuids,
|
|
limit,
|
|
embeddings_api_key,
|
|
embeddings_model,
|
|
embeddings_base_url,
|
|
embeddings_provider,
|
|
embeddings_endpoint,
|
|
embeddings_api_version,
|
|
)
|
|
|
|
# Add result metadata to span
|
|
span.set_attribute(
|
|
"retrieval.results_count", result["total_results"]
|
|
)
|
|
|
|
if result.get("error"):
|
|
span.set_attribute("retrieval.error", result["error"])
|
|
span.set_status(
|
|
trace.Status(trace.StatusCode.ERROR, result["error"])
|
|
)
|
|
else:
|
|
# Add similarity scores
|
|
if result["chunks"]:
|
|
similarities = [
|
|
chunk["similarity"] for chunk in result["chunks"]
|
|
]
|
|
span.set_attribute(
|
|
"retrieval.avg_similarity",
|
|
round(sum(similarities) / len(similarities), 4),
|
|
)
|
|
span.set_attribute(
|
|
"retrieval.max_similarity", max(similarities)
|
|
)
|
|
span.set_attribute(
|
|
"retrieval.min_similarity", min(similarities)
|
|
)
|
|
|
|
# Add retrieved documents info
|
|
filenames = list(
|
|
set(chunk["filename"] for chunk in result["chunks"])
|
|
)
|
|
span.set_attribute(
|
|
"retrieval.source_files", json.dumps(filenames)
|
|
)
|
|
|
|
# Add output as JSON for Langfuse
|
|
output_data = {
|
|
"query": query,
|
|
"chunks_retrieved": len(result["chunks"]),
|
|
"chunks": [
|
|
{
|
|
"text": chunk["text"][:200] + "..."
|
|
if len(chunk["text"]) > 200
|
|
else chunk["text"],
|
|
"filename": chunk["filename"],
|
|
"similarity": chunk["similarity"],
|
|
}
|
|
for chunk in result["chunks"]
|
|
],
|
|
}
|
|
span.set_attribute("output", json.dumps(output_data))
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in traced retrieval: {e}")
|
|
span.record_exception(e)
|
|
span.set_status(trace.Status(trace.StatusCode.ERROR, str(e)))
|
|
raise
|
|
else:
|
|
# No parent context - perform retrieval without tracing
|
|
logger.debug(
|
|
"No parent context available for knowledge base retrieval tracing"
|
|
)
|
|
return await _perform_retrieval(
|
|
query,
|
|
organization_id,
|
|
document_uuids,
|
|
limit,
|
|
embeddings_api_key,
|
|
embeddings_model,
|
|
embeddings_base_url,
|
|
embeddings_provider,
|
|
embeddings_endpoint,
|
|
embeddings_api_version,
|
|
)
|
|
else:
|
|
# Tracing is disabled - perform retrieval without tracing
|
|
return await _perform_retrieval(
|
|
query,
|
|
organization_id,
|
|
document_uuids,
|
|
limit,
|
|
embeddings_api_key,
|
|
embeddings_model,
|
|
embeddings_base_url,
|
|
embeddings_provider,
|
|
embeddings_endpoint,
|
|
embeddings_api_version,
|
|
)
|
|
|
|
|
|
async def _perform_retrieval(
|
|
query: str,
|
|
organization_id: int,
|
|
document_uuids: Optional[List[str]],
|
|
limit: int,
|
|
embeddings_api_key: Optional[str] = None,
|
|
embeddings_model: Optional[str] = None,
|
|
embeddings_base_url: Optional[str] = None,
|
|
embeddings_provider: Optional[str] = None,
|
|
embeddings_endpoint: Optional[str] = None,
|
|
embeddings_api_version: Optional[str] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Internal function to perform the actual retrieval operation.
|
|
|
|
Separated from tracing logic for cleaner code organization.
|
|
Handles both chunked (vector search) and full_document (full text) modes.
|
|
"""
|
|
try:
|
|
chunks = []
|
|
|
|
# Check for full_document mode documents and return their full text
|
|
if document_uuids:
|
|
full_text_docs = await db_client.get_full_text_documents(
|
|
organization_id=organization_id,
|
|
document_uuids=document_uuids,
|
|
)
|
|
for doc in full_text_docs:
|
|
if doc.full_text:
|
|
chunks.append(
|
|
{
|
|
"text": doc.full_text,
|
|
"filename": doc.filename,
|
|
"similarity": 1.0,
|
|
"chunk_index": 0,
|
|
}
|
|
)
|
|
|
|
# Filter out full_document UUIDs so vector search only hits chunked docs
|
|
full_doc_uuids = {doc.document_uuid for doc in full_text_docs}
|
|
chunked_uuids = [u for u in document_uuids if u not in full_doc_uuids]
|
|
else:
|
|
chunked_uuids = document_uuids
|
|
|
|
# Perform vector similarity search on chunked documents
|
|
if chunked_uuids is None or len(chunked_uuids) > 0:
|
|
if not embeddings_api_key:
|
|
raise ValueError(
|
|
"Embeddings API key not configured. Please set your API key in "
|
|
"Model Configurations > Embedding."
|
|
)
|
|
|
|
if (
|
|
embeddings_provider == ServiceProviders.AZURE.value
|
|
and embeddings_endpoint
|
|
):
|
|
embedding_service = AzureOpenAIEmbeddingService(
|
|
db_client=db_client,
|
|
api_key=embeddings_api_key,
|
|
endpoint=embeddings_endpoint,
|
|
model_id=embeddings_model or "text-embedding-3-small",
|
|
api_version=embeddings_api_version or "2024-02-15-preview",
|
|
)
|
|
else:
|
|
embedding_service = OpenAIEmbeddingService(
|
|
db_client=db_client,
|
|
api_key=embeddings_api_key,
|
|
model_id=embeddings_model or "text-embedding-3-small",
|
|
base_url=embeddings_base_url,
|
|
)
|
|
|
|
results = await embedding_service.search_similar_chunks(
|
|
query=query,
|
|
organization_id=organization_id,
|
|
limit=limit,
|
|
document_uuids=chunked_uuids if chunked_uuids else None,
|
|
)
|
|
|
|
for result in results:
|
|
chunk_info = {
|
|
"text": result.get("contextualized_text")
|
|
or result.get("chunk_text"),
|
|
"filename": result.get("filename"),
|
|
"similarity": round(result.get("similarity", 0), 4),
|
|
"chunk_index": result.get("chunk_index"),
|
|
}
|
|
chunks.append(chunk_info)
|
|
|
|
logger.info(
|
|
f"Knowledge base retrieval: query='{query}', "
|
|
f"results={len(chunks)}, "
|
|
f"document_filter={document_uuids}"
|
|
)
|
|
|
|
return {
|
|
"chunks": chunks,
|
|
"query": query,
|
|
"total_results": len(chunks),
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error retrieving from knowledge base: {e}")
|
|
return {
|
|
"error": str(e),
|
|
"chunks": [],
|
|
"query": query,
|
|
"total_results": 0,
|
|
}
|
|
|
|
|
|
def get_knowledge_base_tool(
|
|
document_uuids: Optional[List[str]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Get knowledge base retrieval tool definition for LLM function calling.
|
|
|
|
Args:
|
|
document_uuids: Optional list of document UUIDs to include in description
|
|
|
|
Returns:
|
|
Tool definition compatible with LLM function calling
|
|
"""
|
|
# Build description based on whether specific documents are filtered
|
|
if document_uuids and len(document_uuids) > 0:
|
|
description = (
|
|
"Retrieve relevant information from specific documents in the knowledge base. "
|
|
"Use this tool when you need to look up facts, policies, procedures, or any information "
|
|
"that might be stored in the available documents. The search will only look in the "
|
|
f"documents associated with this conversation step ({len(document_uuids)} document(s) available)."
|
|
)
|
|
else:
|
|
description = (
|
|
"Retrieve relevant information from the knowledge base. "
|
|
"Use this tool when you need to look up facts, policies, procedures, or any information "
|
|
"that might be stored in the knowledge base documents."
|
|
)
|
|
|
|
return {
|
|
"type": "function",
|
|
"function": {
|
|
"name": "retrieve_from_knowledge_base",
|
|
"description": description,
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": (
|
|
"The search query to find relevant information. "
|
|
"Be specific and use natural language. "
|
|
"Example: 'What is the refund policy for canceled orders?'"
|
|
),
|
|
}
|
|
},
|
|
"required": ["query"],
|
|
},
|
|
},
|
|
}
|