feat: added configable summary calculation and various improvements

- Replaced direct embedding calls with a utility function across various components to streamline embedding logic.
- Added enable_summary flag to several models and routes to control summary generation behavior.
This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-02-26 18:24:57 -08:00
parent dc33a4a68f
commit e9892c8fe9
50 changed files with 380 additions and 298 deletions

View file

@ -1,11 +1,59 @@
import hashlib
import logging
import warnings
import numpy as np
from litellm import get_model_info, token_counter
from app.config import config
from app.db import Chunk, DocumentType
from app.prompts import SUMMARY_PROMPT_TEMPLATE
logger = logging.getLogger(__name__)
def _get_embedding_max_tokens() -> int:
"""Get the max token limit for the configured embedding model.
Checks model properties in order: max_seq_length, _max_tokens.
Falls back to 8192 (OpenAI embedding default).
"""
model = config.embedding_model_instance
for attr in ("max_seq_length", "_max_tokens"):
val = getattr(model, attr, None)
if isinstance(val, int) and val > 0:
return val
return 8192
def truncate_for_embedding(text: str) -> str:
"""Truncate text to fit within the embedding model's context window.
Uses the embedding model's own tokenizer for accurate token counting,
so the result is model-agnostic regardless of the underlying provider.
"""
max_tokens = _get_embedding_max_tokens()
if len(text) // 3 <= max_tokens:
return text
tokenizer = config.embedding_model_instance.get_tokenizer()
tokens = tokenizer.encode(text)
if len(tokens) <= max_tokens:
return text
warnings.warn(
f"Truncating text from {len(tokens)} to {max_tokens} tokens for embedding.",
stacklevel=2,
)
return tokenizer.decode(tokens[:max_tokens])
def embed_text(text: str) -> np.ndarray:
"""Truncate text to fit and embed it. Drop-in replacement for
``config.embedding_model_instance.embed(text)`` that never exceeds the
model's context window."""
return config.embedding_model_instance.embed(truncate_for_embedding(text))
def get_model_context_window(model_name: str) -> int:
"""Get the total context window size for a model (input + output tokens)."""
@ -146,7 +194,7 @@ async def generate_document_summary(
else:
enhanced_summary_content = summary_content
summary_embedding = config.embedding_model_instance.embed(enhanced_summary_content)
summary_embedding = embed_text(enhanced_summary_content)
return enhanced_summary_content, summary_embedding
@ -164,7 +212,7 @@ async def create_document_chunks(content: str) -> list[Chunk]:
return [
Chunk(
content=chunk.text,
embedding=config.embedding_model_instance.embed(chunk.text),
embedding=embed_text(chunk.text),
)
for chunk in config.chunker_instance.chunk(content)
]