mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-10 08:05:22 +02:00
feat: add openai embedding service
This commit is contained in:
parent
eb41285204
commit
3f0e500fde
39 changed files with 1902 additions and 339 deletions
|
|
@ -2,25 +2,33 @@
|
|||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Literal
|
||||
|
||||
from docling.chunking import HybridChunker
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
|
||||
from loguru import logger
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import KnowledgeBaseChunkModel, KnowledgeBaseDocumentModel
|
||||
from api.db.models import KnowledgeBaseChunkModel
|
||||
from api.services.gen_ai import (
|
||||
OpenAIEmbeddingService,
|
||||
SentenceTransformerEmbeddingService,
|
||||
)
|
||||
from api.services.storage import storage_fs
|
||||
|
||||
# Constants
|
||||
EMBED_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
EMBEDDING_DIMENSION = 384
|
||||
# For tokenization/chunking - use SentenceTransformer tokenizer as baseline
|
||||
TOKENIZER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
|
||||
async def process_knowledge_base_document(
|
||||
ctx, document_id: int, s3_key: str, organization_id: int, max_tokens: int = 128
|
||||
ctx,
|
||||
document_id: int,
|
||||
s3_key: str,
|
||||
organization_id: int,
|
||||
max_tokens: int = 128,
|
||||
embedding_service: Literal["sentence_transformer", "openai"] = "openai",
|
||||
):
|
||||
"""Process a knowledge base document: download, chunk, embed, and store.
|
||||
|
||||
|
|
@ -30,6 +38,9 @@ async def process_knowledge_base_document(
|
|||
s3_key: S3 key where the file is stored
|
||||
organization_id: Organization ID
|
||||
max_tokens: Maximum number of tokens per chunk (default: 128)
|
||||
embedding_service: Embedding service to use (default: "openai")
|
||||
- "openai": Use OpenAI text-embedding-3-small (1536-dim, requires API key)
|
||||
- "sentence_transformer": Use SentenceTransformer (all-MiniLM-L6-v2, 384-dim, free)
|
||||
"""
|
||||
logger.info(
|
||||
f"Starting knowledge base document processing for document_id={document_id}, "
|
||||
|
|
@ -42,8 +53,14 @@ async def process_knowledge_base_document(
|
|||
# Update status to processing
|
||||
await db_client.update_document_status(document_id, "processing")
|
||||
|
||||
# Create temp file for download
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
|
||||
# Extract file extension from S3 key
|
||||
filename = s3_key.split("/")[-1]
|
||||
file_extension = (
|
||||
os.path.splitext(filename)[1] or ".bin"
|
||||
) # Default to .bin if no extension
|
||||
|
||||
# Create temp file for download with correct extension
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=file_extension)
|
||||
temp_file_path = temp_file.name
|
||||
temp_file.close()
|
||||
|
||||
|
|
@ -108,27 +125,58 @@ async def process_knowledge_base_document(
|
|||
mime_type=mime_type,
|
||||
)
|
||||
|
||||
# Initialize models for processing
|
||||
cache_dir = os.path.expanduser("~/.cache/hf_models")
|
||||
logger.info(f"Loading embedding model: {EMBED_MODEL_ID} (cache: {cache_dir})")
|
||||
embedding_model = SentenceTransformer(
|
||||
EMBED_MODEL_ID,
|
||||
cache_folder=cache_dir,
|
||||
)
|
||||
# Initialize the embedding service based on the parameter
|
||||
if embedding_service == "openai":
|
||||
logger.info(
|
||||
f"Initializing OpenAI embedding service with max_tokens={max_tokens}"
|
||||
)
|
||||
# Try to get user's embeddings configuration
|
||||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
if document.created_by:
|
||||
user_config = await db_client.get_user_configurations(
|
||||
document.created_by
|
||||
)
|
||||
if user_config.embeddings:
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
logger.info(
|
||||
f"Using user embeddings config: model={embeddings_model}"
|
||||
)
|
||||
|
||||
logger.info(f"Loading tokenizer: {EMBED_MODEL_ID} (cache: {cache_dir})")
|
||||
tokenizer = HuggingFaceTokenizer(
|
||||
tokenizer=AutoTokenizer.from_pretrained(
|
||||
EMBED_MODEL_ID,
|
||||
cache_dir=cache_dir,
|
||||
),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
# Check if API key is configured
|
||||
if not embeddings_api_key:
|
||||
error_message = (
|
||||
"OpenAI API key not configured. Please set your API key in "
|
||||
"Model Configurations > Embedding to process documents."
|
||||
)
|
||||
logger.warning(f"Document {document_id}: {error_message}")
|
||||
await db_client.update_document_status(
|
||||
document_id, "failed", error_message=error_message
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
|
||||
chunker = HybridChunker(tokenizer=tokenizer)
|
||||
service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
max_tokens=max_tokens,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
)
|
||||
elif embedding_service == "sentence_transformer":
|
||||
logger.info(
|
||||
f"Initializing SentenceTransformer embedding service with max_tokens={max_tokens}"
|
||||
)
|
||||
service = SentenceTransformerEmbeddingService(
|
||||
db_client=db_client,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid embedding_service: {embedding_service}. "
|
||||
f"Must be 'sentence_transformer' or 'openai'"
|
||||
)
|
||||
|
||||
# Convert document with docling
|
||||
# Step 1: Convert document with docling
|
||||
logger.info("Converting document with docling")
|
||||
converter = DocumentConverter()
|
||||
conversion_result = converter.convert(temp_file_path)
|
||||
|
|
@ -140,13 +188,26 @@ async def process_knowledge_base_document(
|
|||
"document_type": type(doc).__name__,
|
||||
}
|
||||
|
||||
# Chunk the document
|
||||
# Step 2: Initialize tokenizer for chunking
|
||||
logger.info(
|
||||
f"Loading tokenizer: {TOKENIZER_MODEL} with max_tokens={max_tokens}"
|
||||
)
|
||||
tokenizer = HuggingFaceTokenizer(
|
||||
tokenizer=AutoTokenizer.from_pretrained(TOKENIZER_MODEL),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# Step 3: Initialize chunker
|
||||
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
|
||||
chunker = HybridChunker(tokenizer=tokenizer)
|
||||
|
||||
# Step 4: Chunk the document
|
||||
logger.info(f"Chunking document with max_tokens={max_tokens}")
|
||||
chunks = list(chunker.chunk(dl_doc=doc))
|
||||
total_chunks = len(chunks)
|
||||
logger.info(f"Generated {total_chunks} chunks")
|
||||
|
||||
# Process each chunk
|
||||
# Step 5: Process each chunk
|
||||
chunk_texts = []
|
||||
chunk_records = []
|
||||
token_counts = []
|
||||
|
|
@ -156,7 +217,9 @@ async def process_knowledge_base_document(
|
|||
contextualized_text = chunker.contextualize(chunk=chunk)
|
||||
|
||||
# Calculate token count
|
||||
text_to_tokenize = contextualized_text if contextualized_text else chunk_text
|
||||
text_to_tokenize = (
|
||||
contextualized_text if contextualized_text else chunk_text
|
||||
)
|
||||
token_count = len(
|
||||
tokenizer.tokenizer.encode(text_to_tokenize, add_special_tokens=False)
|
||||
)
|
||||
|
|
@ -176,7 +239,7 @@ async def process_knowledge_base_document(
|
|||
),
|
||||
}
|
||||
|
||||
# Create chunk record
|
||||
# Create chunk record (without embedding yet)
|
||||
chunk_record = KnowledgeBaseChunkModel(
|
||||
document_id=document_id,
|
||||
organization_id=organization_id,
|
||||
|
|
@ -184,8 +247,8 @@ async def process_knowledge_base_document(
|
|||
contextualized_text=contextualized_text,
|
||||
chunk_index=i,
|
||||
chunk_metadata=chunk_metadata,
|
||||
embedding_model=EMBED_MODEL_ID,
|
||||
embedding_dimension=EMBEDDING_DIMENSION,
|
||||
embedding_model=service.get_model_id(),
|
||||
embedding_dimension=service.get_embedding_dimension(),
|
||||
token_count=token_count,
|
||||
)
|
||||
|
||||
|
|
@ -196,29 +259,25 @@ async def process_knowledge_base_document(
|
|||
if token_counts:
|
||||
avg_tokens = sum(token_counts) / len(token_counts)
|
||||
min_tokens = min(token_counts)
|
||||
max_tokens = max(token_counts)
|
||||
logger.info(f"Chunk token statistics:")
|
||||
max_tokens_actual = max(token_counts)
|
||||
logger.info("Chunk token statistics:")
|
||||
logger.info(f" - Average: {avg_tokens:.1f} tokens")
|
||||
logger.info(f" - Min: {min_tokens} tokens")
|
||||
logger.info(f" - Max: {max_tokens} tokens")
|
||||
logger.info(f" - Max: {max_tokens_actual} tokens")
|
||||
|
||||
# Generate embeddings in batch
|
||||
logger.info("Generating embeddings")
|
||||
embeddings = embedding_model.encode(
|
||||
chunk_texts,
|
||||
show_progress_bar=False,
|
||||
convert_to_numpy=True,
|
||||
)
|
||||
# Step 6: Generate embeddings using the embedding service
|
||||
logger.info(f"Generating embeddings using {embedding_service}")
|
||||
embeddings = await service.embed_texts(chunk_texts)
|
||||
|
||||
# Attach embeddings to chunk records
|
||||
# Step 7: Attach embeddings to chunk records
|
||||
for chunk_record, embedding in zip(chunk_records, embeddings):
|
||||
chunk_record.embedding = embedding.tolist()
|
||||
chunk_record.embedding = embedding
|
||||
|
||||
# Save chunks in database
|
||||
# Step 8: Save chunks in database
|
||||
logger.info("Storing chunks in database")
|
||||
await db_client.create_chunks_batch(chunk_records)
|
||||
|
||||
# Update document status to completed
|
||||
# Step 9: Update document status to completed
|
||||
await db_client.update_document_status(
|
||||
document_id,
|
||||
"completed",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue